Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- capvector-oft/prismatic/models/__init__.py +2 -0
- capvector-oft/prismatic/models/backbones/__init__.py +0 -0
- capvector-oft/prismatic/models/backbones/llm/__init__.py +4 -0
- capvector-oft/prismatic/models/backbones/llm/base_llm.py +223 -0
- capvector-oft/prismatic/models/backbones/llm/llama2.py +102 -0
- capvector-oft/prismatic/models/backbones/llm/mistral.py +72 -0
- capvector-oft/prismatic/models/backbones/llm/phi.py +64 -0
- capvector-oft/prismatic/models/backbones/llm/prompting/__init__.py +5 -0
- capvector-oft/prismatic/models/backbones/llm/prompting/base_prompter.py +73 -0
- capvector-oft/prismatic/models/backbones/llm/prompting/llama2_chat_prompter.py +91 -0
- capvector-oft/prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py +60 -0
- capvector-oft/prismatic/models/backbones/llm/prompting/phi_prompter.py +65 -0
- capvector-oft/prismatic/models/backbones/llm/prompting/vicuna_v15_prompter.py +82 -0
- capvector-oft/prismatic/models/backbones/vision/__init__.py +7 -0
- capvector-oft/prismatic/models/backbones/vision/base_vision.py +207 -0
- capvector-oft/prismatic/models/backbones/vision/clip_vit.py +27 -0
- capvector-oft/prismatic/models/backbones/vision/dinoclip_vit.py +147 -0
- capvector-oft/prismatic/models/backbones/vision/dinosiglip_vit.py +164 -0
- capvector-oft/prismatic/models/backbones/vision/dinov2_vit.py +19 -0
- capvector-oft/prismatic/models/backbones/vision/in1k_vit.py +22 -0
- capvector-oft/prismatic/models/backbones/vision/siglip_vit.py +24 -0
- capvector-oft/prismatic/models/load.py +226 -0
- capvector-oft/prismatic/models/materialize.py +130 -0
- capvector-oft/prismatic/models/projectors.py +49 -0
- capvector-oft/prismatic/models/registry.py +691 -0
- capvector-oft/prismatic/models/vlas/__init__.py +1 -0
- capvector-oft/prismatic/models/vlas/openvla.py +131 -0
- capvector-oft/prismatic/models/vlms/__init__.py +1 -0
- capvector-oft/prismatic/models/vlms/base_vlm.py +108 -0
- capvector-oft/prismatic/models/vlms/prismatic.py +621 -0
- capvector-oft/prismatic/overwatch/__init__.py +1 -0
- capvector-oft/prismatic/overwatch/overwatch.py +147 -0
- capvector-oft/prismatic/preprocessing/__init__.py +2 -0
- capvector-oft/prismatic/preprocessing/datasets/__init__.py +1 -0
- capvector-oft/prismatic/preprocessing/datasets/datasets.py +200 -0
- capvector-oft/prismatic/preprocessing/download.py +207 -0
- capvector-oft/prismatic/preprocessing/materialize.py +69 -0
- capvector-oft/prismatic/training/__init__.py +2 -0
- capvector-oft/prismatic/training/materialize.py +66 -0
- capvector-oft/prismatic/training/metrics.py +348 -0
- capvector-oft/prismatic/training/strategies/__init__.py +3 -0
- capvector-oft/prismatic/training/strategies/base_strategy.py +417 -0
- capvector-oft/prismatic/training/strategies/ddp.py +128 -0
- capvector-oft/prismatic/training/strategies/fsdp.py +270 -0
- capvector-oft/prismatic/training/train_utils.py +56 -0
- capvector-oft/prismatic/util/__init__.py +1 -0
- capvector-oft/prismatic/util/batching_utils.py +212 -0
- capvector-oft/prismatic/util/data_utils.py +156 -0
- capvector-oft/prismatic/util/nn_utils.py +53 -0
- capvector-oft/prismatic/util/torch_utils.py +95 -0
capvector-oft/prismatic/models/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .load import available_model_names, available_models, get_model_description, load, load_vla
|
| 2 |
+
from .materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform, get_vlm
|
capvector-oft/prismatic/models/backbones/__init__.py
ADDED
|
File without changes
|
capvector-oft/prismatic/models/backbones/llm/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base_llm import LLMBackbone
|
| 2 |
+
from .llama2 import LLaMa2LLMBackbone
|
| 3 |
+
from .mistral import MistralLLMBackbone
|
| 4 |
+
from .phi import PhiLLMBackbone
|
capvector-oft/prismatic/models/backbones/llm/base_llm.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
base_llm.py
|
| 3 |
+
|
| 4 |
+
Abstract class definition of a large (autoregressive) language model backbone (LLM), with full annotations of class
|
| 5 |
+
methods, utility functions, and initialization logic.
|
| 6 |
+
|
| 7 |
+
We also define the generic HFLLMBackbone class here, providing a default interface for loading any HF
|
| 8 |
+
AutoModelForCausalLM (e.g., LLamaForCausalLM). In general, we make the assumption that any given LLM backbone implements
|
| 9 |
+
the AutoModelForCausalLM API (though we may add Seq2Seq models in the future).
|
| 10 |
+
|
| 11 |
+
We make this assumption to keep the LLM handling in this codebase relatively lightweight, and to inherit all the nice HF
|
| 12 |
+
utilities around different types of decoding/generation strategies.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import warnings
|
| 16 |
+
from abc import ABC, abstractmethod
|
| 17 |
+
from functools import partial
|
| 18 |
+
from typing import Callable, List, Optional, Sequence, Type
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
| 23 |
+
from transformers import AutoConfig, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase
|
| 24 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 25 |
+
|
| 26 |
+
from prismatic.models.backbones.llm.prompting import PromptBuilder
|
| 27 |
+
from prismatic.overwatch import initialize_overwatch
|
| 28 |
+
|
| 29 |
+
# Suppress HF Deprecation Warnings
|
| 30 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 31 |
+
|
| 32 |
+
# Initialize Overwatch =>> Wraps `logging.Logger`
|
| 33 |
+
overwatch = initialize_overwatch(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# === Abstract Base Class for arbitrary HF LLM Backbones ===
|
| 37 |
+
class LLMBackbone(nn.Module, ABC):
|
| 38 |
+
def __init__(self, llm_backbone_id: str) -> None:
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.identifier = llm_backbone_id
|
| 41 |
+
|
| 42 |
+
# Instance attributes for an LLM Backbone
|
| 43 |
+
self.llm: PreTrainedModel = None
|
| 44 |
+
self.tokenizer: PreTrainedTokenizerBase = None
|
| 45 |
+
|
| 46 |
+
def get_tokenizer(self) -> PreTrainedTokenizerBase:
|
| 47 |
+
return self.tokenizer
|
| 48 |
+
|
| 49 |
+
@abstractmethod
|
| 50 |
+
def get_fsdp_wrapping_policy(self) -> Callable: ...
|
| 51 |
+
|
| 52 |
+
@abstractmethod
|
| 53 |
+
def enable_gradient_checkpointing(self) -> None: ...
|
| 54 |
+
|
| 55 |
+
@abstractmethod
|
| 56 |
+
def forward(
|
| 57 |
+
self,
|
| 58 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 59 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 60 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 61 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 62 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 63 |
+
labels: Optional[torch.LongTensor] = None,
|
| 64 |
+
use_cache: Optional[bool] = None,
|
| 65 |
+
output_attentions: Optional[bool] = None,
|
| 66 |
+
output_hidden_states: Optional[bool] = None,
|
| 67 |
+
return_dict: Optional[bool] = None,
|
| 68 |
+
) -> CausalLMOutputWithPast:
|
| 69 |
+
"""Run a forward pass through the LLM given targets (labels), returning the scalar Cross-Entropy Loss"""
|
| 70 |
+
raise NotImplementedError
|
| 71 |
+
|
| 72 |
+
@abstractmethod
|
| 73 |
+
def embed_input_ids(self, input_ids: torch.LongTensor) -> torch.Tensor: ...
|
| 74 |
+
|
| 75 |
+
@property
|
| 76 |
+
@abstractmethod
|
| 77 |
+
def prompt_builder_fn(self) -> Type[PromptBuilder]: ...
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
@abstractmethod
|
| 81 |
+
def transformer_layer_cls(self) -> Type[nn.Module]: ...
|
| 82 |
+
|
| 83 |
+
@property
|
| 84 |
+
@abstractmethod
|
| 85 |
+
def half_precision_dtype(self) -> torch.dtype: ...
|
| 86 |
+
|
| 87 |
+
@property
|
| 88 |
+
@abstractmethod
|
| 89 |
+
def last_layer_finetune_modules(self) -> Sequence[nn.Module]: ...
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def embed_dim(self) -> int:
|
| 93 |
+
return self.llm.config.hidden_size
|
| 94 |
+
|
| 95 |
+
@property
|
| 96 |
+
def pad_token_id(self) -> int:
|
| 97 |
+
return self.tokenizer.pad_token_id
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# === Abstract Base Class for Arbitrary HF Causal LLMs ===
|
| 101 |
+
class HFCausalLLMBackbone(LLMBackbone, ABC):
|
| 102 |
+
def __init__(
|
| 103 |
+
self,
|
| 104 |
+
llm_backbone_id: str,
|
| 105 |
+
llm_family: str,
|
| 106 |
+
llm_cls: Type[PreTrainedModel],
|
| 107 |
+
hf_hub_path: str,
|
| 108 |
+
llm_max_length: int = 2048,
|
| 109 |
+
hf_token: Optional[str] = None,
|
| 110 |
+
inference_mode: bool = False,
|
| 111 |
+
use_flash_attention_2: bool = False,
|
| 112 |
+
) -> None:
|
| 113 |
+
super().__init__(llm_backbone_id)
|
| 114 |
+
self.llm_family = llm_family
|
| 115 |
+
self.llm_max_length = llm_max_length
|
| 116 |
+
self.inference_mode = inference_mode
|
| 117 |
+
|
| 118 |
+
# Initialize LLM (downloading from HF Hub if necessary) --> `llm_cls` is the actual {Model}ForCausalLM class!
|
| 119 |
+
# => Note: We're eschewing use of the AutoModel API so that we can be more explicit about LLM-specific details
|
| 120 |
+
if not self.inference_mode:
|
| 121 |
+
overwatch.info(f"Loading [bold]{llm_family}[/] LLM from [underline]`{hf_hub_path}`[/]", ctx_level=1)
|
| 122 |
+
self.llm = llm_cls.from_pretrained(
|
| 123 |
+
hf_hub_path,
|
| 124 |
+
token=hf_token,
|
| 125 |
+
use_flash_attention_2=use_flash_attention_2 if not self.inference_mode else False,
|
| 126 |
+
# The following parameters are set to prevent `UserWarnings` from HF; we want greedy decoding!
|
| 127 |
+
do_sample=False,
|
| 128 |
+
temperature=1.0,
|
| 129 |
+
top_p=1.0,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# [Contract] `inference_mode` means we're loading from a pretrained checkpoint; no need to load base weights!
|
| 133 |
+
else:
|
| 134 |
+
overwatch.info(f"Building empty [bold]{llm_family}[/] LLM from [underline]`{hf_hub_path}`[/]", ctx_level=1)
|
| 135 |
+
llm_config = AutoConfig.from_pretrained(hf_hub_path, token=hf_token)
|
| 136 |
+
self.llm = llm_cls._from_config(llm_config)
|
| 137 |
+
|
| 138 |
+
# Lightweight Handling (with extended explanation) for setting some LLM Parameters
|
| 139 |
+
# => Set `decoder.use_cache = False` --> incompatible with gradient checkpointing (+ training in general)
|
| 140 |
+
#
|
| 141 |
+
# Reference: https://discuss.huggingface.co/t/what-is-the-purpose-of-use-cache-in-decoder/958
|
| 142 |
+
self.llm.config.use_cache = False if not self.inference_mode else True
|
| 143 |
+
|
| 144 |
+
# => Turns out that when gradient checkpointing is on and the underlying LLM has no "trainable" parameters
|
| 145 |
+
# (requires_grad is False), backprop will fail; setting `enable_input_requires_grad()` registers a new
|
| 146 |
+
# forward hook that fixes this =>> also totally safe for the "full finetuning" setting!
|
| 147 |
+
if not self.inference_mode:
|
| 148 |
+
self.llm.enable_input_require_grads()
|
| 149 |
+
|
| 150 |
+
# Load (Fast) Tokenizer
|
| 151 |
+
overwatch.info(f"Loading [bold]{llm_family}[/] (Fast) Tokenizer via the AutoTokenizer API", ctx_level=1)
|
| 152 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 153 |
+
hf_hub_path, model_max_length=self.llm_max_length, token=hf_token, padding_side="right"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# Validation =>> Our VLM logic currently operates under the assumption that the tokenization of a new input
|
| 157 |
+
# starts with a <BOS> token unless `add_special_tokens = False`; for these models, we empirically
|
| 158 |
+
# find that adding image patches *after* the BOS leads to much better performance.
|
| 159 |
+
#
|
| 160 |
+
# As a result we explicitly validate that a tokenizer conforms to the expected behavior; if you're reading this
|
| 161 |
+
# line, it's probably because you're adding a new LLM with a different tokenizer behavior. If so, feel free to
|
| 162 |
+
# override the `SPECIAL_CASES` set below, but make sure to make the appropriate changes in the `datasets.py`
|
| 163 |
+
# and VLM `forward()` logic!
|
| 164 |
+
SPECIAL_CASES = {
|
| 165 |
+
# Phi-2 Tokenizer doesn't add any BOS tokens by default, and sets BOS == EOS == "<|endoftext|>"
|
| 166 |
+
# =>> We'll prepend BOS to first input (to play nicely with image token insertion logic; verified that
|
| 167 |
+
# this works well with base LLM generation.
|
| 168 |
+
# =>> Like Llama-2 Tokenizers -- we'll add a special PAD token for training purposes.
|
| 169 |
+
"phi-2-3b",
|
| 170 |
+
}
|
| 171 |
+
if self.identifier in SPECIAL_CASES:
|
| 172 |
+
return
|
| 173 |
+
|
| 174 |
+
# Note =>> this assert should hold for all Llama-derived tokenizers (`LlamaTokenizerFast` ==> includes Mistral!
|
| 175 |
+
assert (self.tokenizer("Test 123", add_special_tokens=True).input_ids[0] == self.tokenizer.bos_token_id) and (
|
| 176 |
+
self.tokenizer("Test 123", add_special_tokens=False).input_ids[0] != self.tokenizer.bos_token_id
|
| 177 |
+
), (
|
| 178 |
+
f"Default Tokenizer of type `{type(self.tokenizer)}` does not automatically prefix inputs with BOS token!\n"
|
| 179 |
+
"Please read the comment in `base_llm.py` for more information!"
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
def get_fsdp_wrapping_policy(self) -> Callable:
|
| 183 |
+
"""Return a `transformer_auto_wrap_policy` where we wrap each instance of `self.transformer_layer_cls`"""
|
| 184 |
+
transformer_block_policy = partial(
|
| 185 |
+
transformer_auto_wrap_policy, transformer_layer_cls={self.transformer_layer_cls}
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
return transformer_block_policy
|
| 189 |
+
|
| 190 |
+
def enable_gradient_checkpointing(self) -> None:
|
| 191 |
+
"""Dispatch to underlying LLM instance's `gradient_checkpointing_enable`; defined for all `PretrainedModel`."""
|
| 192 |
+
self.llm.gradient_checkpointing_enable()
|
| 193 |
+
|
| 194 |
+
def embed_input_ids(self, input_ids: torch.LongTensor) -> torch.Tensor:
|
| 195 |
+
return self.llm.get_input_embeddings()(input_ids)
|
| 196 |
+
|
| 197 |
+
# [Contract] Should match the `forward` call of the underlying `llm` instance!
|
| 198 |
+
def forward(
|
| 199 |
+
self,
|
| 200 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 201 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 202 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 203 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 204 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 205 |
+
labels: Optional[torch.LongTensor] = None,
|
| 206 |
+
use_cache: Optional[bool] = None,
|
| 207 |
+
output_attentions: Optional[bool] = None,
|
| 208 |
+
output_hidden_states: Optional[bool] = None,
|
| 209 |
+
return_dict: Optional[bool] = None,
|
| 210 |
+
) -> CausalLMOutputWithPast:
|
| 211 |
+
output: CausalLMOutputWithPast = self.llm(
|
| 212 |
+
input_ids=input_ids,
|
| 213 |
+
attention_mask=attention_mask,
|
| 214 |
+
position_ids=position_ids,
|
| 215 |
+
past_key_values=past_key_values,
|
| 216 |
+
inputs_embeds=inputs_embeds,
|
| 217 |
+
labels=labels,
|
| 218 |
+
use_cache=use_cache,
|
| 219 |
+
output_attentions=output_attentions,
|
| 220 |
+
output_hidden_states=output_hidden_states,
|
| 221 |
+
return_dict=return_dict,
|
| 222 |
+
)
|
| 223 |
+
return output
|
capvector-oft/prismatic/models/backbones/llm/llama2.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
llama2.py
|
| 3 |
+
|
| 4 |
+
Class definition for all LLMs derived from LlamaForCausalLM.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Optional, Sequence, Type
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn as nn
|
| 11 |
+
from transformers import LlamaForCausalLM
|
| 12 |
+
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
| 13 |
+
|
| 14 |
+
from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone
|
| 15 |
+
from prismatic.models.backbones.llm.prompting import (
|
| 16 |
+
LLaMa2ChatPromptBuilder,
|
| 17 |
+
PromptBuilder,
|
| 18 |
+
PurePromptBuilder,
|
| 19 |
+
VicunaV15ChatPromptBuilder,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# Registry =>> Support LLaMa-2 Models (from HF Transformers)
|
| 23 |
+
# fmt: off
|
| 24 |
+
LLAMA2_MODELS = {
|
| 25 |
+
# === Pure Meta LLaMa-2 (non-instruct/chat-tuned) Models ===
|
| 26 |
+
"llama2-7b-pure": {
|
| 27 |
+
"llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-7b-hf"
|
| 28 |
+
},
|
| 29 |
+
|
| 30 |
+
"llama2-13b-pure": {
|
| 31 |
+
"llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-13b-hf"
|
| 32 |
+
},
|
| 33 |
+
|
| 34 |
+
# === Meta LLaMa-2 Chat Models ===
|
| 35 |
+
"llama2-7b-chat": {
|
| 36 |
+
"llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-7b-chat-hf"
|
| 37 |
+
},
|
| 38 |
+
|
| 39 |
+
"llama2-13b-chat": {
|
| 40 |
+
"llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-13b-chat-hf"
|
| 41 |
+
},
|
| 42 |
+
|
| 43 |
+
# === Vicuna v1.5 Chat Models ===
|
| 44 |
+
"vicuna-v15-7b": {
|
| 45 |
+
"llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "lmsys/vicuna-7b-v1.5"
|
| 46 |
+
},
|
| 47 |
+
|
| 48 |
+
"vicuna-v15-13b": {
|
| 49 |
+
"llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "lmsys/vicuna-13b-v1.5"
|
| 50 |
+
},
|
| 51 |
+
}
|
| 52 |
+
# fmt: on
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class LLaMa2LLMBackbone(HFCausalLLMBackbone):
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
llm_backbone_id: str,
|
| 59 |
+
llm_max_length: int = 2048,
|
| 60 |
+
hf_token: Optional[str] = None,
|
| 61 |
+
inference_mode: bool = False,
|
| 62 |
+
use_flash_attention_2: bool = True,
|
| 63 |
+
) -> None:
|
| 64 |
+
super().__init__(
|
| 65 |
+
llm_backbone_id,
|
| 66 |
+
llm_max_length=llm_max_length,
|
| 67 |
+
hf_token=hf_token,
|
| 68 |
+
inference_mode=inference_mode,
|
| 69 |
+
use_flash_attention_2=use_flash_attention_2,
|
| 70 |
+
**LLAMA2_MODELS[llm_backbone_id],
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# [Special Case] LLaMa-2 PAD Token Handling --> for clarity, we add an extra token (and resize)
|
| 74 |
+
self.tokenizer.add_special_tokens({"pad_token": "<PAD>"})
|
| 75 |
+
self.llm.config.pad_token_id = self.tokenizer.pad_token_id
|
| 76 |
+
self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64)
|
| 77 |
+
|
| 78 |
+
@property
|
| 79 |
+
def prompt_builder_fn(self) -> Type[PromptBuilder]:
|
| 80 |
+
if self.identifier.startswith("llama2-") and self.identifier.endswith("-pure"):
|
| 81 |
+
return PurePromptBuilder
|
| 82 |
+
|
| 83 |
+
elif self.identifier.startswith("llama2-") and self.identifier.endswith("-chat"):
|
| 84 |
+
return LLaMa2ChatPromptBuilder
|
| 85 |
+
|
| 86 |
+
elif self.identifier.startswith("vicuna"):
|
| 87 |
+
return VicunaV15ChatPromptBuilder
|
| 88 |
+
|
| 89 |
+
raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`")
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def transformer_layer_cls(self) -> Type[nn.Module]:
|
| 93 |
+
return LlamaDecoderLayer
|
| 94 |
+
|
| 95 |
+
@property
|
| 96 |
+
def half_precision_dtype(self) -> torch.dtype:
|
| 97 |
+
"""LLaMa-2 was trained in BF16; see https://huggingface.co/docs/transformers/main/model_doc/llama2."""
|
| 98 |
+
return torch.bfloat16
|
| 99 |
+
|
| 100 |
+
@property
|
| 101 |
+
def last_layer_finetune_modules(self) -> Sequence[nn.Module]:
|
| 102 |
+
return (self.llm.model.embed_tokens, self.llm.model.layers[-1], self.llm.lm_head)
|
capvector-oft/prismatic/models/backbones/llm/mistral.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
mistral.py
|
| 3 |
+
|
| 4 |
+
Class definition for all LLMs derived from MistralForCausalLM.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Optional, Type
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn as nn
|
| 11 |
+
from transformers import MistralForCausalLM
|
| 12 |
+
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
|
| 13 |
+
|
| 14 |
+
from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone
|
| 15 |
+
from prismatic.models.backbones.llm.prompting import MistralInstructPromptBuilder, PromptBuilder, PurePromptBuilder
|
| 16 |
+
|
| 17 |
+
# Registry =>> Support Mistral Models (from HF Transformers)
|
| 18 |
+
# fmt: off
|
| 19 |
+
MISTRAL_MODELS = {
|
| 20 |
+
# === Base Mistral v0.1 ===
|
| 21 |
+
"mistral-v0.1-7b-pure": {
|
| 22 |
+
"llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-v0.1"
|
| 23 |
+
},
|
| 24 |
+
|
| 25 |
+
# === Mistral Instruct v0.1 ===
|
| 26 |
+
"mistral-v0.1-7b-instruct": {
|
| 27 |
+
"llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-Instruct-v0.1"
|
| 28 |
+
}
|
| 29 |
+
}
|
| 30 |
+
# fmt: on
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class MistralLLMBackbone(HFCausalLLMBackbone):
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
llm_backbone_id: str,
|
| 37 |
+
llm_max_length: int = 2048,
|
| 38 |
+
hf_token: Optional[str] = None,
|
| 39 |
+
inference_mode: bool = False,
|
| 40 |
+
use_flash_attention_2: bool = True,
|
| 41 |
+
) -> None:
|
| 42 |
+
super().__init__(
|
| 43 |
+
llm_backbone_id,
|
| 44 |
+
llm_max_length=llm_max_length,
|
| 45 |
+
hf_token=hf_token,
|
| 46 |
+
inference_mode=inference_mode,
|
| 47 |
+
use_flash_attention_2=use_flash_attention_2,
|
| 48 |
+
**MISTRAL_MODELS[llm_backbone_id],
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# [Special Case] Mistral PAD Token Handling --> for clarity, we add an extra token (and resize)
|
| 52 |
+
self.tokenizer.add_special_tokens({"pad_token": "<PAD>"})
|
| 53 |
+
self.llm.config.pad_token_id = self.tokenizer.pad_token_id
|
| 54 |
+
self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64)
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def prompt_builder_fn(self) -> Type[PromptBuilder]:
|
| 58 |
+
if self.identifier.endswith("-pure"):
|
| 59 |
+
return PurePromptBuilder
|
| 60 |
+
|
| 61 |
+
elif self.identifier.endswith("-instruct"):
|
| 62 |
+
return MistralInstructPromptBuilder
|
| 63 |
+
|
| 64 |
+
raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`")
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def transformer_layer_cls(self) -> Type[nn.Module]:
|
| 68 |
+
return MistralDecoderLayer
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def half_precision_dtype(self) -> torch.dtype:
|
| 72 |
+
return torch.bfloat16
|
capvector-oft/prismatic/models/backbones/llm/phi.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
phi.py
|
| 3 |
+
|
| 4 |
+
Class definition for all LLMs derived from PhiForCausalLM.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Optional, Type
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn as nn
|
| 11 |
+
from transformers import PhiForCausalLM
|
| 12 |
+
from transformers.models.phi.modeling_phi import PhiDecoderLayer
|
| 13 |
+
|
| 14 |
+
from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone
|
| 15 |
+
from prismatic.models.backbones.llm.prompting import PhiPromptBuilder, PromptBuilder
|
| 16 |
+
|
| 17 |
+
# Registry ==> Support Phi Models (from HF Transformers)
|
| 18 |
+
# fmt: off
|
| 19 |
+
PHI_MODELS = {
|
| 20 |
+
# === Phi-2 ===
|
| 21 |
+
"phi-2-3b": {
|
| 22 |
+
"llm_family": "phi", "llm_cls": PhiForCausalLM, "hf_hub_path": "microsoft/phi-2"
|
| 23 |
+
}
|
| 24 |
+
}
|
| 25 |
+
# fmt: on
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class PhiLLMBackbone(HFCausalLLMBackbone):
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
llm_backbone_id: str,
|
| 32 |
+
llm_max_length: int = 2048,
|
| 33 |
+
hf_token: Optional[str] = None,
|
| 34 |
+
inference_mode: bool = False,
|
| 35 |
+
use_flash_attention_2: bool = True,
|
| 36 |
+
) -> None:
|
| 37 |
+
super().__init__(
|
| 38 |
+
llm_backbone_id,
|
| 39 |
+
llm_max_length=llm_max_length,
|
| 40 |
+
hf_token=hf_token,
|
| 41 |
+
inference_mode=inference_mode,
|
| 42 |
+
use_flash_attention_2=use_flash_attention_2,
|
| 43 |
+
**PHI_MODELS[llm_backbone_id],
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# [Special Case] Phi PAD Token Handling --> for clarity, we add an extra token (and resize)
|
| 47 |
+
self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
|
| 48 |
+
self.llm.config.pad_token_id = self.tokenizer.pad_token_id
|
| 49 |
+
self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64)
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def prompt_builder_fn(self) -> Type[PromptBuilder]:
|
| 53 |
+
if self.identifier.startswith("phi-2"):
|
| 54 |
+
return PhiPromptBuilder
|
| 55 |
+
|
| 56 |
+
raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`")
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def transformer_layer_cls(self) -> Type[nn.Module]:
|
| 60 |
+
return PhiDecoderLayer
|
| 61 |
+
|
| 62 |
+
@property
|
| 63 |
+
def half_precision_dtype(self) -> torch.dtype:
|
| 64 |
+
return torch.bfloat16
|
capvector-oft/prismatic/models/backbones/llm/prompting/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base_prompter import PromptBuilder, PurePromptBuilder
|
| 2 |
+
from .llama2_chat_prompter import LLaMa2ChatPromptBuilder
|
| 3 |
+
from .mistral_instruct_prompter import MistralInstructPromptBuilder
|
| 4 |
+
from .phi_prompter import PhiPromptBuilder
|
| 5 |
+
from .vicuna_v15_prompter import VicunaV15ChatPromptBuilder
|
capvector-oft/prismatic/models/backbones/llm/prompting/base_prompter.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
base_prompter.py
|
| 3 |
+
|
| 4 |
+
Abstract class definition of a multi-turn prompt builder for ensuring consistent formatting for chat-based LLMs.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PromptBuilder(ABC):
|
| 12 |
+
def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None:
|
| 13 |
+
self.model_family = model_family
|
| 14 |
+
|
| 15 |
+
# Only some models define a system prompt => let subclasses handle this logic!
|
| 16 |
+
self.system_prompt = system_prompt
|
| 17 |
+
|
| 18 |
+
@abstractmethod
|
| 19 |
+
def add_turn(self, role: str, message: str) -> str: ...
|
| 20 |
+
|
| 21 |
+
@abstractmethod
|
| 22 |
+
def get_potential_prompt(self, user_msg: str) -> None: ...
|
| 23 |
+
|
| 24 |
+
@abstractmethod
|
| 25 |
+
def get_prompt(self) -> str: ...
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class PurePromptBuilder(PromptBuilder):
|
| 29 |
+
def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None:
|
| 30 |
+
super().__init__(model_family, system_prompt)
|
| 31 |
+
|
| 32 |
+
# TODO (siddk) =>> Can't always assume LlamaTokenizer --> FIX ME!
|
| 33 |
+
self.bos, self.eos = "<s>", "</s>"
|
| 34 |
+
|
| 35 |
+
# Get role-specific "wrap" functions
|
| 36 |
+
self.wrap_human = lambda msg: f"In: {msg}\nOut: "
|
| 37 |
+
self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}"
|
| 38 |
+
|
| 39 |
+
# === `self.prompt` gets built up over multiple turns ===
|
| 40 |
+
self.prompt, self.turn_count = "", 0
|
| 41 |
+
|
| 42 |
+
def add_turn(self, role: str, message: str) -> str:
|
| 43 |
+
assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt")
|
| 44 |
+
message = message.replace("<image>", "").strip()
|
| 45 |
+
|
| 46 |
+
if (self.turn_count % 2) == 0:
|
| 47 |
+
human_message = self.wrap_human(message)
|
| 48 |
+
wrapped_message = human_message
|
| 49 |
+
else:
|
| 50 |
+
gpt_message = self.wrap_gpt(message)
|
| 51 |
+
wrapped_message = gpt_message
|
| 52 |
+
|
| 53 |
+
# Update Prompt
|
| 54 |
+
self.prompt += wrapped_message
|
| 55 |
+
|
| 56 |
+
# Bump Turn Counter
|
| 57 |
+
self.turn_count += 1
|
| 58 |
+
|
| 59 |
+
# Return "wrapped_message" (effective string added to context)
|
| 60 |
+
return wrapped_message
|
| 61 |
+
|
| 62 |
+
def get_potential_prompt(self, message: str) -> None:
|
| 63 |
+
# Assumes that it's always the user's (human's) turn!
|
| 64 |
+
prompt_copy = str(self.prompt)
|
| 65 |
+
|
| 66 |
+
human_message = self.wrap_human(message)
|
| 67 |
+
prompt_copy += human_message
|
| 68 |
+
|
| 69 |
+
return prompt_copy.removeprefix(self.bos).rstrip()
|
| 70 |
+
|
| 71 |
+
def get_prompt(self) -> str:
|
| 72 |
+
# Remove prefix <bos> (if exists) because it gets auto-inserted by tokenizer!
|
| 73 |
+
return self.prompt.removeprefix(self.bos).rstrip()
|
capvector-oft/prismatic/models/backbones/llm/prompting/llama2_chat_prompter.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
llama2_prompter.py
|
| 3 |
+
|
| 4 |
+
Defines a PromptBuilder for building LLaMa-2 Chat Prompts --> not sure if this is "optimal", but this is the pattern
|
| 5 |
+
that's used by HF and other online tutorials.
|
| 6 |
+
|
| 7 |
+
Reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder
|
| 13 |
+
|
| 14 |
+
# Default System Prompt for Prismatic Models
|
| 15 |
+
SYS_PROMPTS = {
|
| 16 |
+
"prismatic": (
|
| 17 |
+
"You are a helpful language and vision assistant. "
|
| 18 |
+
"You are able to understand the visual content that the user provides, "
|
| 19 |
+
"and assist the user with a variety of tasks using natural language."
|
| 20 |
+
),
|
| 21 |
+
"openvla": (
|
| 22 |
+
"You are a helpful language and vision assistant. "
|
| 23 |
+
"You are able to understand the visual content that the user provides, "
|
| 24 |
+
"and assist the user with a variety of tasks using natural language."
|
| 25 |
+
),
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def format_system_prompt(system_prompt: str) -> str:
|
| 30 |
+
return f"<<SYS>\n{system_prompt.strip()}\n<</SYS>>\n\n"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class LLaMa2ChatPromptBuilder(PromptBuilder):
|
| 34 |
+
def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None:
|
| 35 |
+
super().__init__(model_family, system_prompt)
|
| 36 |
+
self.system_prompt = format_system_prompt(
|
| 37 |
+
SYS_PROMPTS[self.model_family] if system_prompt is None else system_prompt
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# LLaMa-2 Specific
|
| 41 |
+
self.bos, self.eos = "<s>", "</s>"
|
| 42 |
+
|
| 43 |
+
# Get role-specific "wrap" functions
|
| 44 |
+
self.wrap_human = lambda msg: f"[INST] {msg} [/INST] "
|
| 45 |
+
self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}"
|
| 46 |
+
|
| 47 |
+
# === `self.prompt` gets built up over multiple turns ===
|
| 48 |
+
self.prompt, self.turn_count = "", 0
|
| 49 |
+
|
| 50 |
+
def add_turn(self, role: str, message: str) -> str:
|
| 51 |
+
assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt")
|
| 52 |
+
message = message.replace("<image>", "").strip()
|
| 53 |
+
|
| 54 |
+
# Special Handling for "system" prompt (turn_count == 0)
|
| 55 |
+
if self.turn_count == 0:
|
| 56 |
+
sys_message = self.wrap_human(self.system_prompt + message)
|
| 57 |
+
wrapped_message = sys_message
|
| 58 |
+
elif (self.turn_count % 2) == 0:
|
| 59 |
+
human_message = self.wrap_human(message)
|
| 60 |
+
wrapped_message = human_message
|
| 61 |
+
else:
|
| 62 |
+
gpt_message = self.wrap_gpt(message)
|
| 63 |
+
wrapped_message = gpt_message
|
| 64 |
+
|
| 65 |
+
# Update Prompt
|
| 66 |
+
self.prompt += wrapped_message
|
| 67 |
+
|
| 68 |
+
# Bump Turn Counter
|
| 69 |
+
self.turn_count += 1
|
| 70 |
+
|
| 71 |
+
# Return "wrapped_message" (effective string added to context)
|
| 72 |
+
return wrapped_message
|
| 73 |
+
|
| 74 |
+
def get_potential_prompt(self, message: str) -> None:
|
| 75 |
+
# Assumes that it's always the user's (human's) turn!
|
| 76 |
+
prompt_copy = str(self.prompt)
|
| 77 |
+
|
| 78 |
+
# Special Handling for "system" prompt (turn_count == 0)
|
| 79 |
+
if self.turn_count == 0:
|
| 80 |
+
sys_message = self.wrap_human(self.system_prompt + message)
|
| 81 |
+
prompt_copy += sys_message
|
| 82 |
+
|
| 83 |
+
else:
|
| 84 |
+
human_message = self.wrap_human(message)
|
| 85 |
+
prompt_copy += human_message
|
| 86 |
+
|
| 87 |
+
return prompt_copy.removeprefix(self.bos).rstrip()
|
| 88 |
+
|
| 89 |
+
def get_prompt(self) -> str:
|
| 90 |
+
# Remove prefix <bos> because it gets auto-inserted by tokenizer!
|
| 91 |
+
return self.prompt.removeprefix(self.bos).rstrip()
|
capvector-oft/prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
mistral_instruct_prompter.py
|
| 3 |
+
|
| 4 |
+
Defines a PromptBuilder for building Mistral Instruct Chat Prompts --> recommended pattern used by HF / online tutorial.s
|
| 5 |
+
|
| 6 |
+
Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class MistralInstructPromptBuilder(PromptBuilder):
|
| 15 |
+
def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None:
|
| 16 |
+
super().__init__(model_family, system_prompt)
|
| 17 |
+
|
| 18 |
+
# Note =>> Mistral Tokenizer is an instance of `LlamaTokenizer(Fast)`
|
| 19 |
+
# =>> Mistral Instruct *does not* use a System Prompt
|
| 20 |
+
self.bos, self.eos = "<s>", "</s>"
|
| 21 |
+
|
| 22 |
+
# Get role-specific "wrap" functions
|
| 23 |
+
self.wrap_human = lambda msg: f"[INST] {msg} [/INST] "
|
| 24 |
+
self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}"
|
| 25 |
+
|
| 26 |
+
# === `self.prompt` gets built up over multiple turns ===
|
| 27 |
+
self.prompt, self.turn_count = "", 0
|
| 28 |
+
|
| 29 |
+
def add_turn(self, role: str, message: str) -> str:
|
| 30 |
+
assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt")
|
| 31 |
+
message = message.replace("<image>", "").strip()
|
| 32 |
+
|
| 33 |
+
if (self.turn_count % 2) == 0:
|
| 34 |
+
human_message = self.wrap_human(message)
|
| 35 |
+
wrapped_message = human_message
|
| 36 |
+
else:
|
| 37 |
+
gpt_message = self.wrap_gpt(message)
|
| 38 |
+
wrapped_message = gpt_message
|
| 39 |
+
|
| 40 |
+
# Update Prompt
|
| 41 |
+
self.prompt += wrapped_message
|
| 42 |
+
|
| 43 |
+
# Bump Turn Counter
|
| 44 |
+
self.turn_count += 1
|
| 45 |
+
|
| 46 |
+
# Return "wrapped_message" (effective string added to context)
|
| 47 |
+
return wrapped_message
|
| 48 |
+
|
| 49 |
+
def get_potential_prompt(self, message: str) -> None:
|
| 50 |
+
# Assumes that it's always the user's (human's) turn!
|
| 51 |
+
prompt_copy = str(self.prompt)
|
| 52 |
+
|
| 53 |
+
human_message = self.wrap_human(message)
|
| 54 |
+
prompt_copy += human_message
|
| 55 |
+
|
| 56 |
+
return prompt_copy.removeprefix(self.bos).rstrip()
|
| 57 |
+
|
| 58 |
+
def get_prompt(self) -> str:
|
| 59 |
+
# Remove prefix <bos> because it gets auto-inserted by tokenizer!
|
| 60 |
+
return self.prompt.removeprefix(self.bos).rstrip()
|
capvector-oft/prismatic/models/backbones/llm/prompting/phi_prompter.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
phi_prompter.py
|
| 3 |
+
|
| 4 |
+
Defines a PromptBuilder for building Phi-2 Input/Output Prompts --> recommended pattern used by HF / Microsoft.
|
| 5 |
+
Also handles Phi special case BOS token additions.
|
| 6 |
+
|
| 7 |
+
Reference: https://huggingface.co/microsoft/phi-2#qa-format
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class PhiPromptBuilder(PromptBuilder):
|
| 16 |
+
def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None:
|
| 17 |
+
super().__init__(model_family, system_prompt)
|
| 18 |
+
|
| 19 |
+
# Note =>> Phi Tokenizer is an instance of `CodeGenTokenizer(Fast)`
|
| 20 |
+
# =>> By default, does *not* append <BOS> / <EOS> tokens --> we handle that here (IMPORTANT)!
|
| 21 |
+
self.bos, self.eos = "<|endoftext|>", "<|endoftext|>"
|
| 22 |
+
|
| 23 |
+
# Get role-specific "wrap" functions
|
| 24 |
+
# =>> Note that placement of <bos>/<eos> were based on experiments generating from Phi-2 in Input/Output mode
|
| 25 |
+
self.wrap_human = lambda msg: f"Input: {msg}\nOutput: "
|
| 26 |
+
self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}\n{self.eos}"
|
| 27 |
+
|
| 28 |
+
# === `self.prompt` gets built up over multiple turns ===
|
| 29 |
+
self.prompt, self.turn_count = "", 0
|
| 30 |
+
|
| 31 |
+
def add_turn(self, role: str, message: str) -> str:
|
| 32 |
+
assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt")
|
| 33 |
+
message = message.replace("<image>", "").strip()
|
| 34 |
+
|
| 35 |
+
# Special Handling for "first" input --> prepend a <BOS> token (expected by Prismatic)
|
| 36 |
+
if self.turn_count == 0:
|
| 37 |
+
bos_human_message = f"{self.bos}{self.wrap_human(message)}"
|
| 38 |
+
wrapped_message = bos_human_message
|
| 39 |
+
elif (self.turn_count % 2) == 0:
|
| 40 |
+
human_message = self.wrap_human(message)
|
| 41 |
+
wrapped_message = human_message
|
| 42 |
+
else:
|
| 43 |
+
gpt_message = self.wrap_gpt(message)
|
| 44 |
+
wrapped_message = gpt_message
|
| 45 |
+
|
| 46 |
+
# Update Prompt
|
| 47 |
+
self.prompt += wrapped_message
|
| 48 |
+
|
| 49 |
+
# Bump Turn Counter
|
| 50 |
+
self.turn_count += 1
|
| 51 |
+
|
| 52 |
+
# Return "wrapped_message" (effective string added to context)
|
| 53 |
+
return wrapped_message
|
| 54 |
+
|
| 55 |
+
def get_potential_prompt(self, message: str) -> None:
|
| 56 |
+
# Assumes that it's always the user's (human's) turn!
|
| 57 |
+
prompt_copy = str(self.prompt)
|
| 58 |
+
|
| 59 |
+
human_message = self.wrap_human(message)
|
| 60 |
+
prompt_copy += human_message
|
| 61 |
+
|
| 62 |
+
return prompt_copy.rstrip()
|
| 63 |
+
|
| 64 |
+
def get_prompt(self) -> str:
|
| 65 |
+
return self.prompt.rstrip()
|
capvector-oft/prismatic/models/backbones/llm/prompting/vicuna_v15_prompter.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
vicuna_v15_prompter.py
|
| 3 |
+
|
| 4 |
+
Defines a PromptBuilder for building Vicuna-v1.5 Chat Prompts.
|
| 5 |
+
|
| 6 |
+
Reference: https://huggingface.co/lmsys/vicuna-13b-v1.5
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder
|
| 12 |
+
|
| 13 |
+
# Default System Prompt for LLaVa Models
|
| 14 |
+
SYS_PROMPTS = {
|
| 15 |
+
"prismatic": (
|
| 16 |
+
"A chat between a curious user and an artificial intelligence assistant. "
|
| 17 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
| 18 |
+
),
|
| 19 |
+
"openvla": (
|
| 20 |
+
"A chat between a curious user and an artificial intelligence assistant. "
|
| 21 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
| 22 |
+
),
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class VicunaV15ChatPromptBuilder(PromptBuilder):
|
| 27 |
+
def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None:
|
| 28 |
+
super().__init__(model_family, system_prompt)
|
| 29 |
+
self.system_prompt = (SYS_PROMPTS[self.model_family] if system_prompt is None else system_prompt).strip() + " "
|
| 30 |
+
|
| 31 |
+
# LLaMa-2 Specific
|
| 32 |
+
self.bos, self.eos = "<s>", "</s>"
|
| 33 |
+
|
| 34 |
+
# Get role-specific "wrap" functions
|
| 35 |
+
self.wrap_human = lambda msg: f"USER: {msg} ASSISTANT: "
|
| 36 |
+
self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}"
|
| 37 |
+
|
| 38 |
+
# === `self.prompt` gets built up over multiple turns ===
|
| 39 |
+
self.prompt, self.turn_count = "", 0
|
| 40 |
+
|
| 41 |
+
def add_turn(self, role: str, message: str) -> str:
|
| 42 |
+
assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt")
|
| 43 |
+
message = message.replace("<image>", "").strip()
|
| 44 |
+
|
| 45 |
+
# Special Handling for "system" prompt (turn_count == 0)
|
| 46 |
+
if self.turn_count == 0:
|
| 47 |
+
sys_message = self.system_prompt + self.wrap_human(message)
|
| 48 |
+
wrapped_message = sys_message
|
| 49 |
+
elif (self.turn_count % 2) == 0:
|
| 50 |
+
human_message = self.wrap_human(message)
|
| 51 |
+
wrapped_message = human_message
|
| 52 |
+
else:
|
| 53 |
+
gpt_message = self.wrap_gpt(message)
|
| 54 |
+
wrapped_message = gpt_message
|
| 55 |
+
|
| 56 |
+
# Update Prompt
|
| 57 |
+
self.prompt += wrapped_message
|
| 58 |
+
|
| 59 |
+
# Bump Turn Counter
|
| 60 |
+
self.turn_count += 1
|
| 61 |
+
|
| 62 |
+
# Return "wrapped_message" (effective string added to context)
|
| 63 |
+
return wrapped_message
|
| 64 |
+
|
| 65 |
+
def get_potential_prompt(self, message: str) -> None:
|
| 66 |
+
# Assumes that it's always the user's (human's) turn!
|
| 67 |
+
prompt_copy = str(self.prompt)
|
| 68 |
+
|
| 69 |
+
# Special Handling for "system" prompt (turn_count == 0)
|
| 70 |
+
if self.turn_count == 0:
|
| 71 |
+
sys_message = self.system_prompt + self.wrap_human(message)
|
| 72 |
+
prompt_copy += sys_message
|
| 73 |
+
|
| 74 |
+
else:
|
| 75 |
+
human_message = self.wrap_human(message)
|
| 76 |
+
prompt_copy += human_message
|
| 77 |
+
|
| 78 |
+
return prompt_copy.removeprefix(self.bos).rstrip()
|
| 79 |
+
|
| 80 |
+
def get_prompt(self) -> str:
|
| 81 |
+
# Remove prefix <bos> (if exists) because it gets auto-inserted by tokenizer!
|
| 82 |
+
return self.prompt.removeprefix(self.bos).rstrip()
|
capvector-oft/prismatic/models/backbones/vision/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base_vision import ImageTransform, VisionBackbone
|
| 2 |
+
from .clip_vit import CLIPViTBackbone
|
| 3 |
+
from .dinoclip_vit import DinoCLIPViTBackbone
|
| 4 |
+
from .dinosiglip_vit import DinoSigLIPViTBackbone
|
| 5 |
+
from .dinov2_vit import DinoV2ViTBackbone
|
| 6 |
+
from .in1k_vit import IN1KViTBackbone
|
| 7 |
+
from .siglip_vit import SigLIPViTBackbone
|
capvector-oft/prismatic/models/backbones/vision/base_vision.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
base_vision.py
|
| 3 |
+
|
| 4 |
+
Abstract class definition of a Vision Backbone (Visual Featurizer), with full annotations of class methods, utility
|
| 5 |
+
functions, and initialization logic.
|
| 6 |
+
|
| 7 |
+
We also define the generic TimmViTBackbone class here, providing a default interface for loading any TIMM Vision
|
| 8 |
+
Transformer model for feature extraction.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from abc import ABC, abstractmethod
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from functools import partial
|
| 14 |
+
from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
|
| 15 |
+
|
| 16 |
+
import timm
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torchvision.transforms.functional as TVF
|
| 20 |
+
from PIL.Image import Image
|
| 21 |
+
from timm.models.vision_transformer import Block, VisionTransformer
|
| 22 |
+
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
|
| 23 |
+
from torchvision.transforms import Compose, Resize
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# === Utility Functions for Monkey-Patching ===
|
| 27 |
+
def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
|
| 28 |
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
| 29 |
+
result = fn(*args, **kwargs)
|
| 30 |
+
return result[0] if isinstance(result, tuple) else result
|
| 31 |
+
|
| 32 |
+
return wrapper
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# === Interface for an Image Transform ===
|
| 36 |
+
class ImageTransform(Protocol):
|
| 37 |
+
def __call__(self, img: Image, **kwargs: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: ...
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# === Custom Torchvision Image Transforms ===
|
| 41 |
+
@dataclass
|
| 42 |
+
class LetterboxPad:
|
| 43 |
+
padding_fill_value: Tuple[int, int, int]
|
| 44 |
+
|
| 45 |
+
def __call__(self, image: Image) -> Image:
|
| 46 |
+
"""Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
|
| 47 |
+
(w, h), max_wh = image.size, max(image.size)
|
| 48 |
+
horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2)
|
| 49 |
+
padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
|
| 50 |
+
return TVF.pad(image, padding, fill=self.padding_fill_value, padding_mode="constant")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# === Abstract Base Class for arbitrary Vision Backbones ===
|
| 54 |
+
class VisionBackbone(nn.Module, ABC):
|
| 55 |
+
def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.identifier: str = vision_backbone_id
|
| 58 |
+
self.image_resize_strategy: str = image_resize_strategy
|
| 59 |
+
self.default_image_size: int = default_image_size
|
| 60 |
+
|
| 61 |
+
# Instance attributes for a Vision Backbone
|
| 62 |
+
self.featurizer: nn.Module = None
|
| 63 |
+
self.image_transform: ImageTransform = None
|
| 64 |
+
|
| 65 |
+
def get_image_transform(self) -> ImageTransform:
|
| 66 |
+
return self.image_transform
|
| 67 |
+
|
| 68 |
+
@abstractmethod
|
| 69 |
+
def get_fsdp_wrapping_policy(self) -> Callable: ...
|
| 70 |
+
|
| 71 |
+
@abstractmethod
|
| 72 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
"""Run a forward pass through the featurizer given a set of processed images, returning patch/grid features."""
|
| 74 |
+
raise NotImplementedError
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
@abstractmethod
|
| 78 |
+
def default_image_resolution(self) -> Tuple[int, int, int]: ...
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
@abstractmethod
|
| 82 |
+
def embed_dim(self) -> int: ...
|
| 83 |
+
|
| 84 |
+
@property
|
| 85 |
+
@abstractmethod
|
| 86 |
+
def num_patches(self) -> int: ...
|
| 87 |
+
|
| 88 |
+
@property
|
| 89 |
+
@abstractmethod
|
| 90 |
+
def half_precision_dtype(self) -> torch.dtype: ...
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# === Abstract Base Class for Arbitrary TIMM Vision Transformer Backbones ===
|
| 94 |
+
class TimmViTBackbone(VisionBackbone, ABC):
|
| 95 |
+
def __init__(
|
| 96 |
+
self,
|
| 97 |
+
vision_backbone_id: str,
|
| 98 |
+
timm_path_or_url: str,
|
| 99 |
+
image_resize_strategy: str,
|
| 100 |
+
default_image_size: int = 224,
|
| 101 |
+
override_act_layer: Optional[str] = None,
|
| 102 |
+
) -> None:
|
| 103 |
+
super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size)
|
| 104 |
+
self.timm_path_or_url = timm_path_or_url
|
| 105 |
+
self.override_act_layer = override_act_layer
|
| 106 |
+
self.dtype = torch.bfloat16
|
| 107 |
+
|
| 108 |
+
# Initialize Featurizer (ViT) by downloading from HF / TIMM Hub if necessary
|
| 109 |
+
if self.override_act_layer is None:
|
| 110 |
+
self.featurizer: VisionTransformer = timm.create_model(
|
| 111 |
+
self.timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size
|
| 112 |
+
)
|
| 113 |
+
else:
|
| 114 |
+
self.featurizer: VisionTransformer = timm.create_model(
|
| 115 |
+
self.timm_path_or_url,
|
| 116 |
+
pretrained=True,
|
| 117 |
+
num_classes=0,
|
| 118 |
+
img_size=self.default_image_size,
|
| 119 |
+
act_layer=self.override_act_layer,
|
| 120 |
+
)
|
| 121 |
+
self.featurizer.eval()
|
| 122 |
+
|
| 123 |
+
# Monkey-Patch the `forward()` function of the featurizer to ensure FSDP-compatibility
|
| 124 |
+
# => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
|
| 125 |
+
# => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
|
| 126 |
+
self.featurizer.forward = unpack_tuple(
|
| 127 |
+
partial(self.featurizer.get_intermediate_layers, n={len(self.featurizer.blocks) - 2})
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Validation =>> for now, this class *only* supports TIMM Vision Transformers (but can be extended!)
|
| 131 |
+
assert isinstance(self.featurizer, VisionTransformer), (
|
| 132 |
+
"Featurizer is not a TIMM VisionTransformer; if you would like to support a new visual representation, "
|
| 133 |
+
"file an issue or implement the requisite logic (see `prismatic/models/backbones/vision/base_vision.py`)!"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# Get Config =>> Note :: Override default image size to ensure correct image transform
|
| 137 |
+
self.data_cfg = timm.data.resolve_model_data_config(self.featurizer)
|
| 138 |
+
self.data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size)
|
| 139 |
+
|
| 140 |
+
# Initialize Default Image Transform --> Modified by `self.image_resize_strategy`
|
| 141 |
+
default_image_transform = timm.data.create_transform(**self.data_cfg, is_training=False)
|
| 142 |
+
|
| 143 |
+
# Fix =>> SigLIP & IN1K default transforms resize to *larger* than `self.default_image_size` (crops image)!
|
| 144 |
+
if "siglip" in self.timm_path_or_url or "in1k" in self.timm_path_or_url:
|
| 145 |
+
assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!"
|
| 146 |
+
assert isinstance(default_image_transform.transforms[0], Resize)
|
| 147 |
+
default_image_transform = Compose(
|
| 148 |
+
[
|
| 149 |
+
Resize(self.default_image_size, interpolation=default_image_transform.transforms[0].interpolation),
|
| 150 |
+
*default_image_transform.transforms[1:],
|
| 151 |
+
]
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Switch on `image_resize_strategy`
|
| 155 |
+
if self.image_resize_strategy == "resize-naive":
|
| 156 |
+
assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!"
|
| 157 |
+
assert isinstance(default_image_transform.transforms[0], Resize)
|
| 158 |
+
|
| 159 |
+
target_size = (self.default_image_size, self.default_image_size)
|
| 160 |
+
self.image_transform = Compose(
|
| 161 |
+
[
|
| 162 |
+
Resize(target_size, interpolation=default_image_transform.transforms[0].interpolation),
|
| 163 |
+
*default_image_transform.transforms[1:],
|
| 164 |
+
]
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
elif self.image_resize_strategy == "resize-crop":
|
| 168 |
+
self.image_transform = default_image_transform
|
| 169 |
+
|
| 170 |
+
elif self.image_resize_strategy == "letterbox":
|
| 171 |
+
assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!"
|
| 172 |
+
assert "mean" in self.data_cfg, "TIMM `data_cfg` missing image normalization mean!"
|
| 173 |
+
|
| 174 |
+
# Compute Padding Fill Value (rescaled normalization mean if applicable)
|
| 175 |
+
fill = tuple([int(x * 255) for x in self.data_cfg["mean"]])
|
| 176 |
+
|
| 177 |
+
# Build New Transform
|
| 178 |
+
self.image_transform = Compose([LetterboxPad(fill), *default_image_transform.transforms])
|
| 179 |
+
|
| 180 |
+
else:
|
| 181 |
+
raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!")
|
| 182 |
+
|
| 183 |
+
def get_fsdp_wrapping_policy(self) -> Callable:
|
| 184 |
+
"""Return a simple FSDP policy that wraps each ViT block and then the _entire_ featurizer."""
|
| 185 |
+
vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer})
|
| 186 |
+
transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
|
| 187 |
+
return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy])
|
| 188 |
+
|
| 189 |
+
def forward(self, pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor:
|
| 190 |
+
"""Runs transformed image/pixel tensor through vision backbone, returning _all_ patch features."""
|
| 191 |
+
return self.featurizer(pixel_values)
|
| 192 |
+
|
| 193 |
+
@property
|
| 194 |
+
def default_image_resolution(self) -> Tuple[int, int, int]:
|
| 195 |
+
return self.data_cfg["input_size"]
|
| 196 |
+
|
| 197 |
+
@property
|
| 198 |
+
def embed_dim(self) -> int:
|
| 199 |
+
return self.featurizer.embed_dim
|
| 200 |
+
|
| 201 |
+
@property
|
| 202 |
+
def num_patches(self) -> int:
|
| 203 |
+
return self.featurizer.patch_embed.num_patches
|
| 204 |
+
|
| 205 |
+
@property
|
| 206 |
+
def half_precision_dtype(self) -> torch.dtype:
|
| 207 |
+
return self.dtype
|
capvector-oft/prismatic/models/backbones/vision/clip_vit.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
clip_vit.py
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from prismatic.models.backbones.vision.base_vision import TimmViTBackbone
|
| 6 |
+
|
| 7 |
+
# Registry =>> Supported CLIP Vision Backbones (from TIMM)
|
| 8 |
+
CLIP_VISION_BACKBONES = {
|
| 9 |
+
"clip-vit-b": "vit_base_patch16_clip_224.openai",
|
| 10 |
+
"clip-vit-l": "vit_large_patch14_clip_224.openai",
|
| 11 |
+
"clip-vit-l-336px": "vit_large_patch14_clip_336.openai",
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# [IMPORTANT] By Default, TIMM initialized OpenAI CLIP models with the standard GELU activation from PyTorch.
|
| 16 |
+
# HOWEVER =>> Original OpenAI models were trained with the quick_gelu *approximation* -- while it's
|
| 17 |
+
# a decent approximation, the resulting features are *worse*; this was a super tricky bug
|
| 18 |
+
# to identify, but luckily there's an easy fix (`override_act_layer`)
|
| 19 |
+
class CLIPViTBackbone(TimmViTBackbone):
|
| 20 |
+
def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
|
| 21 |
+
super().__init__(
|
| 22 |
+
vision_backbone_id,
|
| 23 |
+
CLIP_VISION_BACKBONES[vision_backbone_id],
|
| 24 |
+
image_resize_strategy,
|
| 25 |
+
default_image_size=default_image_size,
|
| 26 |
+
override_act_layer="quick_gelu" if CLIP_VISION_BACKBONES[vision_backbone_id].endswith(".openai") else None,
|
| 27 |
+
)
|
capvector-oft/prismatic/models/backbones/vision/dinoclip_vit.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
dinoclip_vit.py
|
| 3 |
+
|
| 4 |
+
Vision backbone that returns concatenated features from both DINOv2 and CLIP.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from functools import partial
|
| 9 |
+
from typing import Callable, Dict, Tuple
|
| 10 |
+
|
| 11 |
+
import timm
|
| 12 |
+
import torch
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from timm.models.vision_transformer import Block, VisionTransformer
|
| 15 |
+
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
|
| 16 |
+
from torchvision.transforms import Compose, Resize
|
| 17 |
+
|
| 18 |
+
from prismatic.models.backbones.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple
|
| 19 |
+
|
| 20 |
+
# Registry =>> Supported DinoCLIP Pairs (as TIMM identifiers)
|
| 21 |
+
DINOCLIP_VISION_BACKBONES = {
|
| 22 |
+
"dinoclip-vit-l-336px": {
|
| 23 |
+
"dino": "vit_large_patch14_reg4_dinov2.lvd142m",
|
| 24 |
+
"clip": "vit_large_patch14_clip_336.openai",
|
| 25 |
+
},
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class DinoCLIPImageTransform:
|
| 31 |
+
dino_image_transform: ImageTransform
|
| 32 |
+
clip_image_transform: ImageTransform
|
| 33 |
+
is_prismatic: bool = True
|
| 34 |
+
|
| 35 |
+
def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]:
|
| 36 |
+
return {"dino": self.dino_image_transform(img, **kwargs), "clip": self.clip_image_transform(img, **kwargs)}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class DinoCLIPViTBackbone(VisionBackbone):
|
| 40 |
+
def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
|
| 41 |
+
super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size)
|
| 42 |
+
self.dino_timm_path_or_url = DINOCLIP_VISION_BACKBONES[vision_backbone_id]["dino"]
|
| 43 |
+
self.clip_timm_path_or_url = DINOCLIP_VISION_BACKBONES[vision_backbone_id]["clip"]
|
| 44 |
+
|
| 45 |
+
# Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary
|
| 46 |
+
self.dino_featurizer: VisionTransformer = timm.create_model(
|
| 47 |
+
self.dino_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size
|
| 48 |
+
)
|
| 49 |
+
self.dino_featurizer.eval()
|
| 50 |
+
|
| 51 |
+
self.clip_featurizer: VisionTransformer = timm.create_model(
|
| 52 |
+
self.clip_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size
|
| 53 |
+
)
|
| 54 |
+
self.clip_featurizer.eval()
|
| 55 |
+
|
| 56 |
+
# Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility
|
| 57 |
+
# => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
|
| 58 |
+
# => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
|
| 59 |
+
self.dino_featurizer.forward = unpack_tuple(
|
| 60 |
+
partial(self.dino_featurizer.get_intermediate_layers, n={len(self.dino_featurizer.blocks) - 2})
|
| 61 |
+
)
|
| 62 |
+
self.clip_featurizer.forward = unpack_tuple(
|
| 63 |
+
partial(self.clip_featurizer.get_intermediate_layers, n={len(self.clip_featurizer.blocks) - 2})
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models
|
| 67 |
+
self.dino_data_cfg = timm.data.resolve_model_data_config(self.dino_featurizer)
|
| 68 |
+
self.dino_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size)
|
| 69 |
+
|
| 70 |
+
self.clip_data_cfg = timm.data.resolve_model_data_config(self.clip_featurizer)
|
| 71 |
+
self.clip_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size)
|
| 72 |
+
|
| 73 |
+
# Initialize *both* Transforms
|
| 74 |
+
default_dino_transform = timm.data.create_transform(**self.dino_data_cfg, is_training=False)
|
| 75 |
+
default_clip_transform = timm.data.create_transform(**self.clip_data_cfg, is_training=False)
|
| 76 |
+
if self.image_resize_strategy == "resize-naive":
|
| 77 |
+
assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_image_transform`!"
|
| 78 |
+
assert isinstance(default_clip_transform, Compose), "Unexpected `default_clip_image_transform`!"
|
| 79 |
+
assert isinstance(default_dino_transform.transforms[0], Resize)
|
| 80 |
+
assert isinstance(default_clip_transform.transforms[0], Resize)
|
| 81 |
+
|
| 82 |
+
target_size = (self.default_image_size, self.default_image_size)
|
| 83 |
+
dino_transform = Compose(
|
| 84 |
+
[
|
| 85 |
+
Resize(target_size, interpolation=default_dino_transform.transforms[0].interpolation),
|
| 86 |
+
*default_dino_transform.transforms[1:],
|
| 87 |
+
]
|
| 88 |
+
)
|
| 89 |
+
clip_transform = Compose(
|
| 90 |
+
[
|
| 91 |
+
Resize(target_size, interpolation=default_clip_transform.transforms[0].interpolation),
|
| 92 |
+
*default_clip_transform.transforms[1:],
|
| 93 |
+
]
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
self.image_transform = DinoCLIPImageTransform(dino_transform, clip_transform)
|
| 97 |
+
|
| 98 |
+
elif self.image_resize_strategy == "resize-crop":
|
| 99 |
+
self.image_transform = DinoCLIPImageTransform(default_dino_transform, default_clip_transform)
|
| 100 |
+
|
| 101 |
+
elif self.image_resize_strategy == "letterbox":
|
| 102 |
+
assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_transform`!"
|
| 103 |
+
assert isinstance(default_clip_transform, Compose), "Unexpected `default_clip_transform`!"
|
| 104 |
+
assert "mean" in self.dino_data_cfg and "mean" in self.clip_data_cfg, "DinoCLIP `data_cfg` missing `mean`!"
|
| 105 |
+
|
| 106 |
+
# Compute Padding Fill Value(s) (rescaled normalization mean if applicable)
|
| 107 |
+
dino_fill = tuple([int(x * 255) for x in self.dino_data_cfg["mean"]])
|
| 108 |
+
clip_fill = tuple([int(x * 255) for x in self.clip_data_cfg["mean"]])
|
| 109 |
+
|
| 110 |
+
# Build New Transform
|
| 111 |
+
self.image_transform = DinoCLIPImageTransform(
|
| 112 |
+
Compose([LetterboxPad(dino_fill), *default_dino_transform.transforms]),
|
| 113 |
+
Compose([LetterboxPad(clip_fill), *default_clip_transform.transforms]),
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
else:
|
| 117 |
+
raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!")
|
| 118 |
+
|
| 119 |
+
def get_fsdp_wrapping_policy(self) -> Callable:
|
| 120 |
+
"""Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers."""
|
| 121 |
+
vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer})
|
| 122 |
+
transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
|
| 123 |
+
return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy])
|
| 124 |
+
|
| 125 |
+
def forward(self, pixel_values: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 126 |
+
"""Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches."""
|
| 127 |
+
dino_patches = self.dino_featurizer(pixel_values["dino"])
|
| 128 |
+
clip_patches = self.clip_featurizer(pixel_values["clip"])
|
| 129 |
+
|
| 130 |
+
return torch.cat([dino_patches, clip_patches], dim=2)
|
| 131 |
+
|
| 132 |
+
@property
|
| 133 |
+
def default_image_resolution(self) -> Tuple[int, int, int]:
|
| 134 |
+
return self.dino_data_cfg["input_size"]
|
| 135 |
+
|
| 136 |
+
@property
|
| 137 |
+
def embed_dim(self) -> int:
|
| 138 |
+
return self.dino_featurizer.embed_dim + self.clip_featurizer.embed_dim
|
| 139 |
+
|
| 140 |
+
@property
|
| 141 |
+
def num_patches(self) -> int:
|
| 142 |
+
assert self.dino_featurizer.patch_embed.num_patches == self.clip_featurizer.patch_embed.num_patches
|
| 143 |
+
return self.dino_featurizer.patch_embed.num_patches
|
| 144 |
+
|
| 145 |
+
@property
|
| 146 |
+
def half_precision_dtype(self) -> torch.dtype:
|
| 147 |
+
return torch.bfloat16
|
capvector-oft/prismatic/models/backbones/vision/dinosiglip_vit.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
dinosiglip_vit.py
|
| 3 |
+
|
| 4 |
+
Vision backbone that returns concatenated features from both DINOv2 and SigLIP.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from functools import partial
|
| 9 |
+
from typing import Callable, Dict, Tuple
|
| 10 |
+
|
| 11 |
+
import timm
|
| 12 |
+
import torch
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from timm.models.vision_transformer import Block, VisionTransformer
|
| 15 |
+
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
|
| 16 |
+
from torchvision.transforms import Compose, Resize
|
| 17 |
+
|
| 18 |
+
from prismatic.models.backbones.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple
|
| 19 |
+
|
| 20 |
+
# Registry =>> Supported DinoSigLIP Pairs (as TIMM identifiers)
|
| 21 |
+
DINOSigLIP_VISION_BACKBONES = {
|
| 22 |
+
"dinosiglip-vit-so-224px": {
|
| 23 |
+
"dino": "vit_large_patch14_reg4_dinov2.lvd142m",
|
| 24 |
+
"siglip": "vit_so400m_patch14_siglip_224",
|
| 25 |
+
},
|
| 26 |
+
"dinosiglip-vit-so-384px": {
|
| 27 |
+
"dino": "vit_large_patch14_reg4_dinov2.lvd142m",
|
| 28 |
+
"siglip": "vit_so400m_patch14_siglip_384",
|
| 29 |
+
},
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class DinoSigLIPImageTransform:
|
| 35 |
+
dino_image_transform: ImageTransform
|
| 36 |
+
siglip_image_transform: ImageTransform
|
| 37 |
+
is_prismatic: bool = True
|
| 38 |
+
|
| 39 |
+
def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]:
|
| 40 |
+
return {"dino": self.dino_image_transform(img, **kwargs), "siglip": self.siglip_image_transform(img, **kwargs)}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class DinoSigLIPViTBackbone(VisionBackbone):
|
| 44 |
+
def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
|
| 45 |
+
super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size)
|
| 46 |
+
self.dino_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[vision_backbone_id]["dino"]
|
| 47 |
+
self.siglip_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[vision_backbone_id]["siglip"]
|
| 48 |
+
|
| 49 |
+
# Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary
|
| 50 |
+
self.dino_featurizer: VisionTransformer = timm.create_model(
|
| 51 |
+
self.dino_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size
|
| 52 |
+
)
|
| 53 |
+
self.dino_featurizer.eval()
|
| 54 |
+
|
| 55 |
+
self.siglip_featurizer: VisionTransformer = timm.create_model(
|
| 56 |
+
self.siglip_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size
|
| 57 |
+
)
|
| 58 |
+
self.siglip_featurizer.eval()
|
| 59 |
+
|
| 60 |
+
# Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility
|
| 61 |
+
# => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
|
| 62 |
+
# => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
|
| 63 |
+
self.dino_featurizer.forward = unpack_tuple(
|
| 64 |
+
partial(self.dino_featurizer.get_intermediate_layers, n={len(self.dino_featurizer.blocks) - 2})
|
| 65 |
+
)
|
| 66 |
+
self.siglip_featurizer.forward = unpack_tuple(
|
| 67 |
+
partial(self.siglip_featurizer.get_intermediate_layers, n={len(self.siglip_featurizer.blocks) - 2})
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models
|
| 71 |
+
self.dino_data_cfg = timm.data.resolve_model_data_config(self.dino_featurizer)
|
| 72 |
+
self.dino_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size)
|
| 73 |
+
|
| 74 |
+
self.siglip_data_cfg = timm.data.resolve_model_data_config(self.siglip_featurizer)
|
| 75 |
+
self.siglip_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size)
|
| 76 |
+
|
| 77 |
+
# Initialize *both* Transforms
|
| 78 |
+
default_dino_transform = timm.data.create_transform(**self.dino_data_cfg, is_training=False)
|
| 79 |
+
default_siglip_transform = timm.data.create_transform(**self.siglip_data_cfg, is_training=False)
|
| 80 |
+
|
| 81 |
+
# Fix =>> SigLIP default transform resizes to *larger* than `self.default_image_size` (crops image)!!
|
| 82 |
+
assert isinstance(default_siglip_transform, Compose), "Unexpected `default_image_transform`!"
|
| 83 |
+
assert isinstance(default_siglip_transform.transforms[0], Resize)
|
| 84 |
+
default_siglip_transform = Compose(
|
| 85 |
+
[
|
| 86 |
+
Resize(self.default_image_size, interpolation=default_siglip_transform.transforms[0].interpolation),
|
| 87 |
+
*default_siglip_transform.transforms[1:],
|
| 88 |
+
]
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
if self.image_resize_strategy == "resize-naive":
|
| 92 |
+
assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_image_transform`!"
|
| 93 |
+
assert isinstance(default_siglip_transform, Compose), "Unexpected `default_siglip_image_transform`!"
|
| 94 |
+
assert isinstance(default_dino_transform.transforms[0], Resize)
|
| 95 |
+
assert isinstance(default_siglip_transform.transforms[0], Resize)
|
| 96 |
+
|
| 97 |
+
target_size = (self.default_image_size, self.default_image_size)
|
| 98 |
+
dino_transform = Compose(
|
| 99 |
+
[
|
| 100 |
+
Resize(target_size, interpolation=default_dino_transform.transforms[0].interpolation),
|
| 101 |
+
*default_dino_transform.transforms[1:],
|
| 102 |
+
]
|
| 103 |
+
)
|
| 104 |
+
siglip_transform = Compose(
|
| 105 |
+
[
|
| 106 |
+
Resize(target_size, interpolation=default_siglip_transform.transforms[0].interpolation),
|
| 107 |
+
*default_siglip_transform.transforms[1:],
|
| 108 |
+
]
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
self.image_transform = DinoSigLIPImageTransform(dino_transform, siglip_transform)
|
| 112 |
+
|
| 113 |
+
elif self.image_resize_strategy == "resize-crop":
|
| 114 |
+
self.image_transform = DinoSigLIPImageTransform(default_dino_transform, default_siglip_transform)
|
| 115 |
+
|
| 116 |
+
elif self.image_resize_strategy == "letterbox":
|
| 117 |
+
assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_transform`!"
|
| 118 |
+
assert isinstance(default_siglip_transform, Compose), "Unexpected `default_siglip_transform`!"
|
| 119 |
+
assert (
|
| 120 |
+
"mean" in self.dino_data_cfg and "mean" in self.siglip_data_cfg
|
| 121 |
+
), "DinoSigLIP `data_cfg` missing `mean`!"
|
| 122 |
+
|
| 123 |
+
# Compute Padding Fill Value(s) (rescaled normalization mean if applicable)
|
| 124 |
+
dino_fill = tuple([int(x * 255) for x in self.dino_data_cfg["mean"]])
|
| 125 |
+
siglip_fill = tuple([int(x * 255) for x in self.siglip_data_cfg["mean"]])
|
| 126 |
+
|
| 127 |
+
# Build New Transform
|
| 128 |
+
self.image_transform = DinoSigLIPImageTransform(
|
| 129 |
+
Compose([LetterboxPad(dino_fill), *default_dino_transform.transforms]),
|
| 130 |
+
Compose([LetterboxPad(siglip_fill), *default_siglip_transform.transforms]),
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
else:
|
| 134 |
+
raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!")
|
| 135 |
+
|
| 136 |
+
def get_fsdp_wrapping_policy(self) -> Callable:
|
| 137 |
+
"""Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers."""
|
| 138 |
+
vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer})
|
| 139 |
+
transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
|
| 140 |
+
return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy])
|
| 141 |
+
|
| 142 |
+
def forward(self, pixel_values: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 143 |
+
"""Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches."""
|
| 144 |
+
dino_patches = self.dino_featurizer(pixel_values["dino"])
|
| 145 |
+
siglip_patches = self.siglip_featurizer(pixel_values["siglip"])
|
| 146 |
+
|
| 147 |
+
return torch.cat([dino_patches, siglip_patches], dim=2)
|
| 148 |
+
|
| 149 |
+
@property
|
| 150 |
+
def default_image_resolution(self) -> Tuple[int, int, int]:
|
| 151 |
+
return self.dino_data_cfg["input_size"]
|
| 152 |
+
|
| 153 |
+
@property
|
| 154 |
+
def embed_dim(self) -> int:
|
| 155 |
+
return self.dino_featurizer.embed_dim + self.siglip_featurizer.embed_dim
|
| 156 |
+
|
| 157 |
+
@property
|
| 158 |
+
def num_patches(self) -> int:
|
| 159 |
+
assert self.dino_featurizer.patch_embed.num_patches == self.siglip_featurizer.patch_embed.num_patches
|
| 160 |
+
return self.dino_featurizer.patch_embed.num_patches
|
| 161 |
+
|
| 162 |
+
@property
|
| 163 |
+
def half_precision_dtype(self) -> torch.dtype:
|
| 164 |
+
return torch.bfloat16
|
capvector-oft/prismatic/models/backbones/vision/dinov2_vit.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
dinov2_vit.py
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from prismatic.models.backbones.vision.base_vision import TimmViTBackbone
|
| 6 |
+
|
| 7 |
+
# Registry =>> Supported DINOv2 Vision Backbones (from TIMM) =>> Note:: Using DINOv2 w/ Registers!
|
| 8 |
+
# => Reference: https://arxiv.org/abs/2309.16588
|
| 9 |
+
DINOv2_VISION_BACKBONES = {"dinov2-vit-l": "vit_large_patch14_reg4_dinov2.lvd142m"}
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DinoV2ViTBackbone(TimmViTBackbone):
|
| 13 |
+
def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
|
| 14 |
+
super().__init__(
|
| 15 |
+
vision_backbone_id,
|
| 16 |
+
DINOv2_VISION_BACKBONES[vision_backbone_id],
|
| 17 |
+
image_resize_strategy,
|
| 18 |
+
default_image_size=default_image_size,
|
| 19 |
+
)
|
capvector-oft/prismatic/models/backbones/vision/in1k_vit.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
in1k_vit.py
|
| 3 |
+
|
| 4 |
+
Vision Transformers trained / finetuned on ImageNet (ImageNet-21K =>> ImageNet-1K)
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from prismatic.models.backbones.vision.base_vision import TimmViTBackbone
|
| 8 |
+
|
| 9 |
+
# Registry =>> Supported Vision Backbones (from TIMM)
|
| 10 |
+
IN1K_VISION_BACKBONES = {
|
| 11 |
+
"in1k-vit-l": "vit_large_patch16_224.augreg_in21k_ft_in1k",
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class IN1KViTBackbone(TimmViTBackbone):
|
| 16 |
+
def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
|
| 17 |
+
super().__init__(
|
| 18 |
+
vision_backbone_id,
|
| 19 |
+
IN1K_VISION_BACKBONES[vision_backbone_id],
|
| 20 |
+
image_resize_strategy,
|
| 21 |
+
default_image_size=default_image_size,
|
| 22 |
+
)
|
capvector-oft/prismatic/models/backbones/vision/siglip_vit.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
siglip_vit.py
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from prismatic.models.backbones.vision.base_vision import TimmViTBackbone
|
| 6 |
+
|
| 7 |
+
# Registry =>> Supported SigLIP Vision Backbones (from TIMM) =>> Note:: Using SigLIP w/ Patch = 14 (but SO400M Arch)
|
| 8 |
+
SIGLIP_VISION_BACKBONES = {
|
| 9 |
+
"siglip-vit-b16-224px": "vit_base_patch16_siglip_224",
|
| 10 |
+
"siglip-vit-b16-256px": "vit_base_patch16_siglip_256",
|
| 11 |
+
"siglip-vit-b16-384px": "vit_base_patch16_siglip_384",
|
| 12 |
+
"siglip-vit-so400m": "vit_so400m_patch14_siglip_224",
|
| 13 |
+
"siglip-vit-so400m-384px": "vit_so400m_patch14_siglip_384",
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SigLIPViTBackbone(TimmViTBackbone):
|
| 18 |
+
def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
|
| 19 |
+
super().__init__(
|
| 20 |
+
vision_backbone_id,
|
| 21 |
+
SIGLIP_VISION_BACKBONES[vision_backbone_id],
|
| 22 |
+
image_resize_strategy,
|
| 23 |
+
default_image_size=default_image_size,
|
| 24 |
+
)
|
capvector-oft/prismatic/models/load.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
load.py
|
| 3 |
+
|
| 4 |
+
Entry point for loading pretrained VLMs for inference; exposes functions for listing available models (with canonical
|
| 5 |
+
IDs, mappings to paper experiments, and short descriptions), as well as for loading models (from disk or HF Hub).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import List, Optional, Union
|
| 12 |
+
|
| 13 |
+
from huggingface_hub import HfFileSystem, hf_hub_download
|
| 14 |
+
|
| 15 |
+
from prismatic.conf import ModelConfig
|
| 16 |
+
from prismatic.models.materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform
|
| 17 |
+
from prismatic.models.registry import GLOBAL_REGISTRY, MODEL_REGISTRY
|
| 18 |
+
from prismatic.models.vlas import OpenVLA
|
| 19 |
+
from prismatic.models.vlms import PrismaticVLM
|
| 20 |
+
from prismatic.overwatch import initialize_overwatch
|
| 21 |
+
from prismatic.vla.action_tokenizer import ActionTokenizer
|
| 22 |
+
|
| 23 |
+
# Initialize Overwatch =>> Wraps `logging.Logger`
|
| 24 |
+
overwatch = initialize_overwatch(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# === HF Hub Repository ===
|
| 28 |
+
HF_HUB_REPO = "TRI-ML/prismatic-vlms"
|
| 29 |
+
VLA_HF_HUB_REPO = "openvla/openvla-dev"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# === Available Models ===
|
| 33 |
+
def available_models() -> List[str]:
|
| 34 |
+
return list(MODEL_REGISTRY.keys())
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def available_model_names() -> List[str]:
|
| 38 |
+
return list(GLOBAL_REGISTRY.items())
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_model_description(model_id_or_name: str) -> str:
|
| 42 |
+
if model_id_or_name not in GLOBAL_REGISTRY:
|
| 43 |
+
raise ValueError(f"Couldn't find `{model_id_or_name = }; check `prismatic.available_model_names()`")
|
| 44 |
+
|
| 45 |
+
# Print Description & Return
|
| 46 |
+
print(json.dumps(description := GLOBAL_REGISTRY[model_id_or_name]["description"], indent=2))
|
| 47 |
+
|
| 48 |
+
return description
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# === Load Pretrained Model ===
|
| 52 |
+
def load(
|
| 53 |
+
model_id_or_path: Union[str, Path],
|
| 54 |
+
hf_token: Optional[str] = None,
|
| 55 |
+
cache_dir: Optional[Union[str, Path]] = None,
|
| 56 |
+
load_for_training: bool = False,
|
| 57 |
+
) -> PrismaticVLM:
|
| 58 |
+
"""Loads a pretrained PrismaticVLM from either local disk or the HuggingFace Hub."""
|
| 59 |
+
if os.path.isdir(model_id_or_path):
|
| 60 |
+
overwatch.info(f"Loading from local path `{(run_dir := Path(model_id_or_path))}`")
|
| 61 |
+
|
| 62 |
+
# Get paths for `config.json` and pretrained checkpoint
|
| 63 |
+
config_json, checkpoint_pt = run_dir / "config.json", run_dir / "checkpoints" / "latest-checkpoint.pt"
|
| 64 |
+
assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`"
|
| 65 |
+
assert checkpoint_pt.exists(), f"Missing checkpoint for `{run_dir = }`"
|
| 66 |
+
else:
|
| 67 |
+
if model_id_or_path not in GLOBAL_REGISTRY:
|
| 68 |
+
raise ValueError(f"Couldn't find `{model_id_or_path = }; check `prismatic.available_model_names()`")
|
| 69 |
+
|
| 70 |
+
overwatch.info(f"Downloading `{(model_id := GLOBAL_REGISTRY[model_id_or_path]['model_id'])} from HF Hub")
|
| 71 |
+
with overwatch.local_zero_first():
|
| 72 |
+
config_json = hf_hub_download(repo_id=HF_HUB_REPO, filename=f"{model_id}/config.json", cache_dir=cache_dir)
|
| 73 |
+
checkpoint_pt = hf_hub_download(
|
| 74 |
+
repo_id=HF_HUB_REPO, filename=f"{model_id}/checkpoints/latest-checkpoint.pt", cache_dir=cache_dir
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Load Model Config from `config.json`
|
| 78 |
+
with open(config_json, "r") as f:
|
| 79 |
+
model_cfg = json.load(f)["model"]
|
| 80 |
+
|
| 81 |
+
# = Load Individual Components necessary for Instantiating a VLM =
|
| 82 |
+
# =>> Print Minimal Config
|
| 83 |
+
overwatch.info(
|
| 84 |
+
f"Found Config =>> Loading & Freezing [bold blue]{model_cfg['model_id']}[/] with:\n"
|
| 85 |
+
f" Vision Backbone =>> [bold]{model_cfg['vision_backbone_id']}[/]\n"
|
| 86 |
+
f" LLM Backbone =>> [bold]{model_cfg['llm_backbone_id']}[/]\n"
|
| 87 |
+
f" Arch Specifier =>> [bold]{model_cfg['arch_specifier']}[/]\n"
|
| 88 |
+
f" Checkpoint Path =>> [underline]`{checkpoint_pt}`[/]"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Load Vision Backbone
|
| 92 |
+
overwatch.info(f"Loading Vision Backbone [bold]{model_cfg['vision_backbone_id']}[/]")
|
| 93 |
+
vision_backbone, image_transform = get_vision_backbone_and_transform(
|
| 94 |
+
model_cfg["vision_backbone_id"],
|
| 95 |
+
model_cfg["image_resize_strategy"],
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Load LLM Backbone --> note `inference_mode = True` by default when calling `load()`
|
| 99 |
+
overwatch.info(f"Loading Pretrained LLM [bold]{model_cfg['llm_backbone_id']}[/] via HF Transformers")
|
| 100 |
+
llm_backbone, tokenizer = get_llm_backbone_and_tokenizer(
|
| 101 |
+
model_cfg["llm_backbone_id"],
|
| 102 |
+
llm_max_length=model_cfg.get("llm_max_length", 2048),
|
| 103 |
+
hf_token=hf_token,
|
| 104 |
+
inference_mode=not load_for_training,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Load VLM using `from_pretrained` (clobbers HF syntax... eventually should reconcile)
|
| 108 |
+
overwatch.info(f"Loading VLM [bold blue]{model_cfg['model_id']}[/] from Checkpoint")
|
| 109 |
+
vlm = PrismaticVLM.from_pretrained(
|
| 110 |
+
checkpoint_pt,
|
| 111 |
+
model_cfg["model_id"],
|
| 112 |
+
vision_backbone,
|
| 113 |
+
llm_backbone,
|
| 114 |
+
arch_specifier=model_cfg["arch_specifier"],
|
| 115 |
+
freeze_weights=not load_for_training,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
return vlm
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# === Load Pretrained VLA Model ===
|
| 122 |
+
def load_vla(
|
| 123 |
+
model_id_or_path: Union[str, Path],
|
| 124 |
+
hf_token: Optional[str] = None,
|
| 125 |
+
cache_dir: Optional[Union[str, Path]] = None,
|
| 126 |
+
load_for_training: bool = False,
|
| 127 |
+
step_to_load: Optional[int] = None,
|
| 128 |
+
model_type: str = "pretrained",
|
| 129 |
+
) -> OpenVLA:
|
| 130 |
+
"""Loads a pretrained OpenVLA from either local disk or the HuggingFace Hub."""
|
| 131 |
+
|
| 132 |
+
# TODO (siddk, moojink) :: Unify semantics with `load()` above; right now, `load_vla()` assumes path points to
|
| 133 |
+
# checkpoint `.pt` file, rather than the top-level run directory!
|
| 134 |
+
if os.path.isfile(model_id_or_path):
|
| 135 |
+
overwatch.info(f"Loading from local checkpoint path `{(checkpoint_pt := Path(model_id_or_path))}`")
|
| 136 |
+
|
| 137 |
+
# [Validate] Checkpoint Path should look like `.../<RUN_ID>/checkpoints/<CHECKPOINT_PATH>.pt`
|
| 138 |
+
assert (checkpoint_pt.suffix == ".pt") and (checkpoint_pt.parent.name == "checkpoints"), "Invalid checkpoint!"
|
| 139 |
+
run_dir = checkpoint_pt.parents[1]
|
| 140 |
+
|
| 141 |
+
# Get paths for `config.json`, `dataset_statistics.json` and pretrained checkpoint
|
| 142 |
+
config_json, dataset_statistics_json = run_dir / "config.json", run_dir / "dataset_statistics.json"
|
| 143 |
+
assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`"
|
| 144 |
+
assert dataset_statistics_json.exists(), f"Missing `dataset_statistics.json` for `{run_dir = }`"
|
| 145 |
+
|
| 146 |
+
# Otherwise =>> try looking for a match on `model_id_or_path` on the HF Hub (`VLA_HF_HUB_REPO`)
|
| 147 |
+
else:
|
| 148 |
+
# Search HF Hub Repo via fsspec API
|
| 149 |
+
overwatch.info(f"Checking HF for `{(hf_path := str(Path(VLA_HF_HUB_REPO) / model_type / model_id_or_path))}`")
|
| 150 |
+
if not (tmpfs := HfFileSystem()).exists(hf_path):
|
| 151 |
+
raise ValueError(f"Couldn't find valid HF Hub Path `{hf_path = }`")
|
| 152 |
+
|
| 153 |
+
# Identify Checkpoint to Load (via `step_to_load`)
|
| 154 |
+
step_to_load = f"{step_to_load:06d}" if step_to_load is not None else None
|
| 155 |
+
valid_ckpts = tmpfs.glob(f"{hf_path}/checkpoints/step-{step_to_load if step_to_load is not None else ''}*.pt")
|
| 156 |
+
if (len(valid_ckpts) == 0) or (step_to_load is not None and len(valid_ckpts) != 1):
|
| 157 |
+
raise ValueError(f"Couldn't find a valid checkpoint to load from HF Hub Path `{hf_path}/checkpoints/")
|
| 158 |
+
|
| 159 |
+
# Call to `glob` will sort steps in ascending order (if `step_to_load` is None); just grab last element
|
| 160 |
+
target_ckpt = Path(valid_ckpts[-1]).name
|
| 161 |
+
|
| 162 |
+
overwatch.info(f"Downloading Model `{model_id_or_path}` Config & Checkpoint `{target_ckpt}`")
|
| 163 |
+
with overwatch.local_zero_first():
|
| 164 |
+
relpath = Path(model_type) / model_id_or_path
|
| 165 |
+
config_json = hf_hub_download(
|
| 166 |
+
repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'config.json')!s}", cache_dir=cache_dir
|
| 167 |
+
)
|
| 168 |
+
dataset_statistics_json = hf_hub_download(
|
| 169 |
+
repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'dataset_statistics.json')!s}", cache_dir=cache_dir
|
| 170 |
+
)
|
| 171 |
+
checkpoint_pt = hf_hub_download(
|
| 172 |
+
repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'checkpoints' / target_ckpt)!s}", cache_dir=cache_dir
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# Load VLA Config (and corresponding base VLM `ModelConfig`) from `config.json`
|
| 176 |
+
with open(config_json, "r") as f:
|
| 177 |
+
vla_cfg = json.load(f)["vla"]
|
| 178 |
+
model_cfg = ModelConfig.get_choice_class(vla_cfg["base_vlm"])()
|
| 179 |
+
|
| 180 |
+
# Load Dataset Statistics for Action Denormalization
|
| 181 |
+
with open(dataset_statistics_json, "r") as f:
|
| 182 |
+
norm_stats = json.load(f)
|
| 183 |
+
|
| 184 |
+
# = Load Individual Components necessary for Instantiating a VLA (via base VLM components) =
|
| 185 |
+
# =>> Print Minimal Config
|
| 186 |
+
overwatch.info(
|
| 187 |
+
f"Found Config =>> Loading & Freezing [bold blue]{model_cfg.model_id}[/] with:\n"
|
| 188 |
+
f" Vision Backbone =>> [bold]{model_cfg.vision_backbone_id}[/]\n"
|
| 189 |
+
f" LLM Backbone =>> [bold]{model_cfg.llm_backbone_id}[/]\n"
|
| 190 |
+
f" Arch Specifier =>> [bold]{model_cfg.arch_specifier}[/]\n"
|
| 191 |
+
f" Checkpoint Path =>> [underline]`{checkpoint_pt}`[/]"
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# Load Vision Backbone
|
| 195 |
+
overwatch.info(f"Loading Vision Backbone [bold]{model_cfg.vision_backbone_id}[/]")
|
| 196 |
+
vision_backbone, image_transform = get_vision_backbone_and_transform(
|
| 197 |
+
model_cfg.vision_backbone_id,
|
| 198 |
+
model_cfg.image_resize_strategy,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# Load LLM Backbone --> note `inference_mode = True` by default when calling `load()`
|
| 202 |
+
overwatch.info(f"Loading Pretrained LLM [bold]{model_cfg.llm_backbone_id}[/] via HF Transformers")
|
| 203 |
+
llm_backbone, tokenizer = get_llm_backbone_and_tokenizer(
|
| 204 |
+
model_cfg.llm_backbone_id,
|
| 205 |
+
llm_max_length=model_cfg.llm_max_length,
|
| 206 |
+
hf_token=hf_token,
|
| 207 |
+
inference_mode=not load_for_training,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# Create Action Tokenizer
|
| 211 |
+
action_tokenizer = ActionTokenizer(llm_backbone.get_tokenizer())
|
| 212 |
+
|
| 213 |
+
# Load VLM using `from_pretrained` (clobbers HF syntax... eventually should reconcile)
|
| 214 |
+
overwatch.info(f"Loading VLA [bold blue]{model_cfg.model_id}[/] from Checkpoint")
|
| 215 |
+
vla = OpenVLA.from_pretrained(
|
| 216 |
+
checkpoint_pt,
|
| 217 |
+
model_cfg.model_id,
|
| 218 |
+
vision_backbone,
|
| 219 |
+
llm_backbone,
|
| 220 |
+
arch_specifier=model_cfg.arch_specifier,
|
| 221 |
+
freeze_weights=not load_for_training,
|
| 222 |
+
norm_stats=norm_stats,
|
| 223 |
+
action_tokenizer=action_tokenizer,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
return vla
|
capvector-oft/prismatic/models/materialize.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
materialize.py
|
| 3 |
+
|
| 4 |
+
Factory class for initializing Vision Backbones, LLM Backbones, and VLMs from a set registry; provides and exports
|
| 5 |
+
individual functions for clear control flow.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Optional, Tuple
|
| 9 |
+
|
| 10 |
+
from transformers import PreTrainedTokenizerBase
|
| 11 |
+
|
| 12 |
+
from prismatic.models.backbones.llm import LLaMa2LLMBackbone, LLMBackbone, MistralLLMBackbone, PhiLLMBackbone
|
| 13 |
+
from prismatic.models.backbones.vision import (
|
| 14 |
+
CLIPViTBackbone,
|
| 15 |
+
DinoCLIPViTBackbone,
|
| 16 |
+
DinoSigLIPViTBackbone,
|
| 17 |
+
DinoV2ViTBackbone,
|
| 18 |
+
ImageTransform,
|
| 19 |
+
IN1KViTBackbone,
|
| 20 |
+
SigLIPViTBackbone,
|
| 21 |
+
VisionBackbone,
|
| 22 |
+
)
|
| 23 |
+
from prismatic.models.vlms import PrismaticVLM
|
| 24 |
+
|
| 25 |
+
# === Registries =>> Maps ID --> {cls(), kwargs} :: Different Registries for Vision Backbones, LLM Backbones, VLMs ===
|
| 26 |
+
# fmt: off
|
| 27 |
+
|
| 28 |
+
# === Vision Backbone Registry ===
|
| 29 |
+
VISION_BACKBONES = {
|
| 30 |
+
# === 224px Backbones ===
|
| 31 |
+
"clip-vit-l": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}},
|
| 32 |
+
"siglip-vit-so400m": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}},
|
| 33 |
+
"dinov2-vit-l": {"cls": DinoV2ViTBackbone, "kwargs": {"default_image_size": 224}},
|
| 34 |
+
"in1k-vit-l": {"cls": IN1KViTBackbone, "kwargs": {"default_image_size": 224}},
|
| 35 |
+
"dinosiglip-vit-so-224px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 224}},
|
| 36 |
+
|
| 37 |
+
# === Assorted CLIP Backbones ===
|
| 38 |
+
"clip-vit-b": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}},
|
| 39 |
+
"clip-vit-l-336px": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 336}},
|
| 40 |
+
|
| 41 |
+
# === Assorted SigLIP Backbones ===
|
| 42 |
+
"siglip-vit-b16-224px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}},
|
| 43 |
+
"siglip-vit-b16-256px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 256}},
|
| 44 |
+
"siglip-vit-b16-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}},
|
| 45 |
+
"siglip-vit-so400m-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}},
|
| 46 |
+
|
| 47 |
+
# === Fused Backbones ===
|
| 48 |
+
"dinoclip-vit-l-336px": {"cls": DinoCLIPViTBackbone, "kwargs": {"default_image_size": 336}},
|
| 49 |
+
"dinosiglip-vit-so-384px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 384}},
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# === Language Model Registry ===
|
| 54 |
+
LLM_BACKBONES = {
|
| 55 |
+
# === LLaMa-2 Pure (Non-Chat) Backbones ===
|
| 56 |
+
"llama2-7b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
|
| 57 |
+
"llama2-13b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
|
| 58 |
+
|
| 59 |
+
# === LLaMa-2 Chat Backbones ===
|
| 60 |
+
"llama2-7b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
|
| 61 |
+
"llama2-13b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
|
| 62 |
+
|
| 63 |
+
# === Vicuna-v1.5 Backbones ===
|
| 64 |
+
"vicuna-v15-7b": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
|
| 65 |
+
"vicuna-v15-13b": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
|
| 66 |
+
|
| 67 |
+
# === Mistral v0.1 Backbones ===
|
| 68 |
+
"mistral-v0.1-7b-pure": {"cls": MistralLLMBackbone, "kwargs": {}},
|
| 69 |
+
"mistral-v0.1-7b-instruct": {"cls": MistralLLMBackbone, "kwargs": {}},
|
| 70 |
+
|
| 71 |
+
# === Phi-2 Backbone ===
|
| 72 |
+
"phi-2-3b": {"cls": PhiLLMBackbone, "kwargs": {}},
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
# fmt: on
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_vision_backbone_and_transform(
|
| 79 |
+
vision_backbone_id: str, image_resize_strategy: str
|
| 80 |
+
) -> Tuple[VisionBackbone, ImageTransform]:
|
| 81 |
+
"""Instantiate a Vision Backbone, returning both the nn.Module wrapper class and default Image Transform."""
|
| 82 |
+
if vision_backbone_id in VISION_BACKBONES:
|
| 83 |
+
vision_cfg = VISION_BACKBONES[vision_backbone_id]
|
| 84 |
+
vision_backbone: VisionBackbone = vision_cfg["cls"](
|
| 85 |
+
vision_backbone_id, image_resize_strategy, **vision_cfg["kwargs"]
|
| 86 |
+
)
|
| 87 |
+
image_transform = vision_backbone.get_image_transform()
|
| 88 |
+
return vision_backbone, image_transform
|
| 89 |
+
|
| 90 |
+
else:
|
| 91 |
+
raise ValueError(f"Vision Backbone `{vision_backbone_id}` is not supported!")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def get_llm_backbone_and_tokenizer(
|
| 95 |
+
llm_backbone_id: str,
|
| 96 |
+
llm_max_length: int = 2048,
|
| 97 |
+
hf_token: Optional[str] = None,
|
| 98 |
+
inference_mode: bool = False,
|
| 99 |
+
) -> Tuple[LLMBackbone, PreTrainedTokenizerBase]:
|
| 100 |
+
if llm_backbone_id in LLM_BACKBONES:
|
| 101 |
+
llm_cfg = LLM_BACKBONES[llm_backbone_id]
|
| 102 |
+
llm_backbone: LLMBackbone = llm_cfg["cls"](
|
| 103 |
+
llm_backbone_id,
|
| 104 |
+
llm_max_length=llm_max_length,
|
| 105 |
+
hf_token=hf_token,
|
| 106 |
+
inference_mode=inference_mode,
|
| 107 |
+
**llm_cfg["kwargs"],
|
| 108 |
+
)
|
| 109 |
+
tokenizer = llm_backbone.get_tokenizer()
|
| 110 |
+
return llm_backbone, tokenizer
|
| 111 |
+
|
| 112 |
+
else:
|
| 113 |
+
raise ValueError(f"LLM Backbone `{llm_backbone_id}` is not supported!")
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def get_vlm(
|
| 117 |
+
model_id: str,
|
| 118 |
+
arch_specifier: str,
|
| 119 |
+
vision_backbone: VisionBackbone,
|
| 120 |
+
llm_backbone: LLMBackbone,
|
| 121 |
+
enable_mixed_precision_training: bool = True,
|
| 122 |
+
) -> PrismaticVLM:
|
| 123 |
+
"""Lightweight wrapper around initializing a VLM, mostly for future-proofing (if one wants to add a new VLM)."""
|
| 124 |
+
return PrismaticVLM(
|
| 125 |
+
model_id,
|
| 126 |
+
vision_backbone,
|
| 127 |
+
llm_backbone,
|
| 128 |
+
enable_mixed_precision_training=enable_mixed_precision_training,
|
| 129 |
+
arch_specifier=arch_specifier,
|
| 130 |
+
)
|
capvector-oft/prismatic/models/projectors.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Implementation of additional projectors for additional inputs to the VLA models."""
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ProprioProjector(nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
Projects proprio state inputs into the LLM's embedding space.
|
| 9 |
+
"""
|
| 10 |
+
def __init__(self, llm_dim: int, proprio_dim: int) -> None:
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.llm_dim = llm_dim
|
| 13 |
+
self.proprio_dim = proprio_dim
|
| 14 |
+
|
| 15 |
+
self.fc1 = nn.Linear(self.proprio_dim, self.llm_dim, bias=True)
|
| 16 |
+
self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
|
| 17 |
+
self.act_fn1 = nn.GELU()
|
| 18 |
+
|
| 19 |
+
def forward(self, proprio: torch.Tensor = None) -> torch.Tensor:
|
| 20 |
+
# proprio: (bsz, proprio_dim)
|
| 21 |
+
projected_features = self.fc1(proprio)
|
| 22 |
+
projected_features = self.act_fn1(projected_features)
|
| 23 |
+
projected_features = self.fc2(projected_features)
|
| 24 |
+
return projected_features
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class NoisyActionProjector(nn.Module):
|
| 28 |
+
"""
|
| 29 |
+
[Diffusion] Projects noisy action inputs into the LLM's embedding space.
|
| 30 |
+
|
| 31 |
+
Note that since each action is tokenized into 7 tokens in OpenVLA (rather
|
| 32 |
+
than having 1 token per action), each noisy action token will have dimension 1
|
| 33 |
+
instead of 7.
|
| 34 |
+
"""
|
| 35 |
+
def __init__(self, llm_dim: int) -> None:
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.llm_dim = llm_dim
|
| 38 |
+
self.action_token_dim = 1
|
| 39 |
+
|
| 40 |
+
self.fc1 = nn.Linear(self.action_token_dim, self.llm_dim, bias=True)
|
| 41 |
+
self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
|
| 42 |
+
self.act_fn1 = nn.GELU()
|
| 43 |
+
|
| 44 |
+
def forward(self, noisy_actions: torch.Tensor = None) -> torch.Tensor:
|
| 45 |
+
# noisy_actions: (bsz, num_action_tokens=chunk_len*action_dim, 1)
|
| 46 |
+
projected_features = self.fc1(noisy_actions)
|
| 47 |
+
projected_features = self.act_fn1(projected_features)
|
| 48 |
+
projected_features = self.fc2(projected_features)
|
| 49 |
+
return projected_features
|
capvector-oft/prismatic/models/registry.py
ADDED
|
@@ -0,0 +1,691 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
registry.py
|
| 3 |
+
|
| 4 |
+
Exhaustive list of pretrained VLMs (with full descriptions / links to corresponding names and sections of paper).
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
# === Pretrained Model Registry ===
|
| 8 |
+
# fmt: off
|
| 9 |
+
MODEL_REGISTRY = {
|
| 10 |
+
# === LLaVa v1.5 Reproductions ===
|
| 11 |
+
"reproduction-llava-v15+7b": {
|
| 12 |
+
"model_id": "reproduction-llava-v15+7b",
|
| 13 |
+
"names": ["LLaVa v1.5 7B (Reproduction)"],
|
| 14 |
+
"description": {
|
| 15 |
+
"name": "LLaVa v1.5 7B (Reproduction)",
|
| 16 |
+
"optimization_procedure": "multi-stage",
|
| 17 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 18 |
+
"image_processing": "Letterbox",
|
| 19 |
+
"language_model": "Vicuña v1.5 7B",
|
| 20 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 21 |
+
"train_epochs": 1,
|
| 22 |
+
}
|
| 23 |
+
},
|
| 24 |
+
"reproduction-llava-v15+13b": {
|
| 25 |
+
"model_id": "reproduction-llava-v15+13b",
|
| 26 |
+
"names": ["LLaVa v1.5 13B (Reproduction)"],
|
| 27 |
+
"description": {
|
| 28 |
+
"name": "LLaVa v1.5 13B (Reproduction)",
|
| 29 |
+
"optimization_procedure": "multi-stage",
|
| 30 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 31 |
+
"image_processing": "Letterbox",
|
| 32 |
+
"language_model": "Vicuña v1.5 13B",
|
| 33 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 34 |
+
"train_epochs": 1,
|
| 35 |
+
}
|
| 36 |
+
},
|
| 37 |
+
|
| 38 |
+
# === Section 4.1 :: Optimization Procedure ===
|
| 39 |
+
"one-stage+7b": {
|
| 40 |
+
"model_id": "one-stage+7b",
|
| 41 |
+
"names": [
|
| 42 |
+
"One-Stage 7B",
|
| 43 |
+
"Single-Stage 7B",
|
| 44 |
+
"Frozen ViT (Single-Stage)",
|
| 45 |
+
"CLIP ViT-L 336px (Letterbox)",
|
| 46 |
+
"CLIP ViT-L 336px",
|
| 47 |
+
"Vicuña v1.5 7B",
|
| 48 |
+
"1 Epoch",
|
| 49 |
+
"Base",
|
| 50 |
+
],
|
| 51 |
+
"description": {
|
| 52 |
+
"name": "Single-Stage 7B",
|
| 53 |
+
"optimization_procedure": "single-stage",
|
| 54 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 55 |
+
"image_processing": "Letterbox",
|
| 56 |
+
"language_model": "Vicuña v1.5 7B",
|
| 57 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 58 |
+
"train_epochs": 1,
|
| 59 |
+
}
|
| 60 |
+
},
|
| 61 |
+
"one-stage+13b": {
|
| 62 |
+
"model_id": "one-stage+13b",
|
| 63 |
+
"names": [
|
| 64 |
+
"One-Stage 13B",
|
| 65 |
+
"Single-Stage 13B",
|
| 66 |
+
"Vicuña v1.5 13B",
|
| 67 |
+
],
|
| 68 |
+
"description": {
|
| 69 |
+
"name": "Single-Stage 13B",
|
| 70 |
+
"optimization_procedure": "single-stage",
|
| 71 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 72 |
+
"image_processing": "Letterbox",
|
| 73 |
+
"language_model": "Vicuña v1.5 13B",
|
| 74 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 75 |
+
"train_epochs": 1,
|
| 76 |
+
}
|
| 77 |
+
},
|
| 78 |
+
|
| 79 |
+
"full-ft-multi-stage+7b": {
|
| 80 |
+
"model_id": "full-ft-multi-stage+7b",
|
| 81 |
+
"names": ["Finetune ViT (Multi-Stage)"],
|
| 82 |
+
"description": {
|
| 83 |
+
"name": "Finetune ViT (Multi-Stage)",
|
| 84 |
+
"optimization_procedure": "multi-stage-full-finetune",
|
| 85 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 86 |
+
"image_processing": "Letterbox",
|
| 87 |
+
"language_model": "Vicuña v1.5 7B",
|
| 88 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 89 |
+
"train_epochs": 1,
|
| 90 |
+
}
|
| 91 |
+
},
|
| 92 |
+
"full-ft-one-stage+7b": {
|
| 93 |
+
"model_id": "full-ft-one-stage+7b",
|
| 94 |
+
"names": ["Finetune ViT (Single-Stage)"],
|
| 95 |
+
"description": {
|
| 96 |
+
"name": "Finetune ViT (Single-Stage)",
|
| 97 |
+
"optimization_procedure": "single-stage-full-finetune",
|
| 98 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 99 |
+
"image_processing": "Letterbox",
|
| 100 |
+
"language_model": "Vicuña v1.5 7B",
|
| 101 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 102 |
+
"train_epochs": 1,
|
| 103 |
+
}
|
| 104 |
+
},
|
| 105 |
+
|
| 106 |
+
# === Section 4.2 :: Image Processing and Visual Representations ===
|
| 107 |
+
"in1k-224px+7b": {
|
| 108 |
+
"model_id": "in1k-224px+7b",
|
| 109 |
+
"names": ["IN1K ViT-L 224px"],
|
| 110 |
+
"description": {
|
| 111 |
+
"name": "IN1K ViT-L 224px",
|
| 112 |
+
"optimization_procedure": "single-stage",
|
| 113 |
+
"visual_representation": "ImageNet-21K+1K ViT-L/16 @ 224px",
|
| 114 |
+
"image_processing": "Letterbox",
|
| 115 |
+
"language_model": "Vicuña v1.5 7B",
|
| 116 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 117 |
+
"train_epochs": 1,
|
| 118 |
+
},
|
| 119 |
+
},
|
| 120 |
+
"dinov2-224px+7b": {
|
| 121 |
+
"model_id": "dinov2-224px+7b",
|
| 122 |
+
"names": ["DINOv2 ViT-L 224px"],
|
| 123 |
+
"description": {
|
| 124 |
+
"name": "DINOv2 ViT-L 224px",
|
| 125 |
+
"optimization_procedure": "single-stage",
|
| 126 |
+
"visual_representation": "DINOv2 ViT-L/14 @ 224px",
|
| 127 |
+
"image_processing": "Letterbox",
|
| 128 |
+
"language_model": "Vicuña v1.5 7B",
|
| 129 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 130 |
+
"train_epochs": 1,
|
| 131 |
+
},
|
| 132 |
+
},
|
| 133 |
+
"clip-224px+7b": {
|
| 134 |
+
"model_id": "clip-224px+7b",
|
| 135 |
+
"names": ["CLIP ViT-L 224px"],
|
| 136 |
+
"description": {
|
| 137 |
+
"name": "CLIP ViT-L 224px",
|
| 138 |
+
"optimization_procedure": "single-stage",
|
| 139 |
+
"visual_representation": "CLIP ViT-L/14 @ 224px",
|
| 140 |
+
"image_processing": "Letterbox",
|
| 141 |
+
"language_model": "Vicuña v1.5 7B",
|
| 142 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 143 |
+
"train_epochs": 1,
|
| 144 |
+
},
|
| 145 |
+
},
|
| 146 |
+
"siglip-224px+7b": {
|
| 147 |
+
"model_id": "siglip-224px+7b",
|
| 148 |
+
"names": ["SigLIP ViT-SO 224px"],
|
| 149 |
+
"description": {
|
| 150 |
+
"name": "SigLIP ViT-SO 224px",
|
| 151 |
+
"optimization_procedure": "single-stage",
|
| 152 |
+
"visual_representation": "SigLIP ViT-SO/14 @ 224px",
|
| 153 |
+
"image_processing": "Letterbox",
|
| 154 |
+
"language_model": "Vicuña v1.5 7B",
|
| 155 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 156 |
+
"train_epochs": 1,
|
| 157 |
+
},
|
| 158 |
+
},
|
| 159 |
+
|
| 160 |
+
"clip-336px-resize-crop+7b": {
|
| 161 |
+
"model_id": "clip-336px-resize-crop+7b",
|
| 162 |
+
"names": ["CLIP ViT-L 336px (Resize Crop)"],
|
| 163 |
+
"description": {
|
| 164 |
+
"name": "CLIP ViT-L 336px (Resize Crop)",
|
| 165 |
+
"optimization_procedure": "single-stage",
|
| 166 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 167 |
+
"image_processing": "Resize Crop",
|
| 168 |
+
"language_model": "Vicuña v1.5 7B",
|
| 169 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 170 |
+
"train_epochs": 1,
|
| 171 |
+
}
|
| 172 |
+
},
|
| 173 |
+
"clip-336px-resize-naive+7b": {
|
| 174 |
+
"model_id": "clip-336px-resize-naive+7b",
|
| 175 |
+
"names": ["CLIP ViT-L 336px (Naive Resize)", "CLIP 336px (Naive Resize)"],
|
| 176 |
+
"description": {
|
| 177 |
+
"name": "CLIP ViT-L 336px (Naive Resize)",
|
| 178 |
+
"optimization_procedure": "single-stage",
|
| 179 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 180 |
+
"image_processing": "Naive Resize",
|
| 181 |
+
"language_model": "Vicuña v1.5 7B",
|
| 182 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 183 |
+
"train_epochs": 1,
|
| 184 |
+
}
|
| 185 |
+
},
|
| 186 |
+
"siglip-384px-letterbox+7b": {
|
| 187 |
+
"model_id": "siglip-384px-letterbox+7b",
|
| 188 |
+
"names": ["SigLIP ViT-SO 384px (Letterbox)", "SigLIP ViT-SO 384px"],
|
| 189 |
+
"description": {
|
| 190 |
+
"name": "SigLIP ViT-SO 384px (Letterbox)",
|
| 191 |
+
"optimization_procedure": "single-stage",
|
| 192 |
+
"visual_representation": "SigLIP ViT-SO/14 @ 384px",
|
| 193 |
+
"image_processing": "Letterbox",
|
| 194 |
+
"language_model": "Vicuña v1.5 7B",
|
| 195 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 196 |
+
"train_epochs": 1,
|
| 197 |
+
}
|
| 198 |
+
},
|
| 199 |
+
"siglip-384px-resize-crop+7b": {
|
| 200 |
+
"model_id": "siglip-384px-resize-crop+7b",
|
| 201 |
+
"names": ["SigLIP ViT-SO 384px (Resize Crop)"],
|
| 202 |
+
"description": {
|
| 203 |
+
"name": "SigLIP ViT-SO 384px (Resize Crop)",
|
| 204 |
+
"optimization_procedure": "single-stage",
|
| 205 |
+
"visual_representation": "SigLIP ViT-SO/14 @ 384px",
|
| 206 |
+
"image_processing": "Resize Crop",
|
| 207 |
+
"language_model": "Vicuña v1.5 7B",
|
| 208 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 209 |
+
"train_epochs": 1,
|
| 210 |
+
}
|
| 211 |
+
},
|
| 212 |
+
"siglip-384px-resize-naive+7b": {
|
| 213 |
+
"model_id": "siglip-384px-resize-naive+7b",
|
| 214 |
+
"names": ["SigLIP ViT-SO 384px (Naive Resize)", "SigLIP 384px (Naive Resize)"],
|
| 215 |
+
"description": {
|
| 216 |
+
"name": "SigLIP ViT-SO 384px (Naive Resize)",
|
| 217 |
+
"optimization_procedure": "single-stage",
|
| 218 |
+
"visual_representation": "SigLIP ViT-SO/14 @ 384px",
|
| 219 |
+
"image_processing": "Naive Resize",
|
| 220 |
+
"language_model": "Vicuña v1.5 7B",
|
| 221 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 222 |
+
"train_epochs": 1,
|
| 223 |
+
}
|
| 224 |
+
},
|
| 225 |
+
|
| 226 |
+
"dinoclip-336px-letterbox+7b": {
|
| 227 |
+
"model_id": "dinoclip-336px-letterbox+7b",
|
| 228 |
+
"names": ["DINOv2 + CLIP 336px (Letterbox)"],
|
| 229 |
+
"description": {
|
| 230 |
+
"name": "DINOv2 + CLIP 336px (Letterbox)",
|
| 231 |
+
"optimization_procedure": "single-stage",
|
| 232 |
+
"visual_representation": "DINOv2 ViT-L/14 + CLIP ViT-L/14 @ 336px",
|
| 233 |
+
"image_processing": "Letterbox",
|
| 234 |
+
"language_model": "Vicuña v1.5 7B",
|
| 235 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 236 |
+
"train_epochs": 1,
|
| 237 |
+
}
|
| 238 |
+
},
|
| 239 |
+
"dinoclip-336px-resize-naive+7b": {
|
| 240 |
+
"model_id": "dinoclip-336px-resize-naive+7b",
|
| 241 |
+
"names": ["DINOv2 + CLIP 336px (Naive Resize)"],
|
| 242 |
+
"description": {
|
| 243 |
+
"name": "DINOv2 + CLIP 336px (Naive Resize)",
|
| 244 |
+
"optimization_procedure": "single-stage",
|
| 245 |
+
"visual_representation": "DINOv2 ViT-L/14 + CLIP ViT-L/14 @ 336px",
|
| 246 |
+
"image_processing": "Naive Resize",
|
| 247 |
+
"language_model": "Vicuña v1.5 7B",
|
| 248 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 249 |
+
"train_epochs": 1,
|
| 250 |
+
}
|
| 251 |
+
},
|
| 252 |
+
"dinosiglip-384px-letterbox+7b": {
|
| 253 |
+
"model_id": "dinosiglip-384px-letterbox+7b",
|
| 254 |
+
"names": ["DINOv2 + SigLIP 384px (Letterbox)"],
|
| 255 |
+
"description": {
|
| 256 |
+
"name": "DINOv2 + SigLIP 384px (Letterbox)",
|
| 257 |
+
"optimization_procedure": "single-stage",
|
| 258 |
+
"visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-L/14 @ 384px",
|
| 259 |
+
"image_processing": "Letterbox",
|
| 260 |
+
"language_model": "Vicuña v1.5 7B",
|
| 261 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 262 |
+
"train_epochs": 1,
|
| 263 |
+
}
|
| 264 |
+
},
|
| 265 |
+
"dinosiglip-384px-resize-naive+7b": {
|
| 266 |
+
"model_id": "dinosiglip-384px-resize-naive+7b",
|
| 267 |
+
"names": ["DINOv2 + SigLIP 384px (Naive Resize)"],
|
| 268 |
+
"description": {
|
| 269 |
+
"name": "DINOv2 + SigLIP 384px (Naive Resize)",
|
| 270 |
+
"optimization_procedure": "single-stage",
|
| 271 |
+
"visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-L/14 @ 384px",
|
| 272 |
+
"image_processing": "Naive Resize",
|
| 273 |
+
"language_model": "Vicuña v1.5 7B",
|
| 274 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 275 |
+
"train_epochs": 1,
|
| 276 |
+
}
|
| 277 |
+
},
|
| 278 |
+
|
| 279 |
+
# === Section 4.3 :: Language Models ===
|
| 280 |
+
"llama2+7b": {
|
| 281 |
+
"model_id": "llama2+7b",
|
| 282 |
+
"names": ["Llama-2 7B"],
|
| 283 |
+
"description": {
|
| 284 |
+
"name": "Llama-2 7B",
|
| 285 |
+
"optimization_procedure": "single-stage",
|
| 286 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 287 |
+
"image_processing": "Letterbox",
|
| 288 |
+
"language_model": "Llama-2 7B",
|
| 289 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 290 |
+
"train_epochs": 1,
|
| 291 |
+
},
|
| 292 |
+
},
|
| 293 |
+
"llama2+13b": {
|
| 294 |
+
"model_id": "llama2+13b",
|
| 295 |
+
"names": ["Llama-2 13B"],
|
| 296 |
+
"description": {
|
| 297 |
+
"name": "Llama-2 13B",
|
| 298 |
+
"optimization_procedure": "single-stage",
|
| 299 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 300 |
+
"image_processing": "Letterbox",
|
| 301 |
+
"language_model": "Llama-2 13B",
|
| 302 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 303 |
+
"train_epochs": 1,
|
| 304 |
+
},
|
| 305 |
+
},
|
| 306 |
+
|
| 307 |
+
"vicuna-no-cotraining+7b": {
|
| 308 |
+
"model_id": "vicuna-no-cotraining+7b",
|
| 309 |
+
"names": ["Vicuña v1.5 7B (No Co-training)"],
|
| 310 |
+
"description": {
|
| 311 |
+
"name": "Vicuña v1.5 7B (No Co-training)",
|
| 312 |
+
"optimization_procedure": "single-stage",
|
| 313 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 314 |
+
"image_processing": "Letterbox",
|
| 315 |
+
"language_model": "Vicuña v1.5 7B",
|
| 316 |
+
"datasets": ["LLaVa v1.5 Multimodal-Only"],
|
| 317 |
+
"train_epochs": 1,
|
| 318 |
+
},
|
| 319 |
+
},
|
| 320 |
+
"llama2-no-cotraining+7b": {
|
| 321 |
+
"model_id": "llama2-no-cotraining+7b",
|
| 322 |
+
"names": ["Llama-2 7B (No Co-training)"],
|
| 323 |
+
"description": {
|
| 324 |
+
"name": "Llama-2 7B (No Co-training)",
|
| 325 |
+
"optimization_procedure": "single-stage",
|
| 326 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 327 |
+
"image_processing": "Letterbox",
|
| 328 |
+
"language_model": "Llama-2 7B",
|
| 329 |
+
"datasets": ["LLaVa v1.5 Multimodal-Only"],
|
| 330 |
+
"train_epochs": 1,
|
| 331 |
+
},
|
| 332 |
+
},
|
| 333 |
+
|
| 334 |
+
# === Section 4.4 :: Scaling Properties ===
|
| 335 |
+
"train-1.25-epochs+7b": {
|
| 336 |
+
"model_id": "train-1.25-epochs+7b",
|
| 337 |
+
"names": ["1.25 Epochs"],
|
| 338 |
+
"description": {
|
| 339 |
+
"name": "1.25 Epochs",
|
| 340 |
+
"optimization_procedure": "single-stage",
|
| 341 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 342 |
+
"image_processing": "Letterbox",
|
| 343 |
+
"language_model": "Vicuña v1.5 7B",
|
| 344 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 345 |
+
"train_epochs": 1.25,
|
| 346 |
+
}
|
| 347 |
+
},
|
| 348 |
+
"train-1.5-epochs+7b": {
|
| 349 |
+
"model_id": "train-1.5-epochs+7b",
|
| 350 |
+
"names": ["1.5 Epochs"],
|
| 351 |
+
"description": {
|
| 352 |
+
"name": "1.5 Epochs",
|
| 353 |
+
"optimization_procedure": "single-stage",
|
| 354 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 355 |
+
"image_processing": "Letterbox",
|
| 356 |
+
"language_model": "Vicuña v1.5 7B",
|
| 357 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 358 |
+
"train_epochs": 1.5,
|
| 359 |
+
}
|
| 360 |
+
},
|
| 361 |
+
"train-2-epochs+7b": {
|
| 362 |
+
"model_id": "train-2-epochs+7b",
|
| 363 |
+
"names": ["2 Epochs"],
|
| 364 |
+
"description": {
|
| 365 |
+
"name": "2 Epochs",
|
| 366 |
+
"optimization_procedure": "single-stage",
|
| 367 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 368 |
+
"image_processing": "Letterbox",
|
| 369 |
+
"language_model": "Vicuña v1.5 7B",
|
| 370 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 371 |
+
"train_epochs": 2,
|
| 372 |
+
}
|
| 373 |
+
},
|
| 374 |
+
"train-3-epochs+7b": {
|
| 375 |
+
"model_id": "train-3-epochs+7b",
|
| 376 |
+
"names": ["3 Epochs"],
|
| 377 |
+
"description": {
|
| 378 |
+
"name": "3 Epochs",
|
| 379 |
+
"optimization_procedure": "single-stage",
|
| 380 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 381 |
+
"image_processing": "Letterbox",
|
| 382 |
+
"language_model": "Vicuña v1.5 7B",
|
| 383 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 384 |
+
"train_epochs": 3,
|
| 385 |
+
}
|
| 386 |
+
},
|
| 387 |
+
|
| 388 |
+
"llava-lvis4v+7b": {
|
| 389 |
+
"model_id": "llava-lvis4v+7b",
|
| 390 |
+
"names": ["Base + LVIS-4V"],
|
| 391 |
+
"description": {
|
| 392 |
+
"name": "Base + LVIS-4V",
|
| 393 |
+
"optimization_procedure": "single-stage",
|
| 394 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 395 |
+
"image_processing": "Letterbox",
|
| 396 |
+
"language_model": "Vicuña v1.5 7B",
|
| 397 |
+
"datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V"],
|
| 398 |
+
"train_epochs": 1,
|
| 399 |
+
}
|
| 400 |
+
},
|
| 401 |
+
"llava-lrv+7b": {
|
| 402 |
+
"model_id": "llava-lrv+7b",
|
| 403 |
+
"names": ["Base + LRV"],
|
| 404 |
+
"description": {
|
| 405 |
+
"name": "Base + LRV",
|
| 406 |
+
"optimization_procedure": "single-stage",
|
| 407 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 408 |
+
"image_processing": "Letterbox",
|
| 409 |
+
"language_model": "Vicuña v1.5 7B",
|
| 410 |
+
"datasets": ["LLaVa v1.5 Instruct", "LRV-Instruct"],
|
| 411 |
+
"train_epochs": 1,
|
| 412 |
+
}
|
| 413 |
+
},
|
| 414 |
+
"llava-lvis4v-lrv+7b": {
|
| 415 |
+
"model_id": "llava-lvis4v-lrv+7b",
|
| 416 |
+
"names": ["Base + LVIS-4V + LRV"],
|
| 417 |
+
"description": {
|
| 418 |
+
"name": "Base + LVIS-4V + LRV",
|
| 419 |
+
"optimization_procedure": "single-stage",
|
| 420 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 421 |
+
"image_processing": "Letterbox",
|
| 422 |
+
"language_model": "Vicuña v1.5 7B",
|
| 423 |
+
"datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"],
|
| 424 |
+
"train_epochs": 1,
|
| 425 |
+
}
|
| 426 |
+
},
|
| 427 |
+
|
| 428 |
+
# ===
|
| 429 |
+
|
| 430 |
+
# === CLIP Prism Models ===
|
| 431 |
+
"prism-clip-controlled+7b": {
|
| 432 |
+
"model_id": "prism-clip-controlled+7b",
|
| 433 |
+
"names": ["Prism-CLIP 7B (Controlled)"],
|
| 434 |
+
"description": {
|
| 435 |
+
"name": "CLIP Prism 7B (Controlled)",
|
| 436 |
+
"optimization_procedure": "single-stage",
|
| 437 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 438 |
+
"image_processing": "Naive Resize",
|
| 439 |
+
"language_model": "Llama-2 7B",
|
| 440 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 441 |
+
"train_epochs": 1,
|
| 442 |
+
}
|
| 443 |
+
},
|
| 444 |
+
"prism-clip-controlled+13b": {
|
| 445 |
+
"model_id": "prism-clip-controlled+13b",
|
| 446 |
+
"names": ["Prism-CLIP 13B (Controlled)"],
|
| 447 |
+
"description": {
|
| 448 |
+
"name": "CLIP Prism 13B (Controlled)",
|
| 449 |
+
"optimization_procedure": "single-stage",
|
| 450 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 451 |
+
"image_processing": "Naive Resize",
|
| 452 |
+
"language_model": "Llama-2 13B",
|
| 453 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 454 |
+
"train_epochs": 1,
|
| 455 |
+
}
|
| 456 |
+
},
|
| 457 |
+
"prism-clip+7b": {
|
| 458 |
+
"model_id": "prism-clip+7b",
|
| 459 |
+
"names": ["Prism-CLIP 7B"],
|
| 460 |
+
"description": {
|
| 461 |
+
"name": "CLIP Prism 7B",
|
| 462 |
+
"optimization_procedure": "single-stage",
|
| 463 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 464 |
+
"image_processing": "Naive Resize",
|
| 465 |
+
"language_model": "Llama-2 7B",
|
| 466 |
+
"datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"],
|
| 467 |
+
"train_epochs": 2,
|
| 468 |
+
},
|
| 469 |
+
},
|
| 470 |
+
"prism-clip+13b": {
|
| 471 |
+
"model_id": "prism-clip+13b",
|
| 472 |
+
"names": ["Prism-CLIP 13B"],
|
| 473 |
+
"description": {
|
| 474 |
+
"name": "CLIP Prism 13B",
|
| 475 |
+
"optimization_procedure": "single-stage",
|
| 476 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 477 |
+
"image_processing": "Naive Resize",
|
| 478 |
+
"language_model": "Llama-2 13B",
|
| 479 |
+
"datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"],
|
| 480 |
+
"train_epochs": 2,
|
| 481 |
+
},
|
| 482 |
+
},
|
| 483 |
+
|
| 484 |
+
# === SigLIP Prism Models ==
|
| 485 |
+
"prism-siglip-controlled+7b": {
|
| 486 |
+
"model_id": "prism-siglip-controlled+7b",
|
| 487 |
+
"names": ["Prism-SigLIP 7B (Controlled)"],
|
| 488 |
+
"description": {
|
| 489 |
+
"name": "SigLIP Prism 7B (Controlled)",
|
| 490 |
+
"optimization_procedure": "single-stage",
|
| 491 |
+
"visual_representation": "SigLIP ViT-SO/14 @ 384px",
|
| 492 |
+
"image_processing": "Naive Resize",
|
| 493 |
+
"language_model": "Llama-2 7B",
|
| 494 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 495 |
+
"train_epochs": 1,
|
| 496 |
+
}
|
| 497 |
+
},
|
| 498 |
+
"prism-siglip-controlled+13b": {
|
| 499 |
+
"model_id": "prism-siglip-controlled+7b",
|
| 500 |
+
"names": ["Prism-SigLIP 13B (Controlled)"],
|
| 501 |
+
"description": {
|
| 502 |
+
"name": "SigLIP Prism 13B (Controlled)",
|
| 503 |
+
"optimization_procedure": "single-stage",
|
| 504 |
+
"visual_representation": "SigLIP ViT-SO/14 @ 384px",
|
| 505 |
+
"image_processing": "Naive Resize",
|
| 506 |
+
"language_model": "Llama-2 13B",
|
| 507 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 508 |
+
"train_epochs": 1,
|
| 509 |
+
}
|
| 510 |
+
},
|
| 511 |
+
"prism-siglip+7b": {
|
| 512 |
+
"model_id": "prism-siglip+7b",
|
| 513 |
+
"names": ["Prism-SigLIP 7B"],
|
| 514 |
+
"description": {
|
| 515 |
+
"name": "SigLIP Prism 7B",
|
| 516 |
+
"optimization_procedure": "single-stage",
|
| 517 |
+
"visual_representation": "SigLIP ViT-SO/14 @ 384px",
|
| 518 |
+
"image_processing": "Naive Resize",
|
| 519 |
+
"language_model": "Llama-2 7B",
|
| 520 |
+
"datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"],
|
| 521 |
+
"train_epochs": 2,
|
| 522 |
+
}
|
| 523 |
+
},
|
| 524 |
+
"prism-siglip+13b": {
|
| 525 |
+
"model_id": "prism-siglip+13b",
|
| 526 |
+
"names": ["Prism-SigLIP 13B"],
|
| 527 |
+
"description": {
|
| 528 |
+
"name": "SigLIP Prism 13B",
|
| 529 |
+
"optimization_procedure": "single-stage",
|
| 530 |
+
"visual_representation": "SigLIP ViT-SO/14 @ 384px",
|
| 531 |
+
"image_processing": "Naive Resize",
|
| 532 |
+
"language_model": "Llama-2 13B",
|
| 533 |
+
"datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"],
|
| 534 |
+
"train_epochs": 2,
|
| 535 |
+
}
|
| 536 |
+
},
|
| 537 |
+
|
| 538 |
+
# === DINOSigLIP Prism Models ===
|
| 539 |
+
"prism-dinosiglip-controlled+7b": {
|
| 540 |
+
"model_id": "prism-dinosiglip-controlled+7b",
|
| 541 |
+
"names": ["Prism-DINOSigLIP 7B (Controlled)", "Prism 7B (Controlled)"],
|
| 542 |
+
"description": {
|
| 543 |
+
"name": "DINOSigLIP Prism 7B (Controlled)",
|
| 544 |
+
"optimization_procedure": "single-stage",
|
| 545 |
+
"visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO/14 @ 384px",
|
| 546 |
+
"image_processing": "Naive Resize",
|
| 547 |
+
"language_model": "Llama-2 7B",
|
| 548 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 549 |
+
"train_epochs": 1,
|
| 550 |
+
}
|
| 551 |
+
},
|
| 552 |
+
"prism-dinosiglip-controlled+13b": {
|
| 553 |
+
"model_id": "prism-dinosiglip-controlled+13b",
|
| 554 |
+
"names": ["Prism-DINOSigLIP 13B (Controlled)", "Prism 13B (Controlled)"],
|
| 555 |
+
"description": {
|
| 556 |
+
"name": "DINOSigLIP Prism 13B (Controlled)",
|
| 557 |
+
"optimization_procedure": "single-stage",
|
| 558 |
+
"visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO/14 @ 384px",
|
| 559 |
+
"image_processing": "Naive Resize",
|
| 560 |
+
"language_model": "Llama-2 13B",
|
| 561 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 562 |
+
"train_epochs": 1,
|
| 563 |
+
}
|
| 564 |
+
},
|
| 565 |
+
"prism-dinosiglip+7b": {
|
| 566 |
+
"model_id": "prism-dinosiglip+7b",
|
| 567 |
+
"names": ["Prism-DINOSigLIP 7B"],
|
| 568 |
+
"description": {
|
| 569 |
+
"name": "DINOSigLIP Prism 7B",
|
| 570 |
+
"optimization_procedure": "single-stage",
|
| 571 |
+
"visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO/14 @ 384px",
|
| 572 |
+
"image_processing": "Naive Resize",
|
| 573 |
+
"language_model": "Llama-2 7B",
|
| 574 |
+
"datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"],
|
| 575 |
+
"train_epochs": 2,
|
| 576 |
+
},
|
| 577 |
+
},
|
| 578 |
+
"prism-dinosiglip+13b": {
|
| 579 |
+
"model_id": "prism-dinosiglip+13b",
|
| 580 |
+
"names": ["Prism-DINOSigLIP 13B"],
|
| 581 |
+
"description": {
|
| 582 |
+
"name": "DINOSigLIP Prism 13B",
|
| 583 |
+
"optimization_procedure": "single-stage",
|
| 584 |
+
"visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO/14 @ 384px",
|
| 585 |
+
"image_processing": "Naive Resize",
|
| 586 |
+
"language_model": "Llama-2 13B",
|
| 587 |
+
"datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"],
|
| 588 |
+
"train_epochs": 2,
|
| 589 |
+
},
|
| 590 |
+
},
|
| 591 |
+
|
| 592 |
+
# === DINOSigLIP 224px Prism Models ===
|
| 593 |
+
"prism-dinosiglip-224px-controlled+7b": {
|
| 594 |
+
"model_id": "prism-dinosiglip-224px-controlled+7b",
|
| 595 |
+
"names": ["Prism-DINOSigLIP 224px 7B (Controlled)"],
|
| 596 |
+
"description": {
|
| 597 |
+
"name": "DINOSigLIP 224px 7B (Controlled)",
|
| 598 |
+
"optimization_procedure": "single-stage",
|
| 599 |
+
"visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO 14 @ 224px",
|
| 600 |
+
"image_processing": "Naive Resize",
|
| 601 |
+
"language_model": "Llama-2 7B",
|
| 602 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 603 |
+
"train_epochs": 1,
|
| 604 |
+
}
|
| 605 |
+
},
|
| 606 |
+
"prism-dinosiglip-224px+7b": {
|
| 607 |
+
"model_id": "prism-dinosiglip-224px+7b",
|
| 608 |
+
"names": ["Prism-DINOSigLIP 224px 7B"],
|
| 609 |
+
"description": {
|
| 610 |
+
"name": "DINOSigLIP 224px 7B",
|
| 611 |
+
"optimization_procedure": "single-stage",
|
| 612 |
+
"visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO 14 @ 224px",
|
| 613 |
+
"image_processing": "Naive Resize",
|
| 614 |
+
"language_model": "Llama-2 7B",
|
| 615 |
+
"datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"],
|
| 616 |
+
"train_epochs": 2,
|
| 617 |
+
}
|
| 618 |
+
},
|
| 619 |
+
|
| 620 |
+
# === Additional LLM Backbones ===
|
| 621 |
+
"llama2-chat+7b": {
|
| 622 |
+
"model_id": "llama2-chat+7b",
|
| 623 |
+
"names": ["Llama-2 Chat 7B"],
|
| 624 |
+
"description": {
|
| 625 |
+
"name": "Llama-2 Chat 7B",
|
| 626 |
+
"optimization_procedure": "single-stage",
|
| 627 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 628 |
+
"image_processing": "Letterbox",
|
| 629 |
+
"language_model": "Llama-2 Chat 7B",
|
| 630 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 631 |
+
"train_epochs": 1,
|
| 632 |
+
}
|
| 633 |
+
},
|
| 634 |
+
"llama2-chat+13b": {
|
| 635 |
+
"model_id": "llama2-chat+13b",
|
| 636 |
+
"names": ["Llama-2 Chat 13B"],
|
| 637 |
+
"description": {
|
| 638 |
+
"name": "Llama-2 Chat 13B",
|
| 639 |
+
"optimization_procedure": "single-stage",
|
| 640 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 641 |
+
"image_processing": "Letterbox",
|
| 642 |
+
"language_model": "Llama-2 Chat 13B",
|
| 643 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 644 |
+
"train_epochs": 1,
|
| 645 |
+
}
|
| 646 |
+
},
|
| 647 |
+
"mistral-v0.1+7b": {
|
| 648 |
+
"model_id": "mistral-v0.1+7b",
|
| 649 |
+
"names": ["Mistral v0.1 7B"],
|
| 650 |
+
"description": {
|
| 651 |
+
"name": "Mistral v0.1 7B",
|
| 652 |
+
"optimization_procedure": "single-stage",
|
| 653 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 654 |
+
"image_processing": "Letterbox",
|
| 655 |
+
"language_model": "Mistral v0.1 7B",
|
| 656 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 657 |
+
"train_epochs": 1,
|
| 658 |
+
}
|
| 659 |
+
},
|
| 660 |
+
"mistral-instruct-v0.1+7b": {
|
| 661 |
+
"model_id": "mistral-instruct-v0.1+7b",
|
| 662 |
+
"names": ["Mistral Instruct v0.1 7B"],
|
| 663 |
+
"description": {
|
| 664 |
+
"name": "Mistral Instruct v0.1 7B",
|
| 665 |
+
"optimization_procedure": "single-stage",
|
| 666 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 667 |
+
"image_processing": "Letterbox",
|
| 668 |
+
"language_model": "Mistral Instruct v0.1 7B",
|
| 669 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 670 |
+
"train_epochs": 1,
|
| 671 |
+
}
|
| 672 |
+
},
|
| 673 |
+
"phi-2+3b": {
|
| 674 |
+
"model_id": "phi-2+3b",
|
| 675 |
+
"names": ["Phi-2 3B"],
|
| 676 |
+
"description": {
|
| 677 |
+
"name": "Phi-2 3B",
|
| 678 |
+
"optimization_procedure": "single-stage",
|
| 679 |
+
"visual_representation": "CLIP ViT-L/14 @ 336px",
|
| 680 |
+
"image_processing": "Letterbox",
|
| 681 |
+
"language_model": "Phi-2 3B",
|
| 682 |
+
"datasets": ["LLaVa v1.5 Instruct"],
|
| 683 |
+
"train_epochs": 1,
|
| 684 |
+
}
|
| 685 |
+
},
|
| 686 |
+
}
|
| 687 |
+
|
| 688 |
+
# Build Global Registry (Model ID, Name) -> Metadata
|
| 689 |
+
GLOBAL_REGISTRY = {name: v for k, v in MODEL_REGISTRY.items() for name in [k] + v["names"]}
|
| 690 |
+
|
| 691 |
+
# fmt: on
|
capvector-oft/prismatic/models/vlas/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .openvla import OpenVLA
|
capvector-oft/prismatic/models/vlas/openvla.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
openvla.py
|
| 3 |
+
|
| 4 |
+
PyTorch Module defining OpenVLA as a lightweight wrapper around a PrismaticVLM; defines custom logic around
|
| 5 |
+
discretizing actions with the ActionTokenizer.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Dict, List, Optional
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from transformers import LlamaTokenizerFast
|
| 14 |
+
|
| 15 |
+
from prismatic.models.vlms.prismatic import PrismaticVLM
|
| 16 |
+
from prismatic.overwatch import initialize_overwatch
|
| 17 |
+
from prismatic.vla.action_tokenizer import ActionTokenizer
|
| 18 |
+
|
| 19 |
+
# Initialize Overwatch =>> Wraps `logging.Logger`
|
| 20 |
+
overwatch = initialize_overwatch(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class OpenVLA(PrismaticVLM):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
*args,
|
| 27 |
+
norm_stats: Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]],
|
| 28 |
+
action_tokenizer: ActionTokenizer,
|
| 29 |
+
**kwargs,
|
| 30 |
+
) -> None:
|
| 31 |
+
super().__init__(*args, **kwargs)
|
| 32 |
+
self.norm_stats = norm_stats
|
| 33 |
+
self.action_tokenizer = action_tokenizer
|
| 34 |
+
|
| 35 |
+
@torch.inference_mode()
|
| 36 |
+
def predict_action(
|
| 37 |
+
self, image: Image, instruction: str, unnorm_key: Optional[str] = None, **kwargs: str
|
| 38 |
+
) -> np.ndarray:
|
| 39 |
+
"""
|
| 40 |
+
Core function for VLA inference; maps input image and task instruction to continuous action (de-tokenizes).
|
| 41 |
+
|
| 42 |
+
@param image: PIL Image as [height, width, 3]
|
| 43 |
+
@param instruction: Task instruction string
|
| 44 |
+
@param unnorm_key: Optional dataset name for retrieving un-normalizing statistics; if None, checks that model
|
| 45 |
+
was trained only on a single dataset, and retrieves those statistics.
|
| 46 |
+
|
| 47 |
+
@return Unnormalized (continuous) action vector --> end-effector deltas.
|
| 48 |
+
"""
|
| 49 |
+
image_transform, tokenizer = self.vision_backbone.image_transform, self.llm_backbone.tokenizer
|
| 50 |
+
|
| 51 |
+
# Build VLA Prompt
|
| 52 |
+
prompt_builder = self.get_prompt_builder()
|
| 53 |
+
prompt_builder.add_turn(role="human", message=f"What action should the robot take to {instruction.lower()}?")
|
| 54 |
+
prompt_text = prompt_builder.get_prompt()
|
| 55 |
+
|
| 56 |
+
# Prepare Inputs
|
| 57 |
+
input_ids = tokenizer(prompt_text, truncation=True, return_tensors="pt").input_ids.to(self.device)
|
| 58 |
+
if isinstance(tokenizer, LlamaTokenizerFast):
|
| 59 |
+
# If the special empty token ('') does not already appear after the colon (':') token in the prompt
|
| 60 |
+
# (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
|
| 61 |
+
if not torch.all(input_ids[:, -1] == 29871):
|
| 62 |
+
input_ids = torch.cat(
|
| 63 |
+
(input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
|
| 64 |
+
)
|
| 65 |
+
else:
|
| 66 |
+
raise ValueError(f"Unsupported `tokenizer` type = {type(tokenizer)}")
|
| 67 |
+
|
| 68 |
+
# Preprocess Image
|
| 69 |
+
pixel_values = image_transform(image)
|
| 70 |
+
if isinstance(pixel_values, torch.Tensor):
|
| 71 |
+
pixel_values = pixel_values[None, ...].to(self.device)
|
| 72 |
+
elif isinstance(pixel_values, dict):
|
| 73 |
+
pixel_values = {k: v[None, ...].to(self.device) for k, v in pixel_values.items()}
|
| 74 |
+
else:
|
| 75 |
+
raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}")
|
| 76 |
+
|
| 77 |
+
# Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()`
|
| 78 |
+
autocast_dtype = self.llm_backbone.half_precision_dtype
|
| 79 |
+
with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training):
|
| 80 |
+
# fmt: off
|
| 81 |
+
generated_ids = super(PrismaticVLM, self).generate(
|
| 82 |
+
input_ids=input_ids, # Shape: [1, seq]
|
| 83 |
+
pixel_values=pixel_values, # Shape: [1, 3, res, res] or Dict[str, ...]
|
| 84 |
+
max_new_tokens=self.get_action_dim(unnorm_key),
|
| 85 |
+
**kwargs
|
| 86 |
+
)
|
| 87 |
+
# fmt: on
|
| 88 |
+
|
| 89 |
+
# Extract predicted action tokens and translate into (normalized) continuous actions
|
| 90 |
+
predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :]
|
| 91 |
+
normalized_actions = self.action_tokenizer.decode_token_ids_to_actions(predicted_action_token_ids.cpu().numpy())
|
| 92 |
+
|
| 93 |
+
# Un-normalize Actions
|
| 94 |
+
action_norm_stats = self.get_action_stats(unnorm_key)
|
| 95 |
+
mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool))
|
| 96 |
+
action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
|
| 97 |
+
actions = np.where(
|
| 98 |
+
mask,
|
| 99 |
+
0.5 * (normalized_actions + 1) * (action_high - action_low) + action_low,
|
| 100 |
+
normalized_actions,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
return actions
|
| 104 |
+
|
| 105 |
+
@staticmethod
|
| 106 |
+
def _check_unnorm_key(norm_stats: Dict, unnorm_key: str) -> str:
|
| 107 |
+
if unnorm_key is None:
|
| 108 |
+
assert len(norm_stats) == 1, (
|
| 109 |
+
f"Your model was trained on more than one dataset, please pass a `unnorm_key` from the following "
|
| 110 |
+
f"options to choose the statistics used for un-normalizing actions: {norm_stats.keys()}"
|
| 111 |
+
)
|
| 112 |
+
unnorm_key = next(iter(norm_stats.keys()))
|
| 113 |
+
|
| 114 |
+
# Error Handling
|
| 115 |
+
assert (
|
| 116 |
+
unnorm_key in norm_stats
|
| 117 |
+
), f"The `unnorm_key` you chose is not in the set of available statistics; choose from: {norm_stats.keys()}"
|
| 118 |
+
|
| 119 |
+
return unnorm_key
|
| 120 |
+
|
| 121 |
+
def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
|
| 122 |
+
"""Dimensionality of the policy's action space."""
|
| 123 |
+
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
|
| 124 |
+
|
| 125 |
+
return len(self.norm_stats[unnorm_key]["action"]["q01"])
|
| 126 |
+
|
| 127 |
+
def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict:
|
| 128 |
+
"""Dimensionality of the policy's action space."""
|
| 129 |
+
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
|
| 130 |
+
|
| 131 |
+
return self.norm_stats[unnorm_key]["action"]
|
capvector-oft/prismatic/models/vlms/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .prismatic import PrismaticVLM
|
capvector-oft/prismatic/models/vlms/base_vlm.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
base_vlm.py
|
| 3 |
+
|
| 4 |
+
Abstract class definition of a Vision-Language Model (VLM), with full annotations of class methods, utility functions,
|
| 5 |
+
and initialization logic. This is mostly to future-proof the codebase; while all our experiments instantiate
|
| 6 |
+
from PrismaticVLM, theoretically, this base class should be general enough to cover almost all models (e.g., IDEFICS,
|
| 7 |
+
PALI, Fuyu) in the future.
|
| 8 |
+
|
| 9 |
+
We use Abstract base classes *sparingly* -- mostly as a way to encapsulate any redundant logic or nested inheritance
|
| 10 |
+
(e.g., dependence on nn.Module, HF PretrainedModel, etc.). For other abstract objects (e.g., Tokenizers/Transforms),
|
| 11 |
+
prefer Protocol definitions instead.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from abc import ABC, abstractmethod
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Callable, List, Optional
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
from transformers import GenerationMixin, PretrainedConfig
|
| 23 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 24 |
+
|
| 25 |
+
from prismatic.models.backbones.llm import LLMBackbone
|
| 26 |
+
from prismatic.models.backbones.llm.prompting import PromptBuilder
|
| 27 |
+
from prismatic.models.backbones.vision import VisionBackbone
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# === Abstract Base Class for arbitrary Vision-Language Models ===
|
| 31 |
+
class VLM(nn.Module, GenerationMixin, ABC):
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
model_family: str,
|
| 35 |
+
model_id: str,
|
| 36 |
+
vision_backbone: VisionBackbone,
|
| 37 |
+
llm_backbone: LLMBackbone,
|
| 38 |
+
enable_mixed_precision_training: bool = True,
|
| 39 |
+
) -> None:
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.model_family, self.model_id = model_family, model_id
|
| 42 |
+
self.vision_backbone, self.llm_backbone = vision_backbone, llm_backbone
|
| 43 |
+
self.enable_mixed_precision_training = enable_mixed_precision_training
|
| 44 |
+
|
| 45 |
+
# Instance Attributes for a generic VLM
|
| 46 |
+
self.all_module_keys, self.trainable_module_keys = None, None
|
| 47 |
+
|
| 48 |
+
# === GenerationMixin Expected Attributes =>> *DO NOT MODIFY* ===
|
| 49 |
+
self.generation_config = self.llm_backbone.llm.generation_config
|
| 50 |
+
self.main_input_name = "input_ids"
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def device(self) -> torch.device:
|
| 54 |
+
"""Borrowed from `transformers.modeling_utils.py` -- checks parameter device; assumes model on *ONE* device!"""
|
| 55 |
+
return next(self.parameters()).device
|
| 56 |
+
|
| 57 |
+
@classmethod
|
| 58 |
+
@abstractmethod
|
| 59 |
+
def from_pretrained(
|
| 60 |
+
cls,
|
| 61 |
+
pretrained_checkpoint: Path,
|
| 62 |
+
model_family: str,
|
| 63 |
+
model_id: str,
|
| 64 |
+
vision_backbone: VisionBackbone,
|
| 65 |
+
llm_backbone: LLMBackbone,
|
| 66 |
+
**kwargs: str,
|
| 67 |
+
) -> VLM: ...
|
| 68 |
+
|
| 69 |
+
@abstractmethod
|
| 70 |
+
def get_prompt_builder(self, system_prompt: Optional[str] = None) -> PromptBuilder: ...
|
| 71 |
+
|
| 72 |
+
@abstractmethod
|
| 73 |
+
def freeze_backbones(self, stage: str) -> None: ...
|
| 74 |
+
|
| 75 |
+
@abstractmethod
|
| 76 |
+
def load_from_checkpoint(self, stage: str, run_dir: Path, pretrained_checkpoint: Optional[Path] = None) -> None: ...
|
| 77 |
+
|
| 78 |
+
@abstractmethod
|
| 79 |
+
def get_fsdp_wrapping_policy(self) -> Callable: ...
|
| 80 |
+
|
| 81 |
+
@abstractmethod
|
| 82 |
+
def forward(
|
| 83 |
+
self,
|
| 84 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 85 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 86 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 87 |
+
labels: Optional[torch.LongTensor] = None,
|
| 88 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 89 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 90 |
+
use_cache: Optional[bool] = None,
|
| 91 |
+
output_attentions: Optional[bool] = None,
|
| 92 |
+
output_hidden_states: Optional[bool] = None,
|
| 93 |
+
return_dict: Optional[bool] = None,
|
| 94 |
+
multimodal_indices: Optional[torch.LongTensor] = None,
|
| 95 |
+
) -> CausalLMOutputWithPast: ...
|
| 96 |
+
|
| 97 |
+
# === GenerationMixin Expected Properties & Methods (DO NOT MODIFY) ===
|
| 98 |
+
@staticmethod
|
| 99 |
+
def can_generate() -> bool:
|
| 100 |
+
return True
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def config(self) -> PretrainedConfig:
|
| 104 |
+
return self.llm_backbone.llm.config
|
| 105 |
+
|
| 106 |
+
# => Beam Search Utility
|
| 107 |
+
def _reorder_cache(self, past_key_values, beam_idx):
|
| 108 |
+
return self.llm_backbone.llm._reorder_cache(past_key_values, beam_idx)
|
capvector-oft/prismatic/models/vlms/prismatic.py
ADDED
|
@@ -0,0 +1,621 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
prismatic.py
|
| 3 |
+
|
| 4 |
+
PyTorch Module defining a PrismaticVLM, our general interface for defining the various different VLMs in our work.
|
| 5 |
+
|
| 6 |
+
Notes:
|
| 7 |
+
- For now, we don't subclass `transformers.PretrainedModel` (or CausalLM). Instead, we assume a very limited subset
|
| 8 |
+
of the {Model}ForCausalLM API that enables dispatch to the underlying LLM's `generate` utilities (feeding inputs
|
| 9 |
+
through our custom projection shim).
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
from functools import partial
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Callable, Dict, List, Optional, Type, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from PIL import Image
|
| 20 |
+
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy
|
| 21 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 22 |
+
|
| 23 |
+
from prismatic.models.backbones.llm import LLMBackbone
|
| 24 |
+
from prismatic.models.backbones.llm.prompting import PromptBuilder
|
| 25 |
+
from prismatic.models.backbones.vision import VisionBackbone
|
| 26 |
+
from prismatic.models.vlms.base_vlm import VLM
|
| 27 |
+
from prismatic.overwatch import initialize_overwatch
|
| 28 |
+
from prismatic.util.nn_utils import FusedMLPProjector, LinearProjector, MLPProjector
|
| 29 |
+
|
| 30 |
+
# Initialize Overwatch =>> Wraps `logging.Logger`
|
| 31 |
+
overwatch = initialize_overwatch(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels)
|
| 35 |
+
IGNORE_INDEX = -100
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class PrismaticVLM(VLM):
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
model_id: str,
|
| 42 |
+
vision_backbone: VisionBackbone,
|
| 43 |
+
llm_backbone: LLMBackbone,
|
| 44 |
+
enable_mixed_precision_training: bool = True,
|
| 45 |
+
arch_specifier: str = "gelu-mlp",
|
| 46 |
+
**kwargs,
|
| 47 |
+
) -> None:
|
| 48 |
+
super().__init__(
|
| 49 |
+
"prismatic",
|
| 50 |
+
model_id,
|
| 51 |
+
vision_backbone,
|
| 52 |
+
llm_backbone,
|
| 53 |
+
enable_mixed_precision_training=enable_mixed_precision_training,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Set Weight Initialization Seed for Projector Consistency
|
| 57 |
+
torch.manual_seed(vision_backbone.embed_dim)
|
| 58 |
+
|
| 59 |
+
# Initialize Projection (Adapter) based on `arch_specifier`
|
| 60 |
+
self.arch_specifier = arch_specifier
|
| 61 |
+
if arch_specifier == "linear":
|
| 62 |
+
self.projector = LinearProjector(vision_backbone.embed_dim, llm_backbone.embed_dim)
|
| 63 |
+
elif arch_specifier.endswith("fused-gelu-mlp"):
|
| 64 |
+
self.projector = FusedMLPProjector(vision_backbone.embed_dim, llm_backbone.embed_dim)
|
| 65 |
+
elif arch_specifier.endswith("gelu-mlp"):
|
| 66 |
+
self.projector = MLPProjector(vision_backbone.embed_dim, llm_backbone.embed_dim)
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError(f"PrismaticVLM with `{arch_specifier = }` is not supported!")
|
| 69 |
+
|
| 70 |
+
# Trackers
|
| 71 |
+
self.vision_backbone_requires_grad = False
|
| 72 |
+
|
| 73 |
+
# Set Module Keys =>> used in Checkpoint Saving / Model Loading
|
| 74 |
+
self.all_module_keys = ["vision_backbone", "llm_backbone", "projector"]
|
| 75 |
+
self.trainable_module_keys = []
|
| 76 |
+
|
| 77 |
+
# === Generation Utilities ===
|
| 78 |
+
# => For computing likelihoods --> get tokens corresponding to "True", "False" and "Yes", "No"
|
| 79 |
+
self.string2idx = {}
|
| 80 |
+
for trigger_string in ["True", "False", "Yes", "No"] + [chr(ord("A") + i) for i in range(26)]:
|
| 81 |
+
token_idx_list = self.llm_backbone.tokenizer.encode(trigger_string, add_special_tokens=False)
|
| 82 |
+
assert len(token_idx_list) == 1, f'String "{trigger_string}" is tokenized as more than one token!'
|
| 83 |
+
self.string2idx[trigger_string] = token_idx_list[0]
|
| 84 |
+
|
| 85 |
+
@classmethod
|
| 86 |
+
def from_pretrained(
|
| 87 |
+
cls,
|
| 88 |
+
pretrained_checkpoint: Path,
|
| 89 |
+
model_id: str,
|
| 90 |
+
vision_backbone: VisionBackbone,
|
| 91 |
+
llm_backbone: LLMBackbone,
|
| 92 |
+
enable_mixed_precision_training: bool = True,
|
| 93 |
+
arch_specifier: str = "gelu-mlp",
|
| 94 |
+
freeze_weights: bool = True,
|
| 95 |
+
**kwargs,
|
| 96 |
+
) -> PrismaticVLM:
|
| 97 |
+
"""Initialize a PrismaticVLM from a pretrained checkpoint, freezing all weights, tailored for inference."""
|
| 98 |
+
vlm = cls(
|
| 99 |
+
model_id,
|
| 100 |
+
vision_backbone,
|
| 101 |
+
llm_backbone,
|
| 102 |
+
enable_mixed_precision_training=enable_mixed_precision_training,
|
| 103 |
+
arch_specifier=arch_specifier,
|
| 104 |
+
**kwargs,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Load from Checkpoint (Custom --> should load both *projector* and *llm* weights)
|
| 108 |
+
model_state_dict = torch.load(pretrained_checkpoint, map_location="cpu")["model"]
|
| 109 |
+
assert (
|
| 110 |
+
"projector" in model_state_dict and "llm_backbone" in model_state_dict
|
| 111 |
+
), "PrismaticVLM `from_pretrained` expects checkpoint with keys for `projector` AND `llm_backbone`!"
|
| 112 |
+
|
| 113 |
+
vlm.projector.load_state_dict(model_state_dict["projector"])
|
| 114 |
+
vlm.llm_backbone.load_state_dict(model_state_dict["llm_backbone"])
|
| 115 |
+
if "vision_backbone" in model_state_dict.keys():
|
| 116 |
+
vlm.vision_backbone.load_state_dict(model_state_dict["vision_backbone"])
|
| 117 |
+
|
| 118 |
+
# Freeze Weights
|
| 119 |
+
if freeze_weights:
|
| 120 |
+
vlm.requires_grad_(False)
|
| 121 |
+
vlm.eval()
|
| 122 |
+
|
| 123 |
+
return vlm
|
| 124 |
+
|
| 125 |
+
def get_prompt_builder(self, system_prompt: Optional[str] = None) -> PromptBuilder:
|
| 126 |
+
prompt_initializer: Type[PromptBuilder] = self.llm_backbone.prompt_builder_fn
|
| 127 |
+
return prompt_initializer(self.model_family, system_prompt=system_prompt)
|
| 128 |
+
|
| 129 |
+
def freeze_backbones(self, stage: str) -> None:
|
| 130 |
+
"""
|
| 131 |
+
This function sets `requires_grad_` on each of the component modules explicitly, depending on stage.
|
| 132 |
+
|
| 133 |
+
We support two separate stages --> "align" and "finetune".
|
| 134 |
+
=> "align" --> vision_backbone*, llm_backbone* are frozen; only the `projector` is trained.
|
| 135 |
+
=> "finetune" --> vision_backbone* is frozen; both `projector` and `llm_backbone` are trained.
|
| 136 |
+
|
| 137 |
+
:param stage: Pretraining stage in < "align" | "finetune" | "full-finetune" | "vla-train" | "vla-full-train" >
|
| 138 |
+
"""
|
| 139 |
+
if stage == "align":
|
| 140 |
+
self.vision_backbone.requires_grad_(False)
|
| 141 |
+
self.llm_backbone.requires_grad_(False)
|
| 142 |
+
self.projector.requires_grad_(True)
|
| 143 |
+
|
| 144 |
+
# Add to `self.trainable_module_keys`
|
| 145 |
+
self.trainable_module_keys = ["projector"]
|
| 146 |
+
|
| 147 |
+
# Update Trackers
|
| 148 |
+
self.vision_backbone_requires_grad = False
|
| 149 |
+
|
| 150 |
+
# Explicitly Log Frozen / Trainable Components
|
| 151 |
+
overwatch.info(f"[Frozen] 🥶 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1)
|
| 152 |
+
overwatch.info(f"[Frozen] 🥶 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1)
|
| 153 |
+
overwatch.info(f"[TRAINABLE] 🔥 =>> Projector `{self.arch_specifier}`", ctx_level=1)
|
| 154 |
+
|
| 155 |
+
elif stage in {"finetune", "vla-train"}:
|
| 156 |
+
self.vision_backbone.requires_grad_(False)
|
| 157 |
+
self.llm_backbone.requires_grad_(True)
|
| 158 |
+
self.projector.requires_grad_(True)
|
| 159 |
+
|
| 160 |
+
# Add to `self.trainable_module_keys`
|
| 161 |
+
self.trainable_module_keys = ["projector", "llm_backbone"]
|
| 162 |
+
|
| 163 |
+
# Update Trackers
|
| 164 |
+
self.vision_backbone_requires_grad = False
|
| 165 |
+
|
| 166 |
+
# Explicitly Log Frozen / Unfrozen Components
|
| 167 |
+
overwatch.info(f"[Frozen] 🥶 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1)
|
| 168 |
+
overwatch.info(f"[TRAINABLE] 🔥 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1)
|
| 169 |
+
overwatch.info(f"[TRAINABLE] 🔥 =>> Projector `{self.arch_specifier}`", ctx_level=1)
|
| 170 |
+
|
| 171 |
+
elif stage in {"full-finetune", "vla-full-train"}:
|
| 172 |
+
self.vision_backbone.dtype = torch.float32
|
| 173 |
+
self.vision_backbone.requires_grad_(True)
|
| 174 |
+
self.llm_backbone.requires_grad_(True)
|
| 175 |
+
self.projector.requires_grad_(True)
|
| 176 |
+
|
| 177 |
+
# Add to `self.trainable_module_keys`
|
| 178 |
+
self.trainable_module_keys = ["vision_backbone", "projector", "llm_backbone"]
|
| 179 |
+
|
| 180 |
+
# Update Trackers
|
| 181 |
+
self.vision_backbone_requires_grad = True
|
| 182 |
+
|
| 183 |
+
# Explicitly Log Frozen / Unfrozen Components
|
| 184 |
+
overwatch.info(f"[TRAINABLE] 🔥 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1)
|
| 185 |
+
overwatch.info(f"[TRAINABLE] 🔥 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1)
|
| 186 |
+
overwatch.info(f"[TRAINABLE] 🔥 =>> Projector `{self.arch_specifier}`", ctx_level=1)
|
| 187 |
+
|
| 188 |
+
elif stage in {"last-layer-finetune", "vla-last-layer-train"}:
|
| 189 |
+
self.vision_backbone.requires_grad_(False)
|
| 190 |
+
self.projector.requires_grad_(False)
|
| 191 |
+
self.llm_backbone.requires_grad_(False)
|
| 192 |
+
|
| 193 |
+
# Unfreeze final LLM layer
|
| 194 |
+
for module in self.llm_backbone.last_layer_finetune_modules:
|
| 195 |
+
module.requires_grad_(True)
|
| 196 |
+
|
| 197 |
+
# Add to `self.trainable_module_keys`
|
| 198 |
+
self.trainable_module_keys = ["llm_backbone"]
|
| 199 |
+
|
| 200 |
+
# Update Trackers
|
| 201 |
+
self.vision_backbone_requires_grad = False
|
| 202 |
+
|
| 203 |
+
# Explicitly Log Frozen / Unfrozen Components
|
| 204 |
+
# fmt: off
|
| 205 |
+
overwatch.info(f"[Frozen] 🥶 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1) # noqa: E501
|
| 206 |
+
overwatch.info(f"[Frozen, except last layer] 🥶🔥 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1) # noqa: E501
|
| 207 |
+
overwatch.info(f"[Frozen] 🥶 =>> Projector `{self.arch_specifier}`", ctx_level=1)
|
| 208 |
+
# fmt: on
|
| 209 |
+
|
| 210 |
+
elif stage in {"vla-sandwich-train"}:
|
| 211 |
+
self.vision_backbone.dtype = torch.float32
|
| 212 |
+
self.vision_backbone.requires_grad_(True)
|
| 213 |
+
self.projector.requires_grad_(True)
|
| 214 |
+
self.llm_backbone.requires_grad_(False)
|
| 215 |
+
|
| 216 |
+
# Unfreeze final LLM layer
|
| 217 |
+
for module in self.llm_backbone.last_layer_finetune_modules:
|
| 218 |
+
module.requires_grad_(True)
|
| 219 |
+
|
| 220 |
+
# Add to `self.trainable_module_keys`
|
| 221 |
+
self.trainable_module_keys = ["vision_backbone", "projector", "llm_backbone"]
|
| 222 |
+
|
| 223 |
+
# Update Trackers
|
| 224 |
+
self.vision_backbone_requires_grad = True
|
| 225 |
+
|
| 226 |
+
# Explicitly Log Frozen / Unfrozen Components
|
| 227 |
+
# fmt: off
|
| 228 |
+
overwatch.info(f"[TRAINABLE] 🔥 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1) # noqa: E501
|
| 229 |
+
overwatch.info(f"[Frozen, except last layer] 🥶🔥 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1) # noqa: E501
|
| 230 |
+
overwatch.info(f"[TRAINABLE] 🔥 =>> Projector `{self.arch_specifier}`", ctx_level=1)
|
| 231 |
+
# fmt: on
|
| 232 |
+
|
| 233 |
+
else:
|
| 234 |
+
raise ValueError(f"Stage `{stage}` is not supported for LLaVa! Try < align | finetune >")
|
| 235 |
+
|
| 236 |
+
overwatch.debug("##################################################")
|
| 237 |
+
overwatch.debug("##### Trainable Network Parameters: #####")
|
| 238 |
+
overwatch.debug("##################################################")
|
| 239 |
+
for name, param in self.named_parameters():
|
| 240 |
+
if param.requires_grad:
|
| 241 |
+
overwatch.debug(name)
|
| 242 |
+
|
| 243 |
+
def load_from_checkpoint(self, stage: str, run_dir: Path, pretrained_checkpoint: Optional[Path] = None) -> None:
|
| 244 |
+
"""Load weights from checkpoint (if required by the given stage)."""
|
| 245 |
+
assert stage in {"align", "finetune", "full-finetune"}, f"Stage {stage} is not supported!"
|
| 246 |
+
|
| 247 |
+
# If we're running a `no-align` architecture, we're good!
|
| 248 |
+
if self.arch_specifier.startswith("no-align"):
|
| 249 |
+
overwatch.info(
|
| 250 |
+
f"PrismaticVLM with `{self.arch_specifier = }` does not require pretrained weights!", ctx_level=1
|
| 251 |
+
)
|
| 252 |
+
return
|
| 253 |
+
|
| 254 |
+
# Otherwise, handle stage-specific logic!
|
| 255 |
+
if stage == "align":
|
| 256 |
+
overwatch.info("Stage `align` does not require pretrained weights =>> Starting Training", ctx_level=1)
|
| 257 |
+
return
|
| 258 |
+
|
| 259 |
+
# Otherwise, load from `pretrained_checkpoint` or match on `run_dir` (s/+stage-finetune/+stage-align/g)
|
| 260 |
+
overwatch.info("Stage `finetune` requires `align` pretrained weights", ctx_level=1)
|
| 261 |
+
|
| 262 |
+
# Config specifies path to a checkpoint to load
|
| 263 |
+
if pretrained_checkpoint is not None:
|
| 264 |
+
overwatch.info(f"Loading from Provided Checkpoint `{pretrained_checkpoint}`", ctx_level=1)
|
| 265 |
+
model_state_dict = torch.load(pretrained_checkpoint)["model"]
|
| 266 |
+
self.projector.load_state_dict(model_state_dict["projector"])
|
| 267 |
+
|
| 268 |
+
return
|
| 269 |
+
|
| 270 |
+
# [Contract] If no `pretrained_checkpoint`, assume `align` lives in the run directory; string substitution!
|
| 271 |
+
model, scale, _, seed = run_dir.name.split("+")
|
| 272 |
+
align_dirs = [
|
| 273 |
+
d
|
| 274 |
+
for d in run_dir.parent.iterdir()
|
| 275 |
+
if (d.name.startswith(f"{model}+{scale}") and d.name.endswith(f"+stage-align+{seed}"))
|
| 276 |
+
]
|
| 277 |
+
assert len(align_dirs) == 1, "Multiple or No Valid Pretrained Directories Exist -- Double Check `runs`!"
|
| 278 |
+
if (pretrained_checkpoint := (align_dirs[0] / "checkpoints" / "latest-checkpoint.pt")).exists():
|
| 279 |
+
overwatch.info(f"Loading from Discovered Checkpoint `{pretrained_checkpoint}`", ctx_level=1)
|
| 280 |
+
model_state_dict = torch.load(pretrained_checkpoint)["model"]
|
| 281 |
+
self.projector.load_state_dict(model_state_dict["projector"])
|
| 282 |
+
else:
|
| 283 |
+
raise ValueError(f"Could not find valid `align` checkpoint at {pretrained_checkpoint}!")
|
| 284 |
+
|
| 285 |
+
def get_fsdp_wrapping_policy(self) -> Callable:
|
| 286 |
+
"""Return an FSDP _or_policy over the policies returned by each individual backbone (and our VLM policy)."""
|
| 287 |
+
vision_fsdp_wrapping_policy = self.vision_backbone.get_fsdp_wrapping_policy()
|
| 288 |
+
llm_fsdp_wrapping_policy = self.llm_backbone.get_fsdp_wrapping_policy()
|
| 289 |
+
|
| 290 |
+
# Get Prismatic Wrapping Policy =>> just a module wrapping policy around `self.projector`
|
| 291 |
+
prismatic_fsdp_wrapping_policy = partial(
|
| 292 |
+
_module_wrap_policy,
|
| 293 |
+
module_classes={LinearProjector, MLPProjector, FusedMLPProjector},
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# Return union (_or_) over constituent policies
|
| 297 |
+
# => Note: there is *not* a fall-through policy; any module that isn't covered by the above constituents will
|
| 298 |
+
# automatically be folded into the root VLM FSDP instance.
|
| 299 |
+
return partial(
|
| 300 |
+
_or_policy,
|
| 301 |
+
policies=[
|
| 302 |
+
vision_fsdp_wrapping_policy,
|
| 303 |
+
llm_fsdp_wrapping_policy,
|
| 304 |
+
prismatic_fsdp_wrapping_policy,
|
| 305 |
+
],
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
# Note =>> We're not explicitly subclassing `PreTrainedModel` because we don't need the bloat; however, `forward()`
|
| 309 |
+
# *must* match the signature of a `{Model}ForCausalLM` so that we can inherit from `GenerationMixin`
|
| 310 |
+
|
| 311 |
+
# ruff: noqa: C901
|
| 312 |
+
def forward(
|
| 313 |
+
self,
|
| 314 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 315 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 316 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 317 |
+
labels: Optional[torch.LongTensor] = None,
|
| 318 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 319 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 320 |
+
use_cache: Optional[bool] = None,
|
| 321 |
+
output_attentions: Optional[bool] = None,
|
| 322 |
+
output_hidden_states: Optional[bool] = None,
|
| 323 |
+
return_dict: Optional[bool] = None,
|
| 324 |
+
multimodal_indices: Optional[torch.LongTensor] = None,
|
| 325 |
+
) -> CausalLMOutputWithPast:
|
| 326 |
+
"""Run a forward pass through the VLM, returning a CausalLMOutputWithPast instance (contains loss)."""
|
| 327 |
+
|
| 328 |
+
# Handle Inference (leverage cache, short-circuit on just LLM forward)
|
| 329 |
+
if input_ids.shape[1] == 1 and past_key_values is not None:
|
| 330 |
+
# We're leveraging the cache, so just redirect to `self.llm_backbone` with `input_ids` and `past_key_values`
|
| 331 |
+
output = self.llm_backbone(
|
| 332 |
+
input_ids=input_ids,
|
| 333 |
+
attention_mask=None,
|
| 334 |
+
position_ids=None,
|
| 335 |
+
past_key_values=past_key_values,
|
| 336 |
+
inputs_embeds=None,
|
| 337 |
+
labels=None,
|
| 338 |
+
use_cache=use_cache,
|
| 339 |
+
output_attentions=output_attentions,
|
| 340 |
+
output_hidden_states=output_hidden_states,
|
| 341 |
+
return_dict=return_dict,
|
| 342 |
+
)
|
| 343 |
+
return output
|
| 344 |
+
|
| 345 |
+
elif input_ids.shape[1] == 1 or pixel_values is None:
|
| 346 |
+
raise RuntimeError("Invalid `forward()` call!")
|
| 347 |
+
|
| 348 |
+
# Handle Multimodal Indices is None --> pretend like the batch is fully multimodal (always image + text)!
|
| 349 |
+
if multimodal_indices is None:
|
| 350 |
+
multimodal_indices = torch.arange(len(input_ids), dtype=torch.long, device=input_ids.device)
|
| 351 |
+
|
| 352 |
+
# Handle Multimodal Indices is Empty (len == 0) --> simple unimodal forward
|
| 353 |
+
elif len(multimodal_indices) == 0:
|
| 354 |
+
return self.llm_backbone(
|
| 355 |
+
input_ids=input_ids,
|
| 356 |
+
attention_mask=attention_mask,
|
| 357 |
+
position_ids=None,
|
| 358 |
+
past_key_values=past_key_values,
|
| 359 |
+
inputs_embeds=None,
|
| 360 |
+
labels=labels,
|
| 361 |
+
use_cache=use_cache,
|
| 362 |
+
output_attentions=output_attentions,
|
| 363 |
+
output_hidden_states=output_hidden_states,
|
| 364 |
+
return_dict=return_dict,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
# Run Visual Feature Extraction
|
| 368 |
+
with torch.set_grad_enabled(self.vision_backbone_requires_grad):
|
| 369 |
+
if isinstance(pixel_values, dict):
|
| 370 |
+
patch_features = self.vision_backbone({k: pixel_values[k][multimodal_indices] for k in pixel_values})
|
| 371 |
+
else:
|
| 372 |
+
patch_features = self.vision_backbone(pixel_values[multimodal_indices])
|
| 373 |
+
|
| 374 |
+
# Projection Logic :: [bsz, num_patches, llm_embed_dim] =>> num_patches = (2 *) (256 + 1) for ViT-L + CLS
|
| 375 |
+
projected_patch_embeddings = self.projector(patch_features)
|
| 376 |
+
projected_patch_attention_mask = None
|
| 377 |
+
if attention_mask is not None:
|
| 378 |
+
projected_patch_attention_mask = torch.full(
|
| 379 |
+
(projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
|
| 380 |
+
True,
|
| 381 |
+
dtype=attention_mask.dtype,
|
| 382 |
+
device=attention_mask.device,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
# Get Input Embeddings from LLM Backbone :: [bsz, input_seq_len, llm_embed_dim]
|
| 386 |
+
input_embeddings = self.llm_backbone.embed_input_ids(input_ids)
|
| 387 |
+
|
| 388 |
+
# Build Multimodal Embeddings (and build resulting attention mask)
|
| 389 |
+
multimodal_embeddings = torch.cat(
|
| 390 |
+
[
|
| 391 |
+
input_embeddings[multimodal_indices, :1, :],
|
| 392 |
+
projected_patch_embeddings,
|
| 393 |
+
input_embeddings[multimodal_indices, 1:, :],
|
| 394 |
+
],
|
| 395 |
+
dim=1,
|
| 396 |
+
)
|
| 397 |
+
multimodal_attention_mask = None
|
| 398 |
+
if attention_mask is not None:
|
| 399 |
+
multimodal_attention_mask = torch.cat(
|
| 400 |
+
[
|
| 401 |
+
attention_mask[multimodal_indices, :1],
|
| 402 |
+
projected_patch_attention_mask,
|
| 403 |
+
attention_mask[multimodal_indices, 1:],
|
| 404 |
+
],
|
| 405 |
+
dim=1,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# [Contract] We assume the first token of `labels` (associated with <BOS>) is already marked as "IGNORE"
|
| 409 |
+
# => We'll ignore the per-token outputs for each of the patch embeddings as well!
|
| 410 |
+
multimodal_labels = None
|
| 411 |
+
if labels is not None:
|
| 412 |
+
projected_patch_labels = torch.full(
|
| 413 |
+
(projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
|
| 414 |
+
IGNORE_INDEX,
|
| 415 |
+
dtype=labels.dtype,
|
| 416 |
+
device=labels.device,
|
| 417 |
+
)
|
| 418 |
+
multimodal_labels = torch.cat(
|
| 419 |
+
[labels[multimodal_indices, :1], projected_patch_labels, labels[multimodal_indices, 1:]], dim=1
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
# === Add Unimodal Handling ===
|
| 423 |
+
|
| 424 |
+
# Create Fused Embeddings, Attention Mask, and Labels by Merging with "unimodal" Inputs (if applicable)
|
| 425 |
+
unimodal_indices = torch.tensor(
|
| 426 |
+
[idx for idx in range(len(input_ids)) if idx not in multimodal_indices],
|
| 427 |
+
dtype=torch.long,
|
| 428 |
+
device=multimodal_indices.device,
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
# No "unimodal" data --> Fused == Multimodal
|
| 432 |
+
if len(unimodal_indices) == 0:
|
| 433 |
+
fused_embeddings = multimodal_embeddings
|
| 434 |
+
fused_attention_mask = multimodal_attention_mask
|
| 435 |
+
fused_labels = multimodal_labels
|
| 436 |
+
|
| 437 |
+
else:
|
| 438 |
+
# Otherwise --> Merge w/ unimodal data
|
| 439 |
+
|
| 440 |
+
# This doesn't matter --> but in the "normal" case this is the embedding of the <PAD> token
|
| 441 |
+
# => NOTE :: Verified that `zeros/randn/empty/<PAD> embedding` all return the same result!
|
| 442 |
+
unimodal_embeddings_pad = torch.zeros(
|
| 443 |
+
(len(unimodal_indices), projected_patch_embeddings.shape[1], input_embeddings.shape[2]),
|
| 444 |
+
dtype=input_embeddings.dtype,
|
| 445 |
+
device=input_embeddings.device,
|
| 446 |
+
)
|
| 447 |
+
unimodal_attention_pad = torch.full(
|
| 448 |
+
(len(unimodal_indices), projected_patch_embeddings.shape[1]),
|
| 449 |
+
False,
|
| 450 |
+
dtype=attention_mask.dtype,
|
| 451 |
+
device=attention_mask.device,
|
| 452 |
+
)
|
| 453 |
+
unimodal_labels_pad = torch.full(
|
| 454 |
+
(len(unimodal_indices), projected_patch_embeddings.shape[1]),
|
| 455 |
+
IGNORE_INDEX,
|
| 456 |
+
dtype=labels.dtype,
|
| 457 |
+
device=labels.device,
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
unimodal_embeddings = torch.cat([input_embeddings[unimodal_indices], unimodal_embeddings_pad], dim=1)
|
| 461 |
+
unimodal_attention_mask = torch.cat([attention_mask[unimodal_indices], unimodal_attention_pad], dim=1)
|
| 462 |
+
unimodal_labels = torch.cat([labels[unimodal_indices], unimodal_labels_pad], dim=1)
|
| 463 |
+
|
| 464 |
+
# Create "Fused" Tensors by Stacking Multimodal & Unimodal
|
| 465 |
+
fused_embeddings = torch.vstack([multimodal_embeddings, unimodal_embeddings])
|
| 466 |
+
fused_attention_mask = torch.vstack([multimodal_attention_mask, unimodal_attention_mask])
|
| 467 |
+
fused_labels = torch.vstack([multimodal_labels, unimodal_labels])
|
| 468 |
+
|
| 469 |
+
# Run LLM Forward --> returns CausalLMOutputWithPast!
|
| 470 |
+
return self.llm_backbone(
|
| 471 |
+
input_ids=None,
|
| 472 |
+
attention_mask=fused_attention_mask,
|
| 473 |
+
position_ids=None,
|
| 474 |
+
past_key_values=past_key_values,
|
| 475 |
+
inputs_embeds=fused_embeddings,
|
| 476 |
+
labels=fused_labels,
|
| 477 |
+
use_cache=use_cache,
|
| 478 |
+
output_attentions=output_attentions,
|
| 479 |
+
output_hidden_states=output_hidden_states,
|
| 480 |
+
return_dict=return_dict,
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
# === GenerationMixin Methods ===
|
| 484 |
+
# => Note: The following methods override the functionality of `transformers.GenerationMixin`; these expect the
|
| 485 |
+
# contract in each of the function signatures, and also expect our `forward` function to roughly take
|
| 486 |
+
# the same arguments as the underlying LLM (see `LlamaModelForCausalLM` as an example)
|
| 487 |
+
|
| 488 |
+
def prepare_inputs_for_generation(
|
| 489 |
+
self,
|
| 490 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 491 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 492 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 493 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 494 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 495 |
+
use_cache: Optional[bool] = None,
|
| 496 |
+
**kwargs: torch.Tensor,
|
| 497 |
+
) -> Dict[str, torch.Tensor]:
|
| 498 |
+
"""Borrowed from `LlamaForCausalLM` --> in general, just handles caching logic during generation."""
|
| 499 |
+
if past_key_values:
|
| 500 |
+
input_ids = input_ids[:, -1:]
|
| 501 |
+
|
| 502 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 503 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 504 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 505 |
+
else:
|
| 506 |
+
model_inputs = {"input_ids": input_ids}
|
| 507 |
+
|
| 508 |
+
# Make sure `pixel_values` are preserved in `model_inputs`
|
| 509 |
+
model_inputs.update(
|
| 510 |
+
{
|
| 511 |
+
"attention_mask": attention_mask,
|
| 512 |
+
"pixel_values": pixel_values,
|
| 513 |
+
"past_key_values": past_key_values,
|
| 514 |
+
"use_cache": use_cache,
|
| 515 |
+
}
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
return model_inputs
|
| 519 |
+
|
| 520 |
+
@torch.inference_mode()
|
| 521 |
+
def generate_batch(
|
| 522 |
+
self,
|
| 523 |
+
pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]],
|
| 524 |
+
texts: List[str],
|
| 525 |
+
return_string_probabilities: Optional[List[str]] = None,
|
| 526 |
+
**kwargs: str,
|
| 527 |
+
) -> Union[List[str], List[List[float]]]:
|
| 528 |
+
# For now, only support generation with a batch size of 1 for simplicity
|
| 529 |
+
tokenizer = self.llm_backbone.tokenizer
|
| 530 |
+
|
| 531 |
+
# Prepare Inputs
|
| 532 |
+
batch_input_ids = [
|
| 533 |
+
tokenizer(text, truncation=True, return_tensors="pt").input_ids.to(self.device) for text in texts
|
| 534 |
+
]
|
| 535 |
+
if isinstance(pixel_values, torch.Tensor):
|
| 536 |
+
pixel_values = pixel_values[None, ...].to(self.device)
|
| 537 |
+
elif isinstance(pixel_values, dict):
|
| 538 |
+
pixel_values = {k: v[None, ...].to(self.device) for k, v in pixel_values.items()}
|
| 539 |
+
else:
|
| 540 |
+
raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}")
|
| 541 |
+
|
| 542 |
+
# Create Output Lists
|
| 543 |
+
gen_texts, gen_probabilities = [], []
|
| 544 |
+
|
| 545 |
+
# Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()`
|
| 546 |
+
autocast_dtype = self.llm_backbone.half_precision_dtype
|
| 547 |
+
with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training):
|
| 548 |
+
for idx, input_ids in enumerate(batch_input_ids):
|
| 549 |
+
if isinstance(pixel_values, torch.Tensor):
|
| 550 |
+
pixel_values = pixel_values[idx]
|
| 551 |
+
elif isinstance(pixel_values, dict):
|
| 552 |
+
pixel_values = {k: pixel_values[k][idx] for k in pixel_values}
|
| 553 |
+
else:
|
| 554 |
+
raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}")
|
| 555 |
+
|
| 556 |
+
# Handle `return_string_probabilities`
|
| 557 |
+
if return_string_probabilities is None:
|
| 558 |
+
full_out_ids = super().generate(input_ids=input_ids, pixel_values=pixel_values, **kwargs)
|
| 559 |
+
gen_ids = full_out_ids[0, input_ids.shape[1] :]
|
| 560 |
+
|
| 561 |
+
# Decode `gen_ids` and strip any <EOS> tokens
|
| 562 |
+
gen_texts.append(tokenizer.decode(gen_ids, skip_special_tokens=True).strip())
|
| 563 |
+
|
| 564 |
+
else:
|
| 565 |
+
full_out_dict = super().generate(
|
| 566 |
+
input_ids=input_ids,
|
| 567 |
+
pixel_values=pixel_values,
|
| 568 |
+
output_scores=True,
|
| 569 |
+
return_dict_in_generate=True,
|
| 570 |
+
**kwargs,
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
# Generation pattern should usually be [TOKEN] <EOS> for True/False and Yes/No Generations
|
| 574 |
+
gen_ids = full_out_dict.sequences[0, input_ids.shape[1] :]
|
| 575 |
+
|
| 576 |
+
# [Debug] Verify that the first token generated is in `self.string2idx.values()`
|
| 577 |
+
# assert gen_ids[0] in self.string2idx.values(), "Generated ID not in mapping!"
|
| 578 |
+
|
| 579 |
+
# Decode `gen_ids` and strip any <EOS> tokens
|
| 580 |
+
gen_texts.append(tokenizer.decode(gen_ids, skip_special_tokens=True).strip())
|
| 581 |
+
|
| 582 |
+
# Get all token probabilities --> softmax over logits
|
| 583 |
+
token_probs = torch.softmax(full_out_dict.scores[0][0], dim=0)
|
| 584 |
+
|
| 585 |
+
# Get *normalized* probabilities for all values in `return_token_probabilities`
|
| 586 |
+
slice_idxs = torch.tensor([self.string2idx[s] for s in return_string_probabilities])
|
| 587 |
+
string_probs_unnormalized = token_probs[slice_idxs]
|
| 588 |
+
string_probs = string_probs_unnormalized / string_probs_unnormalized.sum()
|
| 589 |
+
gen_probabilities.append(string_probs.cpu().numpy().tolist())
|
| 590 |
+
|
| 591 |
+
return gen_texts if return_string_probabilities is None else gen_probabilities
|
| 592 |
+
|
| 593 |
+
@torch.inference_mode()
|
| 594 |
+
def generate(self, image: Image, prompt_text: str, **kwargs: str) -> str:
|
| 595 |
+
# For now, only support generation with a batch size of 1 for simplicity
|
| 596 |
+
image_transform, tokenizer = self.vision_backbone.image_transform, self.llm_backbone.tokenizer
|
| 597 |
+
|
| 598 |
+
# Prepare Inputs
|
| 599 |
+
input_ids = tokenizer(prompt_text, truncation=True, return_tensors="pt").input_ids.to(self.device)
|
| 600 |
+
pixel_values = image_transform(image)
|
| 601 |
+
if isinstance(pixel_values, torch.Tensor):
|
| 602 |
+
pixel_values = pixel_values[None, ...].to(self.device)
|
| 603 |
+
elif isinstance(pixel_values, dict):
|
| 604 |
+
pixel_values = {k: v[None, ...].to(self.device) for k, v in pixel_values.items()}
|
| 605 |
+
else:
|
| 606 |
+
raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}")
|
| 607 |
+
|
| 608 |
+
# Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()`
|
| 609 |
+
autocast_dtype = self.llm_backbone.half_precision_dtype
|
| 610 |
+
with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training):
|
| 611 |
+
# fmt: off
|
| 612 |
+
generated_ids = super().generate(
|
| 613 |
+
input_ids=input_ids, # Shape: [1, seq]
|
| 614 |
+
pixel_values=pixel_values, # Shape: [1, 3, res, res] or Dict[str, Shape[1, 3, res, res]]
|
| 615 |
+
**kwargs
|
| 616 |
+
)
|
| 617 |
+
# fmt: on
|
| 618 |
+
|
| 619 |
+
generated_text = tokenizer.decode(generated_ids[0, input_ids.shape[1] :], skip_special_tokens=True).strip()
|
| 620 |
+
|
| 621 |
+
return generated_text
|
capvector-oft/prismatic/overwatch/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .overwatch import initialize_overwatch
|
capvector-oft/prismatic/overwatch/overwatch.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
overwatch.py
|
| 3 |
+
|
| 4 |
+
Utility class for creating a centralized/standardized logger (built on Rich) and accelerate handler.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import logging.config
|
| 9 |
+
import os
|
| 10 |
+
from contextlib import nullcontext
|
| 11 |
+
from logging import LoggerAdapter
|
| 12 |
+
from typing import Any, Callable, ClassVar, Dict, MutableMapping, Tuple, Union
|
| 13 |
+
|
| 14 |
+
# Overwatch Default Format String
|
| 15 |
+
RICH_FORMATTER, DATEFMT = "| >> %(message)s", "%m/%d [%H:%M:%S]"
|
| 16 |
+
|
| 17 |
+
# Set Logging Configuration
|
| 18 |
+
LOG_CONFIG = {
|
| 19 |
+
"version": 1,
|
| 20 |
+
"disable_existing_loggers": True,
|
| 21 |
+
"formatters": {"simple-console": {"format": RICH_FORMATTER, "datefmt": DATEFMT}},
|
| 22 |
+
"handlers": {
|
| 23 |
+
"console": {
|
| 24 |
+
"class": "rich.logging.RichHandler",
|
| 25 |
+
"formatter": "simple-console",
|
| 26 |
+
"markup": True,
|
| 27 |
+
"rich_tracebacks": True,
|
| 28 |
+
"show_level": True,
|
| 29 |
+
"show_path": True,
|
| 30 |
+
"show_time": True,
|
| 31 |
+
}
|
| 32 |
+
},
|
| 33 |
+
"root": {"level": "INFO", "handlers": ["console"]},
|
| 34 |
+
}
|
| 35 |
+
logging.config.dictConfig(LOG_CONFIG)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# === Custom Contextual Logging Logic ===
|
| 39 |
+
class ContextAdapter(LoggerAdapter):
|
| 40 |
+
CTX_PREFIXES: ClassVar[Dict[int, str]] = {**{0: "[*] "}, **{idx: "|=> ".rjust(4 + (idx * 4)) for idx in [1, 2, 3]}}
|
| 41 |
+
|
| 42 |
+
def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> Tuple[str, MutableMapping[str, Any]]:
|
| 43 |
+
ctx_level = kwargs.pop("ctx_level", 0)
|
| 44 |
+
return f"{self.CTX_PREFIXES[ctx_level]}{msg}", kwargs
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class DistributedOverwatch:
|
| 48 |
+
def __init__(self, name: str) -> None:
|
| 49 |
+
"""Initializer for an Overwatch object that wraps logging & `accelerate.PartialState`."""
|
| 50 |
+
from accelerate import PartialState
|
| 51 |
+
|
| 52 |
+
# Note that PartialState is always safe to initialize regardless of `accelerate launch` or `torchrun`
|
| 53 |
+
# =>> However, might be worth actually figuring out if we need the `accelerate` dependency at all!
|
| 54 |
+
self.logger, self.distributed_state = ContextAdapter(logging.getLogger(name), extra={}), PartialState()
|
| 55 |
+
|
| 56 |
+
# Logger Delegation (for convenience; would be nice to just compose & dynamic dispatch eventually)
|
| 57 |
+
self.debug = self.logger.debug
|
| 58 |
+
self.info = self.logger.info
|
| 59 |
+
self.warning = self.logger.warning
|
| 60 |
+
self.error = self.logger.error
|
| 61 |
+
self.critical = self.logger.critical
|
| 62 |
+
|
| 63 |
+
# Logging Defaults =>> only Log `INFO` on Main Process, `ERROR` on others!
|
| 64 |
+
self.logger.setLevel(logging.INFO if self.distributed_state.is_main_process else logging.ERROR)
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def rank_zero_only(self) -> Callable[..., Any]:
|
| 68 |
+
return self.distributed_state.on_main_process
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def local_zero_only(self) -> Callable[..., Any]:
|
| 72 |
+
return self.distributed_state.on_local_main_process
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def rank_zero_first(self) -> Callable[..., Any]:
|
| 76 |
+
return self.distributed_state.main_process_first
|
| 77 |
+
|
| 78 |
+
@property
|
| 79 |
+
def local_zero_first(self) -> Callable[..., Any]:
|
| 80 |
+
return self.distributed_state.local_main_process_first
|
| 81 |
+
|
| 82 |
+
def is_rank_zero(self) -> bool:
|
| 83 |
+
return self.distributed_state.is_main_process
|
| 84 |
+
|
| 85 |
+
def rank(self) -> int:
|
| 86 |
+
return self.distributed_state.process_index
|
| 87 |
+
|
| 88 |
+
def local_rank(self) -> int:
|
| 89 |
+
return self.distributed_state.local_process_index
|
| 90 |
+
|
| 91 |
+
def world_size(self) -> int:
|
| 92 |
+
return self.distributed_state.num_processes
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class PureOverwatch:
|
| 96 |
+
def __init__(self, name: str) -> None:
|
| 97 |
+
"""Initializer for an Overwatch object that just wraps logging."""
|
| 98 |
+
self.logger = ContextAdapter(logging.getLogger(name), extra={})
|
| 99 |
+
|
| 100 |
+
# Logger Delegation (for convenience; would be nice to just compose & dynamic dispatch eventually)
|
| 101 |
+
self.debug = self.logger.debug
|
| 102 |
+
self.info = self.logger.info
|
| 103 |
+
self.warning = self.logger.warning
|
| 104 |
+
self.error = self.logger.error
|
| 105 |
+
self.critical = self.logger.critical
|
| 106 |
+
|
| 107 |
+
# Logging Defaults =>> INFO
|
| 108 |
+
self.logger.setLevel(logging.INFO)
|
| 109 |
+
|
| 110 |
+
@staticmethod
|
| 111 |
+
def get_identity_ctx() -> Callable[..., Any]:
|
| 112 |
+
def identity(fn: Callable[..., Any]) -> Callable[..., Any]:
|
| 113 |
+
return fn
|
| 114 |
+
|
| 115 |
+
return identity
|
| 116 |
+
|
| 117 |
+
@property
|
| 118 |
+
def rank_zero_only(self) -> Callable[..., Any]:
|
| 119 |
+
return self.get_identity_ctx()
|
| 120 |
+
|
| 121 |
+
@property
|
| 122 |
+
def local_zero_only(self) -> Callable[..., Any]:
|
| 123 |
+
return self.get_identity_ctx()
|
| 124 |
+
|
| 125 |
+
@property
|
| 126 |
+
def rank_zero_first(self) -> Callable[..., Any]:
|
| 127 |
+
return nullcontext
|
| 128 |
+
|
| 129 |
+
@property
|
| 130 |
+
def local_zero_first(self) -> Callable[..., Any]:
|
| 131 |
+
return nullcontext
|
| 132 |
+
|
| 133 |
+
@staticmethod
|
| 134 |
+
def is_rank_zero() -> bool:
|
| 135 |
+
return True
|
| 136 |
+
|
| 137 |
+
@staticmethod
|
| 138 |
+
def rank() -> int:
|
| 139 |
+
return 0
|
| 140 |
+
|
| 141 |
+
@staticmethod
|
| 142 |
+
def world_size() -> int:
|
| 143 |
+
return 1
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def initialize_overwatch(name: str) -> Union[DistributedOverwatch, PureOverwatch]:
|
| 147 |
+
return DistributedOverwatch(name) if int(os.environ.get("WORLD_SIZE", -1)) != -1 else PureOverwatch(name)
|
capvector-oft/prismatic/preprocessing/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .download import convert_to_jpg, download_extract
|
| 2 |
+
from .materialize import get_dataset_and_collator
|
capvector-oft/prismatic/preprocessing/datasets/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .datasets import AlignDataset, FinetuneDataset
|
capvector-oft/prismatic/preprocessing/datasets/datasets.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
datasets.py
|
| 3 |
+
|
| 4 |
+
PyTorch Dataset Definitions for Prismatic models; supports processing for both the `align` and `finetune` stages, with
|
| 5 |
+
utilities for formatting conversations during the `finetune` stage subject to the given LLM backbone's expected
|
| 6 |
+
formatting (e.g., SYS_PROMPT + USER: ... ASSISTANT: ... for Vicuña v1.5 Chat models).
|
| 7 |
+
|
| 8 |
+
We currently only support Map-style Datasets; assumes that all files (annotations, images) are on local disk, and that
|
| 9 |
+
random access image reading is relatively cheap/fast.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import copy
|
| 13 |
+
import json
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Dict, List, Tuple, Type
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from PIL import Image
|
| 19 |
+
from torch.utils.data import Dataset
|
| 20 |
+
from transformers import CodeGenTokenizerFast, LlamaTokenizerFast, PreTrainedTokenizerBase
|
| 21 |
+
|
| 22 |
+
from prismatic.models.backbones.llm.prompting import PromptBuilder
|
| 23 |
+
from prismatic.models.backbones.vision import ImageTransform
|
| 24 |
+
|
| 25 |
+
# HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels)
|
| 26 |
+
IGNORE_INDEX = -100
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class AlignDataset(Dataset[Dict[str, torch.Tensor]]):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
chat_json: Path,
|
| 33 |
+
image_dir: Path,
|
| 34 |
+
image_transform: ImageTransform,
|
| 35 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 36 |
+
) -> None:
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.chat_json, self.image_dir = chat_json, image_dir
|
| 39 |
+
self.image_transform, self.tokenizer = image_transform, tokenizer
|
| 40 |
+
self.dataset_type = "align"
|
| 41 |
+
|
| 42 |
+
# Create Prompt Template
|
| 43 |
+
self.prompt_template = "{caption}" + self.tokenizer.eos_token
|
| 44 |
+
|
| 45 |
+
# Load Chat JSON
|
| 46 |
+
with open(self.chat_json, "r") as f:
|
| 47 |
+
self.examples = json.load(f)
|
| 48 |
+
|
| 49 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 50 |
+
"""
|
| 51 |
+
Following the *actual* code executed from the LLaVa codebase, during the "align" phase, we actually discard
|
| 52 |
+
the "prompt" from the human, and instead directly predict the caption from the image.
|
| 53 |
+
|
| 54 |
+
As a concrete example given the "raw data" for the first example:
|
| 55 |
+
example = self.examples[0]["conversations"]` = {
|
| 56 |
+
[
|
| 57 |
+
{"from": "human", "value": "Render a clear and concise summary of the photo.\n<image>"},
|
| 58 |
+
{"from": "gpt", "value": "select luxury furniture 3 - inch gel memory foam mattress topper"}
|
| 59 |
+
]
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
Return =>> self.tokenizer("<image> select luxury furniture 3 - inch gel memory foam mattress topper\n")
|
| 63 |
+
|
| 64 |
+
:param idx: Index to retrieve from the dataset.
|
| 65 |
+
|
| 66 |
+
:return: Dictionary of {"pixel_values": torch.Tensor, "input_ids": torch.Tensor, "labels": torch.Tensor}
|
| 67 |
+
"""
|
| 68 |
+
image_path, conversation = Path(self.examples[idx]["image"]), self.examples[idx]["conversations"]
|
| 69 |
+
assert (len(conversation) == 2) and ("<image>" not in conversation[-1]["value"]), "Unexpected text!"
|
| 70 |
+
|
| 71 |
+
# Format Caption --> {caption}{eos_token}
|
| 72 |
+
caption = self.prompt_template.format(caption=conversation[-1]["value"].strip())
|
| 73 |
+
|
| 74 |
+
# We treat image patches as "tokens = [p1 p2 p3, ...]"; we need to specify ordering of text/patch tokens.
|
| 75 |
+
# => Critically, we find that inserting *after* the BOS token leads to the strongest performance!
|
| 76 |
+
# - input_ids = "<s> p1 p2 p3 ... <caption_text> \n"
|
| 77 |
+
# - labels = "IGNORE IGNORE ..." (copy `input_ids` replacing <s> and p{1...K} with IGNORE)
|
| 78 |
+
#
|
| 79 |
+
# IMPORTANT => IF WE'RE USING HF LLM.forward(... labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL!
|
| 80 |
+
input_ids = self.tokenizer(caption, truncation=True, return_tensors="pt").input_ids[0]
|
| 81 |
+
labels = copy.deepcopy(input_ids)
|
| 82 |
+
|
| 83 |
+
# Set the <BOS> token's label to IGNORE_INDEX (since we're inserting the image patches right after)
|
| 84 |
+
labels[0] = IGNORE_INDEX
|
| 85 |
+
|
| 86 |
+
# Process Image --> get "pixel_values" (will either be a torch.Tensor OR a Dict[str,torch.Tensor])
|
| 87 |
+
pixel_values = self.image_transform(Image.open(self.image_dir / image_path).convert("RGB"))
|
| 88 |
+
|
| 89 |
+
return dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels)
|
| 90 |
+
|
| 91 |
+
def get_modality_lengths(self, n_image_patches: int) -> List[Tuple[bool, int]]:
|
| 92 |
+
"""Get a list of modalities (unimodal / text-only vs. multimodal) and length of conversations per example."""
|
| 93 |
+
modality_lengths = []
|
| 94 |
+
for example in self.examples:
|
| 95 |
+
is_multimodal = "image" in example
|
| 96 |
+
n_words = sum([len(turn["value"].replace("<image>", "").split()) for turn in example["conversations"]])
|
| 97 |
+
modality_lengths.append((is_multimodal, (n_image_patches + n_words) if is_multimodal else n_words))
|
| 98 |
+
return modality_lengths
|
| 99 |
+
|
| 100 |
+
def __len__(self) -> int:
|
| 101 |
+
return len(self.examples)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class FinetuneDataset(Dataset[Dict[str, torch.Tensor]]):
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
instruct_json: Path,
|
| 108 |
+
image_dir: Path,
|
| 109 |
+
image_transform: ImageTransform,
|
| 110 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 111 |
+
prompt_builder_fn: Type[PromptBuilder],
|
| 112 |
+
) -> None:
|
| 113 |
+
super().__init__()
|
| 114 |
+
self.instruct_json, self.image_dir = instruct_json, image_dir
|
| 115 |
+
self.image_transform, self.tokenizer = image_transform, tokenizer
|
| 116 |
+
self.prompt_builder_fn = prompt_builder_fn
|
| 117 |
+
self.dataset_type = "finetune"
|
| 118 |
+
|
| 119 |
+
# Load Instruct JSON
|
| 120 |
+
with open(self.instruct_json, "r") as f:
|
| 121 |
+
self.examples = json.load(f)
|
| 122 |
+
|
| 123 |
+
# === Unimodal + Multimodal Handling ===
|
| 124 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 125 |
+
"""
|
| 126 |
+
Unlike the *align* stage handling, for the *finetune* stage, we actually need to handle multiple "turns" of
|
| 127 |
+
dialog grounded in a single image.
|
| 128 |
+
|
| 129 |
+
To do this, we leverage the `prompt_builder_fn` which instantiates a PromptBuilder object. By calling the
|
| 130 |
+
methods for adding turns and getting a prompt, we ensure proper formatting and consistency for each example.
|
| 131 |
+
|
| 132 |
+
:param idx: Index to retrieve from the dataset.
|
| 133 |
+
|
| 134 |
+
:return: Dictionary of {"pixel_values": torch.Tensor, "input_ids": torch.Tensor, "labels": torch.Tensor}
|
| 135 |
+
"""
|
| 136 |
+
conversation = self.examples[idx]["conversations"]
|
| 137 |
+
|
| 138 |
+
# Create Prompt Builder --> add each message sequentially
|
| 139 |
+
prompt_builder, input_ids, labels = self.prompt_builder_fn(model_family="prismatic"), [], []
|
| 140 |
+
for turn_idx, turn in enumerate(conversation):
|
| 141 |
+
# Get "effective" string added to prompt --> handle whitespace for tokenizer type!
|
| 142 |
+
msg = prompt_builder.add_turn(turn["from"], turn["value"])
|
| 143 |
+
|
| 144 |
+
# Llama Tokenizer (Fast) adds extra character if a string ends in whitespace --> strip if non-empty!
|
| 145 |
+
if isinstance(self.tokenizer, LlamaTokenizerFast):
|
| 146 |
+
msg = msg.rstrip()
|
| 147 |
+
|
| 148 |
+
# Phi-2 Tokenizer == CodeGenTokenizer (Fast) -- no special handling!
|
| 149 |
+
elif isinstance(self.tokenizer, CodeGenTokenizerFast):
|
| 150 |
+
pass
|
| 151 |
+
|
| 152 |
+
else:
|
| 153 |
+
raise ValueError(f"Tokenizer of type `{type(self.tokenizer)}` is not explicitly handled!")
|
| 154 |
+
|
| 155 |
+
# Tokenize Input IDs
|
| 156 |
+
turn_input_ids = self.tokenizer(msg, add_special_tokens=turn_idx == 0).input_ids
|
| 157 |
+
|
| 158 |
+
# [CRITICAL] We do not want to take the loss for the "USER: <msg>" prompts =>> just the responses!
|
| 159 |
+
turn_labels = (
|
| 160 |
+
[IGNORE_INDEX for _ in range(len(turn_input_ids))] if (turn_idx % 2) == 0 else list(turn_input_ids)
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# Add to Trackers
|
| 164 |
+
input_ids.extend(turn_input_ids)
|
| 165 |
+
labels.extend(turn_labels)
|
| 166 |
+
|
| 167 |
+
# Tensorize =>> Set the <BOS> token's label to IGNORE_INDEX (since we're inserting the image patches after)
|
| 168 |
+
# - IMPORTANT => IF WE'RE USING HF LLM.forward(... labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL!
|
| 169 |
+
input_ids, labels = torch.tensor(input_ids), torch.tensor(labels)
|
| 170 |
+
|
| 171 |
+
# Handle Truncation (if necessary)
|
| 172 |
+
input_ids, labels = input_ids[: self.tokenizer.model_max_length], labels[: self.tokenizer.model_max_length]
|
| 173 |
+
|
| 174 |
+
# === Handle "unimodal" (language-only) vs. "multimodal" ===
|
| 175 |
+
if "image" in self.examples[idx]:
|
| 176 |
+
image_path = Path(self.examples[idx]["image"])
|
| 177 |
+
|
| 178 |
+
# Set the <BOS> token's label to IGNORE_INDEX (since we're inserting the image patches right after)
|
| 179 |
+
labels[0] = IGNORE_INDEX
|
| 180 |
+
|
| 181 |
+
# Process Image --> get "pixel_values" (will either be a torch.Tensor OR a Dict[str,torch.Tensor])
|
| 182 |
+
pixel_values = self.image_transform(Image.open(self.image_dir / image_path).convert("RGB"))
|
| 183 |
+
|
| 184 |
+
return dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels)
|
| 185 |
+
|
| 186 |
+
else:
|
| 187 |
+
# No image --> return `pixel_values` = None; Collator will do the smart batch handling for us!
|
| 188 |
+
return dict(pixel_values=None, input_ids=input_ids, labels=labels)
|
| 189 |
+
|
| 190 |
+
def get_modality_lengths(self) -> List[Tuple[bool, int]]:
|
| 191 |
+
"""Get a list of modalities (unimodal / text-only vs. multimodal) and length of conversations per example."""
|
| 192 |
+
modality_lengths = []
|
| 193 |
+
for example in self.examples:
|
| 194 |
+
is_multimodal = "image" in example
|
| 195 |
+
n_words = sum([len(turn["value"].split()) for turn in example["conversations"]])
|
| 196 |
+
modality_lengths.append((is_multimodal, n_words))
|
| 197 |
+
return modality_lengths
|
| 198 |
+
|
| 199 |
+
def __len__(self) -> int:
|
| 200 |
+
return len(self.examples)
|
capvector-oft/prismatic/preprocessing/download.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
download.py
|
| 3 |
+
|
| 4 |
+
Utility functions for downloading and extracting various datasets to (local) disk.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import shutil
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Dict, List, TypedDict
|
| 11 |
+
from zipfile import ZipFile
|
| 12 |
+
|
| 13 |
+
import requests
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from rich.progress import BarColumn, DownloadColumn, MofNCompleteColumn, Progress, TextColumn, TransferSpeedColumn
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
|
| 18 |
+
from prismatic.overwatch import initialize_overwatch
|
| 19 |
+
|
| 20 |
+
# Initialize Overwatch =>> Wraps `logging.Logger`
|
| 21 |
+
overwatch = initialize_overwatch(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# === Dataset Registry w/ Links ===
|
| 25 |
+
# fmt: off
|
| 26 |
+
DatasetComponent = TypedDict(
|
| 27 |
+
"DatasetComponent",
|
| 28 |
+
{"name": str, "extract": bool, "extract_type": str, "url": str, "do_rename": bool},
|
| 29 |
+
total=False
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
DATASET_REGISTRY: Dict[str, List[DatasetComponent]] = {
|
| 33 |
+
# === LLaVa v1.5 Dataset(s) ===
|
| 34 |
+
|
| 35 |
+
# Note =>> This is the full suite of datasets included in the LLaVa 1.5 "finetuning" stage; all the LLaVa v1.5
|
| 36 |
+
# models are finetuned on this split. We use this dataset for all experiments in our paper.
|
| 37 |
+
"llava-laion-cc-sbu-558k": [
|
| 38 |
+
{
|
| 39 |
+
"name": "chat.json", # Contains the "chat" traces :: {"human" => <prompt>, "gpt" => <caption>}
|
| 40 |
+
"extract": False,
|
| 41 |
+
"url": "https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/resolve/main/blip_laion_cc_sbu_558k.json",
|
| 42 |
+
"do_rename": True,
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"name": "images", # Contains the LLaVa Processed Images (jpgs, 224x224 resolution)
|
| 46 |
+
"extract": True,
|
| 47 |
+
"extract_type": "directory",
|
| 48 |
+
"url": "https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/resolve/main/images.zip",
|
| 49 |
+
"do_rename": False,
|
| 50 |
+
}
|
| 51 |
+
],
|
| 52 |
+
|
| 53 |
+
"llava-v1.5-instruct": [
|
| 54 |
+
{
|
| 55 |
+
"name": "llava_v1_5_mix665k.json",
|
| 56 |
+
"extract": False,
|
| 57 |
+
"url": (
|
| 58 |
+
"https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/resolve/main/llava_v1_5_mix665k.json"
|
| 59 |
+
),
|
| 60 |
+
"do_rename": True,
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"name": "coco/train2017", # Visual Instruct Tuning images are all sourced from COCO Train 2017
|
| 64 |
+
"extract": True,
|
| 65 |
+
"extract_type": "directory",
|
| 66 |
+
"url": "http://images.cocodataset.org/zips/train2017.zip",
|
| 67 |
+
"do_rename": True,
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"name": "gqa/images",
|
| 71 |
+
"extract": True,
|
| 72 |
+
"extract_type": "directory",
|
| 73 |
+
"url": "https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip",
|
| 74 |
+
"do_rename": True,
|
| 75 |
+
},
|
| 76 |
+
{
|
| 77 |
+
"name": "ocr_vqa/images",
|
| 78 |
+
"extract": True,
|
| 79 |
+
"extract_type": "directory",
|
| 80 |
+
"url": "https://huggingface.co/datasets/qnguyen3/ocr_vqa/resolve/main/ocr_vqa.zip",
|
| 81 |
+
"do_rename": True,
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"name": "textvqa/train_images",
|
| 85 |
+
"extract": True,
|
| 86 |
+
"extract_type": "directory",
|
| 87 |
+
"url": "https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip",
|
| 88 |
+
"do_rename": True,
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"name": "vg/VG_100K",
|
| 92 |
+
"extract": True,
|
| 93 |
+
"extract_type": "directory",
|
| 94 |
+
"url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip",
|
| 95 |
+
"do_rename": True,
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"name": "vg/VG_100K_2",
|
| 99 |
+
"extract": True,
|
| 100 |
+
"extract_type": "directory",
|
| 101 |
+
"url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip",
|
| 102 |
+
"do_rename": True,
|
| 103 |
+
},
|
| 104 |
+
]
|
| 105 |
+
}
|
| 106 |
+
# fmt: on
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def convert_to_jpg(image_dir: Path) -> None:
|
| 110 |
+
"""Handling for OCR-VQA Images specifically; iterates through directory, converts all GIFs/PNGs."""
|
| 111 |
+
overwatch.info(f"Converting all Images in `{image_dir}` to JPG")
|
| 112 |
+
|
| 113 |
+
for image_fn in tqdm(list(image_dir.iterdir())):
|
| 114 |
+
if image_fn.suffix in {".jpg", ".jpeg"} or (jpg_fn := image_dir / f"{image_fn.stem}.jpg").exists():
|
| 115 |
+
continue
|
| 116 |
+
|
| 117 |
+
if image_fn.suffix == ".gif":
|
| 118 |
+
gif = Image.open(image_fn)
|
| 119 |
+
gif.seek(0)
|
| 120 |
+
gif.convert("RGB").save(jpg_fn)
|
| 121 |
+
elif image_fn.suffix == ".png":
|
| 122 |
+
Image.open(image_fn).convert("RGB").save(jpg_fn)
|
| 123 |
+
else:
|
| 124 |
+
raise ValueError(f"Unexpected image format `{image_fn.suffix}`")
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def download_with_progress(url: str, download_dir: Path, chunk_size_bytes: int = 1024) -> Path:
|
| 128 |
+
"""Utility function for downloading files from the internet, with a handy Rich-based progress bar."""
|
| 129 |
+
overwatch.info(f"Downloading {(dest_path := download_dir / Path(url).name)} from `{url}`", ctx_level=1)
|
| 130 |
+
if dest_path.exists():
|
| 131 |
+
return dest_path
|
| 132 |
+
|
| 133 |
+
# Otherwise --> fire an HTTP Request, with `stream = True`
|
| 134 |
+
response = requests.get(url, stream=True)
|
| 135 |
+
|
| 136 |
+
# Download w/ Transfer-Aware Progress
|
| 137 |
+
# => Reference: https://github.com/Textualize/rich/blob/master/examples/downloader.py
|
| 138 |
+
with Progress(
|
| 139 |
+
TextColumn("[bold]{task.description} - {task.fields[fname]}"),
|
| 140 |
+
BarColumn(bar_width=None),
|
| 141 |
+
"[progress.percentage]{task.percentage:>3.1f}%",
|
| 142 |
+
"•",
|
| 143 |
+
DownloadColumn(),
|
| 144 |
+
"•",
|
| 145 |
+
TransferSpeedColumn(),
|
| 146 |
+
transient=True,
|
| 147 |
+
) as dl_progress:
|
| 148 |
+
dl_tid = dl_progress.add_task(
|
| 149 |
+
"Downloading", fname=dest_path.name, total=int(response.headers.get("content-length", "None"))
|
| 150 |
+
)
|
| 151 |
+
with open(dest_path, "wb") as f:
|
| 152 |
+
for data in response.iter_content(chunk_size=chunk_size_bytes):
|
| 153 |
+
dl_progress.advance(dl_tid, f.write(data))
|
| 154 |
+
|
| 155 |
+
return dest_path
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def extract_with_progress(archive_path: Path, download_dir: Path, extract_type: str, cleanup: bool = False) -> Path:
|
| 159 |
+
"""Utility function for extracting compressed archives, with a handy Rich-based progress bar."""
|
| 160 |
+
assert archive_path.suffix == ".zip", "Only `.zip` compressed archives are supported for now!"
|
| 161 |
+
overwatch.info(f"Extracting {archive_path.name} to `{download_dir}`", ctx_level=1)
|
| 162 |
+
|
| 163 |
+
# Extract w/ Progress
|
| 164 |
+
with Progress(
|
| 165 |
+
TextColumn("[bold]{task.description} - {task.fields[aname]}"),
|
| 166 |
+
BarColumn(bar_width=None),
|
| 167 |
+
"[progress.percentage]{task.percentage:>3.1f}%",
|
| 168 |
+
"•",
|
| 169 |
+
MofNCompleteColumn(),
|
| 170 |
+
transient=True,
|
| 171 |
+
) as ext_progress:
|
| 172 |
+
with ZipFile(archive_path) as zf:
|
| 173 |
+
ext_tid = ext_progress.add_task("Extracting", aname=archive_path.name, total=len(members := zf.infolist()))
|
| 174 |
+
extract_path = Path(zf.extract(members[0], download_dir))
|
| 175 |
+
if extract_type == "file":
|
| 176 |
+
assert len(members) == 1, f"Archive `{archive_path}` with extract type `{extract_type} has > 1 member!"
|
| 177 |
+
elif extract_type == "directory":
|
| 178 |
+
for member in members[1:]:
|
| 179 |
+
zf.extract(member, download_dir)
|
| 180 |
+
ext_progress.advance(ext_tid)
|
| 181 |
+
else:
|
| 182 |
+
raise ValueError(f"Extract type `{extract_type}` for archive `{archive_path}` is not defined!")
|
| 183 |
+
|
| 184 |
+
# Cleanup (if specified)
|
| 185 |
+
if cleanup:
|
| 186 |
+
archive_path.unlink()
|
| 187 |
+
|
| 188 |
+
return extract_path
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def download_extract(dataset_id: str, root_dir: Path) -> None:
|
| 192 |
+
"""Download all files for a given dataset (querying registry above), extracting archives if necessary."""
|
| 193 |
+
os.makedirs(download_dir := root_dir / "download" / dataset_id, exist_ok=True)
|
| 194 |
+
|
| 195 |
+
# Download Files => Single-Threaded, with Progress Bar
|
| 196 |
+
dl_tasks = [d for d in DATASET_REGISTRY[dataset_id] if not (download_dir / d["name"]).exists()]
|
| 197 |
+
for dl_task in dl_tasks:
|
| 198 |
+
dl_path = download_with_progress(dl_task["url"], download_dir)
|
| 199 |
+
|
| 200 |
+
# Extract Files (if specified) --> Note (assumes ".zip" ONLY!)
|
| 201 |
+
if dl_task["extract"]:
|
| 202 |
+
dl_path = extract_with_progress(dl_path, download_dir, dl_task["extract_type"])
|
| 203 |
+
dl_path = dl_path.parent if dl_path.is_file() else dl_path
|
| 204 |
+
|
| 205 |
+
# Rename Path --> dl_task["name"]
|
| 206 |
+
if dl_task["do_rename"]:
|
| 207 |
+
shutil.move(dl_path, download_dir / dl_task["name"])
|
capvector-oft/prismatic/preprocessing/materialize.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
materialize.py
|
| 3 |
+
|
| 4 |
+
Factory class for initializing pretraining datasets on a per-VLM basis; provides and exports individual functions for
|
| 5 |
+
clear control flow.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Tuple, Type
|
| 9 |
+
|
| 10 |
+
from torch.utils.data import Dataset
|
| 11 |
+
from transformers import PreTrainedTokenizerBase
|
| 12 |
+
|
| 13 |
+
from prismatic.conf import DatasetConfig
|
| 14 |
+
from prismatic.models.backbones.llm.prompting import PromptBuilder
|
| 15 |
+
from prismatic.models.backbones.vision import ImageTransform
|
| 16 |
+
from prismatic.preprocessing.datasets import AlignDataset, FinetuneDataset
|
| 17 |
+
from prismatic.util.data_utils import PaddedCollatorForLanguageModeling
|
| 18 |
+
|
| 19 |
+
# Dataset Initializers =>> Maps Stage --> cls()
|
| 20 |
+
DATASET_INITIALIZER = {"align": AlignDataset, "finetune": FinetuneDataset, "full-finetune": FinetuneDataset}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_dataset_and_collator(
|
| 24 |
+
stage: str,
|
| 25 |
+
dataset_cfg: DatasetConfig,
|
| 26 |
+
image_transform: ImageTransform,
|
| 27 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 28 |
+
prompt_builder_fn: Type[PromptBuilder],
|
| 29 |
+
default_image_resolution: Tuple[int, int, int],
|
| 30 |
+
padding_side: str = "right",
|
| 31 |
+
) -> Tuple[Dataset, PaddedCollatorForLanguageModeling]:
|
| 32 |
+
dataset_cls = DATASET_INITIALIZER[stage]
|
| 33 |
+
dataset_root_dir = dataset_cfg.dataset_root_dir
|
| 34 |
+
collator = PaddedCollatorForLanguageModeling(
|
| 35 |
+
tokenizer.model_max_length, tokenizer.pad_token_id, default_image_resolution, padding_side=padding_side
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Switch on `stage`
|
| 39 |
+
if stage == "align":
|
| 40 |
+
annotation_json, image_dir = dataset_cfg.align_stage_components
|
| 41 |
+
dataset = dataset_cls(
|
| 42 |
+
dataset_root_dir / annotation_json, dataset_root_dir / image_dir, image_transform, tokenizer
|
| 43 |
+
)
|
| 44 |
+
return dataset, collator
|
| 45 |
+
|
| 46 |
+
elif stage == "finetune":
|
| 47 |
+
annotation_json, image_dir = dataset_cfg.finetune_stage_components
|
| 48 |
+
dataset = dataset_cls(
|
| 49 |
+
dataset_root_dir / annotation_json,
|
| 50 |
+
dataset_root_dir / image_dir,
|
| 51 |
+
image_transform,
|
| 52 |
+
tokenizer,
|
| 53 |
+
prompt_builder_fn=prompt_builder_fn,
|
| 54 |
+
)
|
| 55 |
+
return dataset, collator
|
| 56 |
+
|
| 57 |
+
elif stage == "full-finetune":
|
| 58 |
+
annotation_json, image_dir = dataset_cfg.finetune_stage_components
|
| 59 |
+
dataset = dataset_cls(
|
| 60 |
+
dataset_root_dir / annotation_json,
|
| 61 |
+
dataset_root_dir / image_dir,
|
| 62 |
+
image_transform,
|
| 63 |
+
tokenizer,
|
| 64 |
+
prompt_builder_fn=prompt_builder_fn,
|
| 65 |
+
)
|
| 66 |
+
return dataset, collator
|
| 67 |
+
|
| 68 |
+
else:
|
| 69 |
+
raise ValueError(f"Stage `{stage}` is not supported!")
|
capvector-oft/prismatic/training/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .materialize import get_train_strategy
|
| 2 |
+
from .metrics import Metrics, VLAMetrics
|
capvector-oft/prismatic/training/materialize.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
materialize.py
|
| 3 |
+
|
| 4 |
+
Factory class defining functions for instantiating various Training Strategies, supporting different VLMs, backbones,
|
| 5 |
+
and strategy configurations.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Callable, Optional
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from prismatic.models.vlms import PrismaticVLM
|
| 13 |
+
from prismatic.training.strategies import FSDPStrategy, TrainingStrategy
|
| 14 |
+
|
| 15 |
+
# Registry =>> Maps ID --> {cls(), kwargs} :: supports FSDP for now, but DDP handler is also implemented!
|
| 16 |
+
TRAIN_STRATEGIES = {
|
| 17 |
+
"fsdp-shard-grad-op": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "shard-grad-op"}},
|
| 18 |
+
"fsdp-full-shard": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "full-shard"}},
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_train_strategy(
|
| 23 |
+
train_strategy: str,
|
| 24 |
+
vlm: PrismaticVLM,
|
| 25 |
+
device_id: int,
|
| 26 |
+
stage: str,
|
| 27 |
+
epochs: int,
|
| 28 |
+
max_steps: Optional[int],
|
| 29 |
+
global_batch_size: int,
|
| 30 |
+
per_device_batch_size: int,
|
| 31 |
+
learning_rate: float,
|
| 32 |
+
weight_decay: float,
|
| 33 |
+
max_grad_norm: float,
|
| 34 |
+
lr_scheduler_type: str,
|
| 35 |
+
warmup_ratio: float,
|
| 36 |
+
enable_gradient_checkpointing: bool = True,
|
| 37 |
+
enable_mixed_precision_training: bool = True,
|
| 38 |
+
reduce_in_full_precision: bool = False,
|
| 39 |
+
mixed_precision_dtype: torch.dtype = torch.bfloat16,
|
| 40 |
+
worker_init_fn: Optional[Callable[[int], None]] = None,
|
| 41 |
+
) -> TrainingStrategy:
|
| 42 |
+
if train_strategy in TRAIN_STRATEGIES:
|
| 43 |
+
strategy_cfg = TRAIN_STRATEGIES[train_strategy]
|
| 44 |
+
strategy = strategy_cfg["cls"](
|
| 45 |
+
vlm=vlm,
|
| 46 |
+
device_id=device_id,
|
| 47 |
+
stage=stage,
|
| 48 |
+
epochs=epochs,
|
| 49 |
+
max_steps=max_steps,
|
| 50 |
+
global_batch_size=global_batch_size,
|
| 51 |
+
per_device_batch_size=per_device_batch_size,
|
| 52 |
+
learning_rate=learning_rate,
|
| 53 |
+
weight_decay=weight_decay,
|
| 54 |
+
max_grad_norm=max_grad_norm,
|
| 55 |
+
lr_scheduler_type=lr_scheduler_type,
|
| 56 |
+
warmup_ratio=warmup_ratio,
|
| 57 |
+
enable_gradient_checkpointing=enable_gradient_checkpointing,
|
| 58 |
+
enable_mixed_precision_training=enable_mixed_precision_training,
|
| 59 |
+
reduce_in_full_precision=reduce_in_full_precision,
|
| 60 |
+
mixed_precision_dtype=mixed_precision_dtype,
|
| 61 |
+
worker_init_fn=worker_init_fn,
|
| 62 |
+
**strategy_cfg["kwargs"],
|
| 63 |
+
)
|
| 64 |
+
return strategy
|
| 65 |
+
else:
|
| 66 |
+
raise ValueError(f"Train Strategy `{train_strategy}` is not supported!")
|
capvector-oft/prismatic/training/metrics.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
metrics.py
|
| 3 |
+
|
| 4 |
+
Utility classes defining a Metrics container and multiple Trackers to enable model/stage-specific logging to various
|
| 5 |
+
endpoints (e.g., JSONL local logs, Weights & Biases).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import time
|
| 9 |
+
from collections import defaultdict, deque
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any, Dict, Optional, Protocol, Tuple, Union
|
| 12 |
+
|
| 13 |
+
import jsonlines
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import wandb
|
| 17 |
+
|
| 18 |
+
from prismatic.overwatch import initialize_overwatch
|
| 19 |
+
|
| 20 |
+
# Initialize Overwatch =>> Wraps `logging.Logger`
|
| 21 |
+
overwatch = initialize_overwatch(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# === Define Tracker Interface ===
|
| 25 |
+
class Tracker(Protocol):
|
| 26 |
+
def write_hyperparameters(self) -> None: ...
|
| 27 |
+
|
| 28 |
+
def write(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None: ...
|
| 29 |
+
|
| 30 |
+
def finalize(self) -> None: ...
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# === Individual Tracker Definitions ===
|
| 34 |
+
class JSONLinesTracker:
|
| 35 |
+
def __init__(self, run_id: str, run_dir: Path, hparams: Dict[str, Any]) -> None:
|
| 36 |
+
self.run_id, self.run_dir, self.hparams = run_id, run_dir, hparams
|
| 37 |
+
|
| 38 |
+
@overwatch.rank_zero_only
|
| 39 |
+
def write_hyperparameters(self) -> None:
|
| 40 |
+
with jsonlines.open(self.run_dir / "run-metrics.jsonl", mode="w", sort_keys=True) as js_tracker:
|
| 41 |
+
js_tracker.write({"run_id": self.run_id, "hparams": self.hparams})
|
| 42 |
+
|
| 43 |
+
@overwatch.rank_zero_only
|
| 44 |
+
def write(self, _: int, metrics: Dict[str, Union[int, float]]) -> None:
|
| 45 |
+
with jsonlines.open(self.run_dir / f"{self.run_id}.jsonl", mode="a", sort_keys=True) as js_tracker:
|
| 46 |
+
js_tracker.write(metrics)
|
| 47 |
+
|
| 48 |
+
def finalize(self) -> None:
|
| 49 |
+
return
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class WeightsBiasesTracker:
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
run_id: str,
|
| 56 |
+
run_dir: Path,
|
| 57 |
+
hparams: Dict[str, Any],
|
| 58 |
+
project: str = "prismatic",
|
| 59 |
+
entity: Optional[str] = None,
|
| 60 |
+
group: str = "align",
|
| 61 |
+
) -> None:
|
| 62 |
+
self.run_id, self.run_dir, self.hparams = run_id, run_dir, hparams
|
| 63 |
+
|
| 64 |
+
# Get W&B-Specific Initialization Parameters
|
| 65 |
+
self.project, self.entity, self.group, self.wandb_dir = project, entity, group, self.run_dir
|
| 66 |
+
|
| 67 |
+
# Call W&B.init()
|
| 68 |
+
self.initialize()
|
| 69 |
+
|
| 70 |
+
@overwatch.rank_zero_only
|
| 71 |
+
def initialize(self) -> None:
|
| 72 |
+
wandb.init(
|
| 73 |
+
name=self.run_id,
|
| 74 |
+
dir=self.wandb_dir,
|
| 75 |
+
config=self.hparams,
|
| 76 |
+
project=self.project,
|
| 77 |
+
entity=self.entity,
|
| 78 |
+
group=self.group,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
@overwatch.rank_zero_only
|
| 82 |
+
def write_hyperparameters(self) -> None:
|
| 83 |
+
wandb.config = self.hparams
|
| 84 |
+
|
| 85 |
+
@overwatch.rank_zero_only
|
| 86 |
+
def write(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None:
|
| 87 |
+
wandb.log(metrics, step=global_step)
|
| 88 |
+
|
| 89 |
+
@staticmethod
|
| 90 |
+
def finalize() -> None:
|
| 91 |
+
if overwatch.is_rank_zero():
|
| 92 |
+
wandb.finish()
|
| 93 |
+
|
| 94 |
+
# A job gets 210 seconds to get its affairs in order
|
| 95 |
+
time.sleep(210)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# === Core Metrics Container :: Initializes Trackers => Compiles/Pushes Metrics ===
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class Metrics:
|
| 102 |
+
def __init__(
|
| 103 |
+
self,
|
| 104 |
+
active_trackers: Tuple[str, ...],
|
| 105 |
+
run_id: str,
|
| 106 |
+
run_dir: Path,
|
| 107 |
+
hparams: Dict[str, Any],
|
| 108 |
+
stage: str,
|
| 109 |
+
wandb_project: str = "prismatic",
|
| 110 |
+
wandb_entity: Optional[str] = None,
|
| 111 |
+
grad_accumulation_steps: int = 1,
|
| 112 |
+
window_size: int = 128,
|
| 113 |
+
) -> None:
|
| 114 |
+
self.run_id, self.run_dir, self.hparams, self.stage = run_id, run_dir, hparams, stage
|
| 115 |
+
|
| 116 |
+
# Initialize Trackers
|
| 117 |
+
self.trackers = []
|
| 118 |
+
for tracker_type in active_trackers:
|
| 119 |
+
if tracker_type == "jsonl":
|
| 120 |
+
tracker = JSONLinesTracker(run_id, run_dir, hparams)
|
| 121 |
+
elif tracker_type == "wandb":
|
| 122 |
+
tracker = WeightsBiasesTracker(
|
| 123 |
+
run_id, run_dir, hparams, project=wandb_project, entity=wandb_entity, group=self.stage
|
| 124 |
+
)
|
| 125 |
+
else:
|
| 126 |
+
raise ValueError(f"Tracker with type `{tracker_type} is not supported!")
|
| 127 |
+
|
| 128 |
+
# Add Hyperparameters --> add to `self.trackers`
|
| 129 |
+
tracker.write_hyperparameters()
|
| 130 |
+
self.trackers.append(tracker)
|
| 131 |
+
|
| 132 |
+
# Create Universal Metrics Buffers
|
| 133 |
+
self.global_step, self.start_time, self.step_start_time = 0, time.time(), time.time()
|
| 134 |
+
self.state = {
|
| 135 |
+
"loss_raw": deque(maxlen=grad_accumulation_steps),
|
| 136 |
+
"loss": deque(maxlen=window_size),
|
| 137 |
+
"step_time": deque(maxlen=window_size),
|
| 138 |
+
"lr": [],
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
def log(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None:
|
| 142 |
+
for tracker in self.trackers:
|
| 143 |
+
tracker.write(global_step, metrics)
|
| 144 |
+
|
| 145 |
+
def get_status(self, loss: Optional[torch.Tensor] = None) -> str:
|
| 146 |
+
lr = self.state["lr"][-1] if len(self.state["lr"]) > 0 else 0
|
| 147 |
+
if loss is None:
|
| 148 |
+
return f"=>> [Global Step] {self.global_step:06d} =>> LR :: {lr:.6f}"
|
| 149 |
+
|
| 150 |
+
# Otherwise, embed `loss` in status report!
|
| 151 |
+
return f"=>> [Global Step] {self.global_step:06d} =>> LR :: {lr:.6f} -- Loss :: {loss:.4f}"
|
| 152 |
+
|
| 153 |
+
def commit(
|
| 154 |
+
self, *, global_step: Optional[int] = None, lr: Optional[float] = None, update_step_time: bool = False, **kwargs
|
| 155 |
+
) -> None:
|
| 156 |
+
"""Update all metrics in `self.state` by iterating through special positional arguments & kwargs."""
|
| 157 |
+
if global_step is not None:
|
| 158 |
+
self.global_step = global_step
|
| 159 |
+
|
| 160 |
+
# For all other variables --> only track on rank zero!
|
| 161 |
+
if not overwatch.is_rank_zero():
|
| 162 |
+
return
|
| 163 |
+
|
| 164 |
+
# Special Positional Arguments
|
| 165 |
+
if lr is not None:
|
| 166 |
+
self.state["lr"].append(lr)
|
| 167 |
+
|
| 168 |
+
if update_step_time:
|
| 169 |
+
self.state["step_time"].append(time.time() - self.step_start_time)
|
| 170 |
+
self.step_start_time = time.time()
|
| 171 |
+
|
| 172 |
+
# Generic Keyword Arguments
|
| 173 |
+
for key, value in kwargs.items():
|
| 174 |
+
if key == "loss":
|
| 175 |
+
loss_val = value.detach()
|
| 176 |
+
self.state["loss_raw"].append(loss_val)
|
| 177 |
+
self.state["loss"].append(loss_val)
|
| 178 |
+
else:
|
| 179 |
+
self.state[key].append(value.detach())
|
| 180 |
+
|
| 181 |
+
@overwatch.rank_zero_only
|
| 182 |
+
def push(self) -> str:
|
| 183 |
+
# Note :: Raw Loss is an Average over Gradient Accumulation Steps --> No Smoothing!
|
| 184 |
+
loss_raw = torch.stack(list(self.state["loss_raw"])).mean().item()
|
| 185 |
+
loss = torch.stack(list(self.state["loss"])).mean().item()
|
| 186 |
+
step_time, lr = np.mean(list(self.state["step_time"])), self.state["lr"][-1]
|
| 187 |
+
status = self.get_status(loss)
|
| 188 |
+
|
| 189 |
+
# Fire to Trackers
|
| 190 |
+
prefix = self.stage.capitalize()
|
| 191 |
+
self.log(
|
| 192 |
+
self.global_step,
|
| 193 |
+
metrics={
|
| 194 |
+
f"{prefix}/Step": self.global_step,
|
| 195 |
+
f"{prefix}/Loss": loss,
|
| 196 |
+
f"{prefix}/Loss (Raw)": loss_raw,
|
| 197 |
+
f"{prefix}/Learning Rate": lr,
|
| 198 |
+
f"{prefix}/Step Time": step_time,
|
| 199 |
+
},
|
| 200 |
+
)
|
| 201 |
+
return status
|
| 202 |
+
|
| 203 |
+
def finalize(self) -> str:
|
| 204 |
+
for tracker in self.trackers:
|
| 205 |
+
tracker.finalize()
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class VLAMetrics:
|
| 209 |
+
def __init__(
|
| 210 |
+
self,
|
| 211 |
+
active_trackers: Tuple[str, ...],
|
| 212 |
+
run_id: str,
|
| 213 |
+
run_dir: Path,
|
| 214 |
+
hparams: Dict[str, Any],
|
| 215 |
+
wandb_project: str = "openvla",
|
| 216 |
+
wandb_entity: Optional[str] = "stanford-voltron",
|
| 217 |
+
grad_accumulation_steps: int = 1,
|
| 218 |
+
window_size: int = 1,
|
| 219 |
+
resume_step: Optional[int] = None,
|
| 220 |
+
resume_epoch: Optional[int] = None,
|
| 221 |
+
) -> None:
|
| 222 |
+
self.run_id, self.run_dir, self.hparams = run_id, run_dir, hparams
|
| 223 |
+
|
| 224 |
+
# Initialize Trackers
|
| 225 |
+
self.trackers = []
|
| 226 |
+
for tracker_type in active_trackers:
|
| 227 |
+
if tracker_type == "jsonl":
|
| 228 |
+
tracker = JSONLinesTracker(run_id, run_dir, hparams)
|
| 229 |
+
elif tracker_type == "wandb":
|
| 230 |
+
tracker = WeightsBiasesTracker(
|
| 231 |
+
run_id, run_dir, hparams, project=wandb_project, entity=wandb_entity, group="vla-train"
|
| 232 |
+
)
|
| 233 |
+
else:
|
| 234 |
+
raise ValueError(f"Tracker with type `{tracker_type} is not supported!")
|
| 235 |
+
|
| 236 |
+
# Add Hyperparameters --> add to `self.trackers`
|
| 237 |
+
tracker.write_hyperparameters()
|
| 238 |
+
self.trackers.append(tracker)
|
| 239 |
+
|
| 240 |
+
# Create Universal Metrics Buffers
|
| 241 |
+
self.global_step = 0 if resume_step is None else resume_step
|
| 242 |
+
self.epoch = 0 if resume_epoch is None else resume_epoch
|
| 243 |
+
self.start_time, self.step_start_time = time.time(), time.time()
|
| 244 |
+
self.state = {
|
| 245 |
+
"loss_raw": deque(maxlen=grad_accumulation_steps),
|
| 246 |
+
"loss": deque(maxlen=window_size),
|
| 247 |
+
"l1_loss": deque(maxlen=window_size),
|
| 248 |
+
"action_accuracy": deque(maxlen=window_size),
|
| 249 |
+
"step_time": deque(maxlen=window_size),
|
| 250 |
+
"lr": [],
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
# Created metrics buffers for individual tracked datasets
|
| 254 |
+
self.dataset_trackers = defaultdict(lambda: VLAMetrics([], "", "", {}))
|
| 255 |
+
|
| 256 |
+
def log(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None:
|
| 257 |
+
for tracker in self.trackers:
|
| 258 |
+
tracker.write(global_step, metrics)
|
| 259 |
+
|
| 260 |
+
def get_status(self, loss: Optional[torch.Tensor] = None) -> str:
|
| 261 |
+
lr = self.state["lr"][-1] if len(self.state["lr"]) > 0 else 0
|
| 262 |
+
if loss is None:
|
| 263 |
+
return f"=>> [Epoch {self.epoch:03d}] Global Step {self.global_step:06d} =>> LR :: {lr:.6f}"
|
| 264 |
+
|
| 265 |
+
# Otherwise, embed `loss` in status report!
|
| 266 |
+
return f"=>> [Epoch {self.epoch:03d}] Global Step {self.global_step:06d} =>> LR :: {lr:.6f} - Loss :: {loss:.4f}"
|
| 267 |
+
|
| 268 |
+
def commit(
|
| 269 |
+
self,
|
| 270 |
+
*,
|
| 271 |
+
global_step: Optional[int] = None,
|
| 272 |
+
epoch: Optional[int] = None,
|
| 273 |
+
lr: Optional[float] = None,
|
| 274 |
+
update_step_time: bool = False,
|
| 275 |
+
**kwargs,
|
| 276 |
+
) -> None:
|
| 277 |
+
"""Update all metrics in `self.state` by iterating through special positional arguments & kwargs."""
|
| 278 |
+
if global_step is not None:
|
| 279 |
+
self.global_step = global_step
|
| 280 |
+
|
| 281 |
+
if epoch is not None:
|
| 282 |
+
self.epoch = epoch
|
| 283 |
+
|
| 284 |
+
# For all other variables --> only track on rank zero!
|
| 285 |
+
if not overwatch.is_rank_zero():
|
| 286 |
+
return
|
| 287 |
+
|
| 288 |
+
# Special Positional Arguments
|
| 289 |
+
if lr is not None:
|
| 290 |
+
self.state["lr"].append(lr)
|
| 291 |
+
|
| 292 |
+
if update_step_time:
|
| 293 |
+
self.state["step_time"].append(time.time() - self.step_start_time)
|
| 294 |
+
self.step_start_time = time.time()
|
| 295 |
+
|
| 296 |
+
# Generic Keyword Arguments
|
| 297 |
+
for key, value in kwargs.items():
|
| 298 |
+
if key == "loss":
|
| 299 |
+
loss_val = value.detach()
|
| 300 |
+
self.state["loss_raw"].append(loss_val)
|
| 301 |
+
self.state["loss"].append(loss_val)
|
| 302 |
+
else:
|
| 303 |
+
self.state[key].append(value.detach())
|
| 304 |
+
|
| 305 |
+
def commit_for_dataset(self, dataset_name: str, **kwargs) -> None:
|
| 306 |
+
self.dataset_trackers[dataset_name].commit(**kwargs)
|
| 307 |
+
|
| 308 |
+
@overwatch.rank_zero_only
|
| 309 |
+
def push(self) -> str:
|
| 310 |
+
# Note :: Raw Loss is an Average over Gradient Accumulation Steps --> No Smoothing!
|
| 311 |
+
loss_raw = torch.stack(list(self.state["loss_raw"])).mean().item()
|
| 312 |
+
loss = torch.stack(list(self.state["loss"])).mean().item()
|
| 313 |
+
l1_loss = torch.stack(list(self.state["l1_loss"])).mean().item()
|
| 314 |
+
action_accuracy = torch.stack(list(self.state["action_accuracy"])).mean().item()
|
| 315 |
+
step_time, lr = np.mean(list(self.state["step_time"])), self.state["lr"][-1]
|
| 316 |
+
status = self.get_status(loss)
|
| 317 |
+
|
| 318 |
+
# Get metrics per dataset
|
| 319 |
+
dataset_metrics = {}
|
| 320 |
+
for ds, tracker in self.dataset_trackers.items():
|
| 321 |
+
dataset_metrics.update(
|
| 322 |
+
{
|
| 323 |
+
f"{ds}/L1 Loss": torch.stack(list(tracker.state["l1_loss"])).mean().item(),
|
| 324 |
+
f"{ds}/Action Token Accuracy": torch.stack(list(tracker.state["action_accuracy"])).mean().item(),
|
| 325 |
+
}
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
# Fire to Trackers
|
| 329 |
+
prefix = "VLA Train"
|
| 330 |
+
self.log(
|
| 331 |
+
self.global_step,
|
| 332 |
+
metrics={
|
| 333 |
+
f"{prefix}/Step": self.global_step,
|
| 334 |
+
f"{prefix}/Epoch": self.epoch,
|
| 335 |
+
f"{prefix}/Loss": loss,
|
| 336 |
+
f"{prefix}/L1 Loss": l1_loss,
|
| 337 |
+
f"{prefix}/Action Token Accuracy": action_accuracy,
|
| 338 |
+
f"{prefix}/Loss (Raw)": loss_raw,
|
| 339 |
+
f"{prefix}/Learning Rate": lr,
|
| 340 |
+
f"{prefix}/Step Time": step_time,
|
| 341 |
+
**dataset_metrics,
|
| 342 |
+
},
|
| 343 |
+
)
|
| 344 |
+
return status
|
| 345 |
+
|
| 346 |
+
def finalize(self) -> str:
|
| 347 |
+
for tracker in self.trackers:
|
| 348 |
+
tracker.finalize()
|
capvector-oft/prismatic/training/strategies/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base_strategy import TrainingStrategy
|
| 2 |
+
from .ddp import DDPStrategy
|
| 3 |
+
from .fsdp import FSDPStrategy
|
capvector-oft/prismatic/training/strategies/base_strategy.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
base_strategy.py
|
| 3 |
+
|
| 4 |
+
Abstract class definition of a (distributed) training strategy, with full annotations of class methods, utility
|
| 5 |
+
functions, and initialization logic.
|
| 6 |
+
|
| 7 |
+
Training Strategies (DDP, FSDP-Grad, FSDP-Full) tend to have a lot of repeated components; this class does a lot of
|
| 8 |
+
heavy lifting.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from abc import ABC, abstractmethod
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Callable, Optional
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
import torch.distributed as dist
|
| 18 |
+
from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 21 |
+
|
| 22 |
+
from prismatic.models.vlms import PrismaticVLM
|
| 23 |
+
from prismatic.overwatch import initialize_overwatch
|
| 24 |
+
from prismatic.training.metrics import Metrics, VLAMetrics
|
| 25 |
+
from prismatic.training.train_utils import (
|
| 26 |
+
compute_actions_l1_loss,
|
| 27 |
+
compute_token_accuracy,
|
| 28 |
+
get_current_action_mask,
|
| 29 |
+
get_next_actions_mask,
|
| 30 |
+
)
|
| 31 |
+
from prismatic.util import check_bloat16_supported
|
| 32 |
+
from prismatic.util.batching_utils import SplitModalitySampler
|
| 33 |
+
from prismatic.util.data_utils import PaddedCollatorForActionPrediction, PaddedCollatorForLanguageModeling
|
| 34 |
+
from prismatic.vla.action_tokenizer import ActionTokenizer
|
| 35 |
+
|
| 36 |
+
# HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels)
|
| 37 |
+
from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, NUM_ACTIONS_CHUNK, IGNORE_INDEX
|
| 38 |
+
NEWLINE_INDEX = 13 # '\n'
|
| 39 |
+
STOP_INDEX = 2 # '</s>'
|
| 40 |
+
|
| 41 |
+
# Initialize Overwatch =>> Wraps `logging.Logger`
|
| 42 |
+
overwatch = initialize_overwatch(__name__)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# === Abstract Base Class for an arbitrary Training Strategy ===
|
| 46 |
+
class TrainingStrategy(ABC):
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
vlm: PrismaticVLM,
|
| 50 |
+
device_id: int,
|
| 51 |
+
stage: str,
|
| 52 |
+
epochs: int,
|
| 53 |
+
max_steps: Optional[int],
|
| 54 |
+
global_batch_size: int,
|
| 55 |
+
per_device_batch_size: int,
|
| 56 |
+
learning_rate: float,
|
| 57 |
+
weight_decay: float,
|
| 58 |
+
max_grad_norm: float,
|
| 59 |
+
lr_scheduler_type: str,
|
| 60 |
+
warmup_ratio: float,
|
| 61 |
+
enable_gradient_checkpointing: bool = True,
|
| 62 |
+
enable_mixed_precision_training: bool = True,
|
| 63 |
+
reduce_in_full_precision: bool = False,
|
| 64 |
+
mixed_precision_dtype: torch.dtype = torch.bfloat16,
|
| 65 |
+
worker_init_fn: Optional[Callable[[int], None]] = None,
|
| 66 |
+
**_: str,
|
| 67 |
+
) -> None:
|
| 68 |
+
self.vlm, self.device_id, self.stage = vlm, device_id, stage
|
| 69 |
+
|
| 70 |
+
# Get relevant VLM instance parameters before they get (potentially) wrapped
|
| 71 |
+
self.all_module_keys, self.trainable_module_keys = self.vlm.all_module_keys, self.vlm.trainable_module_keys
|
| 72 |
+
self.llm_transformer_layer_cls = self.vlm.llm_backbone.transformer_layer_cls
|
| 73 |
+
|
| 74 |
+
# Optimization Parameters
|
| 75 |
+
self.epochs, self.max_steps = epochs, max_steps
|
| 76 |
+
self.global_batch_size, self.per_device_batch_size = global_batch_size, per_device_batch_size
|
| 77 |
+
|
| 78 |
+
self.learning_rate, self.weight_decay, self.max_grad_norm = learning_rate, weight_decay, max_grad_norm
|
| 79 |
+
self.lr_scheduler_type, self.warmup_ratio = lr_scheduler_type, warmup_ratio
|
| 80 |
+
|
| 81 |
+
# Generic Strategy Parameters
|
| 82 |
+
self.enable_gradient_checkpointing = enable_gradient_checkpointing
|
| 83 |
+
self.enable_mixed_precision_training = enable_mixed_precision_training
|
| 84 |
+
self.reduce_in_full_precision = reduce_in_full_precision
|
| 85 |
+
self.mixed_precision_dtype = mixed_precision_dtype
|
| 86 |
+
|
| 87 |
+
# DataLoader Parameters
|
| 88 |
+
self.worker_init_fn = worker_init_fn
|
| 89 |
+
|
| 90 |
+
# Optimizers & Scheduler (initialized in `run_setup`)
|
| 91 |
+
self.optimizer, self.lr_scheduler = None, None
|
| 92 |
+
|
| 93 |
+
# Lightweight Validation
|
| 94 |
+
assert (
|
| 95 |
+
self.global_batch_size % self.per_device_batch_size == 0
|
| 96 |
+
), "Per-device batch size must evenly divide global batch size!"
|
| 97 |
+
self.grad_accumulation_steps = self.global_batch_size // self.per_device_batch_size // overwatch.world_size()
|
| 98 |
+
if self.enable_mixed_precision_training:
|
| 99 |
+
assert self.mixed_precision_dtype == torch.bfloat16, "Only BF16 mixed precision training is supported!"
|
| 100 |
+
assert check_bloat16_supported(), "BFloat16 is not supported on this hardware; unset `mixed_precision`"
|
| 101 |
+
|
| 102 |
+
@abstractmethod
|
| 103 |
+
def save_checkpoint(
|
| 104 |
+
self,
|
| 105 |
+
run_dir: Path,
|
| 106 |
+
global_step: int,
|
| 107 |
+
epoch: int,
|
| 108 |
+
train_loss: Optional[float] = None,
|
| 109 |
+
only_trainable: bool = True,
|
| 110 |
+
) -> None: ...
|
| 111 |
+
|
| 112 |
+
@abstractmethod
|
| 113 |
+
def run_setup(self, run_dir: Path, n_train_examples: int) -> None: ...
|
| 114 |
+
|
| 115 |
+
@abstractmethod
|
| 116 |
+
def clip_grad_norm(self) -> None: ...
|
| 117 |
+
|
| 118 |
+
def run_training(
|
| 119 |
+
self,
|
| 120 |
+
dataset: Dataset,
|
| 121 |
+
collator: PaddedCollatorForLanguageModeling,
|
| 122 |
+
metrics: Metrics,
|
| 123 |
+
stage: str = "finetune",
|
| 124 |
+
batch_construction_strategy: str = "split-modality",
|
| 125 |
+
seed: int = 7,
|
| 126 |
+
) -> None:
|
| 127 |
+
"""Run the training loop for the given `dataset` and `collator`; log losses, results to `metrics`"""
|
| 128 |
+
if "finetune" in stage and batch_construction_strategy == "split-modality":
|
| 129 |
+
# Instantiate the split-modality sampler; if you want to extend with other batch construction schemes,
|
| 130 |
+
# (e.g., grouping by length) =>> can easily add them here!
|
| 131 |
+
modality_lengths = dataset.get_modality_lengths()
|
| 132 |
+
sampler = SplitModalitySampler(
|
| 133 |
+
dataset,
|
| 134 |
+
modality_lengths,
|
| 135 |
+
global_batch_size=self.global_batch_size,
|
| 136 |
+
num_replicas=overwatch.world_size(),
|
| 137 |
+
rank=overwatch.rank(),
|
| 138 |
+
seed=seed,
|
| 139 |
+
drop_last=False,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
else:
|
| 143 |
+
sampler = DistributedSampler(
|
| 144 |
+
dataset,
|
| 145 |
+
num_replicas=overwatch.world_size(),
|
| 146 |
+
rank=overwatch.rank(),
|
| 147 |
+
shuffle=True,
|
| 148 |
+
seed=seed,
|
| 149 |
+
drop_last=False,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# Create a DataLoader with the initialized sampler, per-device-bsz, and collator
|
| 153 |
+
dataloader = DataLoader(
|
| 154 |
+
dataset,
|
| 155 |
+
batch_size=self.per_device_batch_size,
|
| 156 |
+
sampler=sampler,
|
| 157 |
+
collate_fn=collator,
|
| 158 |
+
num_workers=2,
|
| 159 |
+
worker_init_fn=self.worker_init_fn,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Max Steps vs. Epochs Computation
|
| 163 |
+
steps_per_epoch = len(dataloader) // self.grad_accumulation_steps
|
| 164 |
+
if self.max_steps is not None and steps_per_epoch < self.max_steps:
|
| 165 |
+
# Just set `epochs` to some large number --> we'll short-circuit based on steps anyway
|
| 166 |
+
self.epochs = 100
|
| 167 |
+
|
| 168 |
+
# === Train ===
|
| 169 |
+
status = metrics.get_status()
|
| 170 |
+
with tqdm(
|
| 171 |
+
total=(
|
| 172 |
+
(self.epochs * (len(dataloader) // self.grad_accumulation_steps))
|
| 173 |
+
if self.max_steps is None
|
| 174 |
+
else self.max_steps
|
| 175 |
+
),
|
| 176 |
+
desc=status,
|
| 177 |
+
leave=False,
|
| 178 |
+
disable=not overwatch.is_rank_zero(),
|
| 179 |
+
) as progress:
|
| 180 |
+
for epoch in range(self.epochs):
|
| 181 |
+
self.vlm.train()
|
| 182 |
+
sampler.set_epoch(epoch)
|
| 183 |
+
|
| 184 |
+
# Zero-Gradients (just in case)
|
| 185 |
+
self.optimizer.zero_grad()
|
| 186 |
+
|
| 187 |
+
# Note that we'll unpack batch (and let AMP/FSDP do its thing) in the VLM.forward() call
|
| 188 |
+
# => Basically, if we're using mixed precision (or not), autocast()/FSDP will move to device!
|
| 189 |
+
for train_idx, batch in enumerate(dataloader):
|
| 190 |
+
# [Contract] self.vlm.forward() must automatically compute `loss` and return!
|
| 191 |
+
with torch.autocast(
|
| 192 |
+
"cuda",
|
| 193 |
+
dtype=self.mixed_precision_dtype,
|
| 194 |
+
enabled=self.enable_mixed_precision_training,
|
| 195 |
+
):
|
| 196 |
+
output: CausalLMOutputWithPast = self.vlm(
|
| 197 |
+
input_ids=batch["input_ids"],
|
| 198 |
+
attention_mask=batch["attention_mask"],
|
| 199 |
+
pixel_values=batch["pixel_values"],
|
| 200 |
+
labels=batch["labels"],
|
| 201 |
+
multimodal_indices=batch["multimodal_indices"],
|
| 202 |
+
)
|
| 203 |
+
loss = output.loss
|
| 204 |
+
|
| 205 |
+
# Commit Loss (Prior to Gradient Accumulation Normalization)
|
| 206 |
+
metrics.commit(loss=loss)
|
| 207 |
+
|
| 208 |
+
# Normalize Loss to account for Gradient Accumulation --> Backward!
|
| 209 |
+
# [IMPORTANT] Technically speaking, doing gradient accumulation in this way is "incorrect"; this is
|
| 210 |
+
# because in general, each batch has a *different number of masked out tokens* (because
|
| 211 |
+
# we're instruct-tuning). Taking the mean over two unbalanced means != the right thing!
|
| 212 |
+
#
|
| 213 |
+
# HOWEVER -- at least at the 7B scale, the "naive" approach is just as performant as
|
| 214 |
+
# the "correct" implementation, without adding extra complexity.
|
| 215 |
+
#
|
| 216 |
+
# That being said =>> at the 13B scale, *no matter what we tried, ANY gradient accumulation is just
|
| 217 |
+
# really bad for downstream performance. Initial investigation shows that BF16 accumulation
|
| 218 |
+
# just really tanks in precision... and don't have a good/clean way to fix this. Would love for
|
| 219 |
+
# someone to PR and fix this (and I'd greatly appreciate it!!!)
|
| 220 |
+
normalized_loss = loss / self.grad_accumulation_steps
|
| 221 |
+
normalized_loss.backward()
|
| 222 |
+
|
| 223 |
+
# Step =>> Only if Done w/ Gradient Accumulation
|
| 224 |
+
if (train_idx + 1) % self.grad_accumulation_steps == 0:
|
| 225 |
+
metrics.commit(update_step_time=True)
|
| 226 |
+
|
| 227 |
+
# Clip Gradients --> this is custom, per-strategy because of DDP vs. FSDP locality-assumptions
|
| 228 |
+
self.clip_grad_norm()
|
| 229 |
+
|
| 230 |
+
# Optimizer & LR Scheduler Step
|
| 231 |
+
self.optimizer.step()
|
| 232 |
+
self.lr_scheduler.step()
|
| 233 |
+
self.optimizer.zero_grad()
|
| 234 |
+
|
| 235 |
+
# Push Metrics
|
| 236 |
+
metrics.commit(global_step=metrics.global_step + 1, lr=self.lr_scheduler.get_last_lr()[0])
|
| 237 |
+
status = metrics.push()
|
| 238 |
+
|
| 239 |
+
# Check for Termination & Save Final Checkpoint (in case `max_steps` is not None)
|
| 240 |
+
if self.max_steps is not None and metrics.global_step >= self.max_steps:
|
| 241 |
+
self.save_checkpoint(metrics.run_dir, metrics.global_step, epoch, loss.item())
|
| 242 |
+
dist.barrier()
|
| 243 |
+
|
| 244 |
+
return
|
| 245 |
+
|
| 246 |
+
# Update Progress Bar
|
| 247 |
+
progress.update()
|
| 248 |
+
progress.set_description(status)
|
| 249 |
+
|
| 250 |
+
# Save checkpoint at end each epoch (if `self.max_steps` is None)
|
| 251 |
+
if self.max_steps is None:
|
| 252 |
+
self.save_checkpoint(metrics.run_dir, metrics.global_step, epoch, loss.item())
|
| 253 |
+
dist.barrier()
|
| 254 |
+
|
| 255 |
+
# === VLA Training ===
|
| 256 |
+
|
| 257 |
+
def run_vla_training(
|
| 258 |
+
self,
|
| 259 |
+
vla_dataset: IterableDataset,
|
| 260 |
+
collator: PaddedCollatorForActionPrediction,
|
| 261 |
+
action_tokenizer: ActionTokenizer,
|
| 262 |
+
metrics: VLAMetrics,
|
| 263 |
+
save_interval: int = 2500,
|
| 264 |
+
save_full_model: bool = True,
|
| 265 |
+
) -> None:
|
| 266 |
+
"""Run the VLA training loop for the given `dataset` and `collator`; log losses, action metrics to `metrics`."""
|
| 267 |
+
assert isinstance(vla_dataset, IterableDataset), "VLA training expects an IterableDataset!"
|
| 268 |
+
assert self.grad_accumulation_steps == 1, "VLA training does not support gradient accumulation!"
|
| 269 |
+
|
| 270 |
+
# Create a DataLoader =>> Set `num_workers` to 0; RLDS loader handles parallelism!
|
| 271 |
+
dataloader = DataLoader(
|
| 272 |
+
vla_dataset,
|
| 273 |
+
batch_size=self.per_device_batch_size,
|
| 274 |
+
sampler=None,
|
| 275 |
+
collate_fn=collator,
|
| 276 |
+
num_workers=0,
|
| 277 |
+
worker_init_fn=self.worker_init_fn,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# === Train ===
|
| 281 |
+
status = metrics.get_status()
|
| 282 |
+
with tqdm(
|
| 283 |
+
total=(self.epochs * len(dataloader)) if self.max_steps is None else self.max_steps,
|
| 284 |
+
desc=status,
|
| 285 |
+
leave=False,
|
| 286 |
+
disable=not overwatch.is_rank_zero(),
|
| 287 |
+
) as progress:
|
| 288 |
+
self.vlm.train()
|
| 289 |
+
|
| 290 |
+
# Zero Gradients (just in case)
|
| 291 |
+
self.optimizer.zero_grad()
|
| 292 |
+
|
| 293 |
+
# [Contract] DataLoader wraps RLDS Loader (`.as_numpy_iterator() =>> implicit `.repeat()`)
|
| 294 |
+
# => This means looping over the DataLoader is basically "infinite" (so no outer loop over epochs).
|
| 295 |
+
# Slightly breaks default PyTorch semantics, which is why we adaptively compute `epoch` below.
|
| 296 |
+
for batch in dataloader:
|
| 297 |
+
# Note that we'll unpack batch (and let AMP/FSDP do its thing) in the VLM.forward() call
|
| 298 |
+
# => Basically, if we're using mixed precision (or not), autocast()/FSDP will move to device!
|
| 299 |
+
with torch.autocast(
|
| 300 |
+
"cuda", dtype=self.mixed_precision_dtype, enabled=self.enable_mixed_precision_training
|
| 301 |
+
):
|
| 302 |
+
# [Contract] self.vlm.forward() must automatically compute `loss` and return!
|
| 303 |
+
output: CausalLMOutputWithPast = self.vlm(
|
| 304 |
+
input_ids=batch["input_ids"],
|
| 305 |
+
attention_mask=batch["attention_mask"],
|
| 306 |
+
pixel_values=batch["pixel_values"],
|
| 307 |
+
labels=batch["labels"],
|
| 308 |
+
)
|
| 309 |
+
loss = output.loss
|
| 310 |
+
|
| 311 |
+
# Commit Loss =>> Backward!
|
| 312 |
+
metrics.commit(loss=loss)
|
| 313 |
+
loss.backward()
|
| 314 |
+
|
| 315 |
+
# Get predicted and ground-truth token IDs
|
| 316 |
+
predicted_token_ids = output.logits[:, self.vlm.vision_backbone.num_patches : -1].argmax(dim=2)
|
| 317 |
+
ground_truth_token_ids = batch["labels"][:, 1:].to(predicted_token_ids.device)
|
| 318 |
+
|
| 319 |
+
#######################################################################
|
| 320 |
+
# === Compute Current Action Token Accuracy & L1 Loss ===
|
| 321 |
+
#######################################################################
|
| 322 |
+
|
| 323 |
+
# Get current action mask: Target the first ACTION_DIM non-ignore tokens
|
| 324 |
+
current_action_mask = get_current_action_mask(ground_truth_token_ids)
|
| 325 |
+
|
| 326 |
+
# Compute Accuracy
|
| 327 |
+
action_accuracy = compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask=current_action_mask)
|
| 328 |
+
|
| 329 |
+
# Compute L1 Loss on Predicted (Continuous) Actions
|
| 330 |
+
action_l1_loss = compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=current_action_mask)
|
| 331 |
+
|
| 332 |
+
#######################################################################
|
| 333 |
+
# === Compute Next Actions Token Accuracy & L1 Loss ===
|
| 334 |
+
#######################################################################
|
| 335 |
+
|
| 336 |
+
# Get next actions mask: Target all tokens after the first ACTION_DIM non-ignore tokens (excluding the last token, which is the stop token)
|
| 337 |
+
next_actions_mask = get_next_actions_mask(ground_truth_token_ids)
|
| 338 |
+
|
| 339 |
+
# Compute Accuracy
|
| 340 |
+
next_actions_accuracy = compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask)
|
| 341 |
+
|
| 342 |
+
# Compute L1 Loss on Predicted (Continuous) Actions
|
| 343 |
+
next_actions_l1_loss = compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask)
|
| 344 |
+
|
| 345 |
+
#######################################################################
|
| 346 |
+
# === Log ===
|
| 347 |
+
#######################################################################
|
| 348 |
+
|
| 349 |
+
# Commit Metrics
|
| 350 |
+
metrics.commit(
|
| 351 |
+
action_accuracy=action_accuracy,
|
| 352 |
+
l1_loss=action_l1_loss,
|
| 353 |
+
next_actions_accuracy=next_actions_accuracy,
|
| 354 |
+
next_actions_l1_loss=next_actions_l1_loss,
|
| 355 |
+
update_step_time=True,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# Compute metrics per dataset --> only on rank_zero since we don't log them on other workers anyways
|
| 359 |
+
if overwatch.is_rank_zero():
|
| 360 |
+
datasets = set(batch["dataset_names"])
|
| 361 |
+
if len(datasets) > 1:
|
| 362 |
+
for ds in datasets:
|
| 363 |
+
ds_mask = torch.tensor([elem == ds for elem in batch["dataset_names"]])
|
| 364 |
+
action_accuracy_ds = correct_preds[ds_mask].sum().float() / mask[ds_mask].sum().float()
|
| 365 |
+
pred_continuous_actions_ds = torch.tensor(
|
| 366 |
+
action_tokenizer.decode_token_ids_to_actions(
|
| 367 |
+
predicted_token_ids[ds_mask][mask[ds_mask]].cpu().numpy()
|
| 368 |
+
)
|
| 369 |
+
)
|
| 370 |
+
continuous_actions_gt_ds = torch.tensor(
|
| 371 |
+
action_tokenizer.decode_token_ids_to_actions(
|
| 372 |
+
ground_truth_token_ids[ds_mask][mask[ds_mask]].cpu().numpy()
|
| 373 |
+
)
|
| 374 |
+
)
|
| 375 |
+
action_l1_loss_ds = torch.nn.functional.l1_loss(
|
| 376 |
+
pred_continuous_actions_ds, continuous_actions_gt_ds
|
| 377 |
+
)
|
| 378 |
+
metrics.commit_for_dataset(
|
| 379 |
+
dataset_name=ds.decode(),
|
| 380 |
+
action_accuracy=action_accuracy_ds,
|
| 381 |
+
l1_loss=action_l1_loss_ds,
|
| 382 |
+
next_actions_accuracy=next_actions_accuracy,
|
| 383 |
+
next_actions_l1_loss=next_actions_l1_loss,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
# === Gradient Step ===
|
| 387 |
+
|
| 388 |
+
# Clip Gradients --> this is custom, per-strategy because of DDP vs. FSDP locality assumptions
|
| 389 |
+
self.clip_grad_norm()
|
| 390 |
+
|
| 391 |
+
# Optimizer & LR Scheduler Step
|
| 392 |
+
self.optimizer.step()
|
| 393 |
+
self.lr_scheduler.step()
|
| 394 |
+
self.optimizer.zero_grad()
|
| 395 |
+
|
| 396 |
+
# Compute epoch value using number of completed gradient steps
|
| 397 |
+
epoch = (metrics.global_step + 1) // (len(vla_dataset) // self.global_batch_size)
|
| 398 |
+
|
| 399 |
+
# Push Metrics
|
| 400 |
+
metrics.commit(global_step=metrics.global_step + 1, epoch=epoch, lr=self.lr_scheduler.get_last_lr()[0])
|
| 401 |
+
status = metrics.push()
|
| 402 |
+
|
| 403 |
+
# Check for Save Interval or Max Steps & Save Checkpoint
|
| 404 |
+
if (terminate := (self.max_steps is not None and metrics.global_step >= self.max_steps)) or (
|
| 405 |
+
(metrics.global_step % save_interval) == 0
|
| 406 |
+
):
|
| 407 |
+
self.save_checkpoint(
|
| 408 |
+
metrics.run_dir, metrics.global_step, epoch, loss.item(), only_trainable=not save_full_model
|
| 409 |
+
)
|
| 410 |
+
dist.barrier()
|
| 411 |
+
|
| 412 |
+
if terminate:
|
| 413 |
+
return
|
| 414 |
+
|
| 415 |
+
# Update Progress Bar
|
| 416 |
+
progress.update()
|
| 417 |
+
progress.set_description(status)
|
capvector-oft/prismatic/training/strategies/ddp.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ddp.py
|
| 3 |
+
|
| 4 |
+
Core class definition for a strategy implementing Torch native Distributed Data Parallel Training; note that on most
|
| 5 |
+
GPU hardware and LLM backbones >= 5-7B parameters, DDP training will OOM, which is why we opt for FSDP.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import shutil
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 14 |
+
from torch.optim import AdamW
|
| 15 |
+
from transformers.optimization import get_constant_schedule, get_cosine_schedule_with_warmup
|
| 16 |
+
|
| 17 |
+
from prismatic.overwatch import initialize_overwatch
|
| 18 |
+
from prismatic.training.strategies.base_strategy import TrainingStrategy
|
| 19 |
+
|
| 20 |
+
# Initialize Overwatch =>> Wraps `logging.Logger`
|
| 21 |
+
overwatch = initialize_overwatch(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class DDPStrategy(TrainingStrategy):
|
| 25 |
+
@overwatch.rank_zero_only
|
| 26 |
+
def save_checkpoint(
|
| 27 |
+
self,
|
| 28 |
+
run_dir: Path,
|
| 29 |
+
global_step: int,
|
| 30 |
+
epoch: int,
|
| 31 |
+
train_loss: Optional[float] = None,
|
| 32 |
+
only_trainable: bool = True,
|
| 33 |
+
) -> None:
|
| 34 |
+
"""Save a checkpoint to the `run_dir` only containing the state_dicts for trainable parameters by default."""
|
| 35 |
+
assert isinstance(self.vlm, DDP), "save_checkpoint assumes VLM is already wrapped in DDP!"
|
| 36 |
+
|
| 37 |
+
# Splinter State Dictionary by Top-Level Submodules (or subset, if `only_trainable`)
|
| 38 |
+
model_state_dicts = {
|
| 39 |
+
mkey: getattr(self.vlm.module, mkey).state_dict()
|
| 40 |
+
for mkey in (self.trainable_module_keys if only_trainable else self.all_module_keys)
|
| 41 |
+
}
|
| 42 |
+
optimizer_state_dict = self.optimizer.state_dict()
|
| 43 |
+
|
| 44 |
+
# Set Checkpoint Path =>> Embed *minimal* training statistics!
|
| 45 |
+
checkpoint_dir = run_dir / "checkpoints"
|
| 46 |
+
if train_loss is None:
|
| 47 |
+
checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss=inf.pt"
|
| 48 |
+
else:
|
| 49 |
+
checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss={train_loss:.4f}.pt"
|
| 50 |
+
|
| 51 |
+
# Save Checkpoint & Copy Latest to `latest-checkpoint.pt`
|
| 52 |
+
torch.save({"model": model_state_dicts, "optimizer": optimizer_state_dict}, checkpoint_path)
|
| 53 |
+
shutil.copy(checkpoint_path, checkpoint_dir / "latest-checkpoint.pt")
|
| 54 |
+
|
| 55 |
+
def run_setup(self, run_dir: Path, n_train_examples: int) -> None:
|
| 56 |
+
# Gradient Checkpointing Setup
|
| 57 |
+
if self.enable_gradient_checkpointing:
|
| 58 |
+
# For Gradient Checkpointing --> we make the assumption that the "bulk" of activation memory is taken up
|
| 59 |
+
# by the LLM; because we also make the explicit assumption that each LLM is derived from a HF
|
| 60 |
+
# pretrained model, the only thing we *need* to do (technically) is call `gradient_checkpoint_enable`
|
| 61 |
+
# on `self.llm_backbone`.
|
| 62 |
+
#
|
| 63 |
+
# What does it actually do? --> runs the *generic* custom_forward + torch.utils.checkpoint.checkpoint logic
|
| 64 |
+
# => github.com/huggingface/transformers/.../models/llama/modeling_llama.py#L692-L706
|
| 65 |
+
#
|
| 66 |
+
# Additional Reference (to better understand gradient checkpointing in PyTorch writ large)
|
| 67 |
+
# => github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb
|
| 68 |
+
overwatch.info("Enabling Gradient Checkpointing on LLM Backbone", ctx_level=1)
|
| 69 |
+
self.vlm.llm_backbone.gradient_checkpointing_enable()
|
| 70 |
+
|
| 71 |
+
# Move to Device =>> Note parameters are in full precision (*mixed precision* will only autocast as appropriate)
|
| 72 |
+
overwatch.info("Placing Entire VLM (Vision Backbone, LLM Backbone, Projector Weights) on GPU", ctx_level=1)
|
| 73 |
+
self.vlm.to(self.device_id)
|
| 74 |
+
|
| 75 |
+
# Wrap with Distributed Data Parallel
|
| 76 |
+
# => Note: By default, wrapping naively with DDP(self.vlm) will initialize a *separate* buffer on GPU that
|
| 77 |
+
# is the same size/dtype as the model parameters; this will *double* GPU memory!
|
| 78 |
+
# - stackoverflow.com/questions/68949954/model-takes-twice-the-memory-footprint-with-distributed-data-parallel
|
| 79 |
+
overwatch.info("Wrapping VLM with Distributed Data Parallel", ctx_level=1)
|
| 80 |
+
self.vlm = DDP(self.vlm, device_ids=[self.device_id], gradient_as_bucket_view=True)
|
| 81 |
+
|
| 82 |
+
# Create Optimizer and LR Scheduler =>> note that most of the LR Schedulers we use require `max_steps/epochs`
|
| 83 |
+
# => Optimizer should only operate on parameters that are *unfrozen* / trainable!
|
| 84 |
+
trainable_params = [param for param in self.vlm.parameters() if param.requires_grad]
|
| 85 |
+
if self.max_steps is None:
|
| 86 |
+
num_training_steps = (n_train_examples * self.epochs) // self.global_batch_size
|
| 87 |
+
else:
|
| 88 |
+
num_training_steps = self.max_steps
|
| 89 |
+
|
| 90 |
+
if self.lr_scheduler_type == "linear-warmup+cosine-decay":
|
| 91 |
+
# Set warmup steps (floor) based on `warmup_ratio` (should be 0.03 - 0.05)
|
| 92 |
+
num_warmup_steps = int(num_training_steps * self.warmup_ratio)
|
| 93 |
+
|
| 94 |
+
assert self.weight_decay == 0, "DDP training does not currently support `weight_decay` > 0!"
|
| 95 |
+
self.optimizer = AdamW(trainable_params, lr=self.learning_rate, weight_decay=self.weight_decay)
|
| 96 |
+
self.lr_scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps, num_training_steps)
|
| 97 |
+
for param_group in self.optimizer.param_groups:
|
| 98 |
+
param_group["lr"] = 0.0
|
| 99 |
+
|
| 100 |
+
elif self.lr_scheduler_type == "constant":
|
| 101 |
+
num_warmup_steps = 0
|
| 102 |
+
|
| 103 |
+
assert self.weight_decay == 0, "DDP training does not currently support `weight_decay` > 0!"
|
| 104 |
+
self.optimizer = AdamW(trainable_params, lr=self.learning_rate, weight_decay=self.weight_decay)
|
| 105 |
+
self.lr_scheduler = get_constant_schedule(self.optimizer)
|
| 106 |
+
|
| 107 |
+
else:
|
| 108 |
+
raise ValueError(f"Learning Rate Schedule with type `{self.lr_scheduler_type}` is not supported!")
|
| 109 |
+
|
| 110 |
+
# Finalize Setup =>> Log
|
| 111 |
+
overwatch.info(
|
| 112 |
+
"DDP Strategy =>> Finalized Training Setup:\n"
|
| 113 |
+
f" |-> Global (Effective) Batch Size = {self.global_batch_size}\n"
|
| 114 |
+
f" |-> Per-Device Batch Size = {self.per_device_batch_size}\n"
|
| 115 |
+
f" |-> Distributed World Size = {overwatch.world_size()}\n"
|
| 116 |
+
f" |-> Gradient Accumulation Steps = {self.grad_accumulation_steps}\n\n"
|
| 117 |
+
f" |-> LLM Backbone Gradient Checkpointing = {self.enable_gradient_checkpointing}\n"
|
| 118 |
+
f" |-> Use Native AMP = {self.enable_mixed_precision_training} ({self.mixed_precision_dtype})\n\n"
|
| 119 |
+
f" |-> Default AdamW LR = {self.learning_rate}\n"
|
| 120 |
+
f" |-> AdamW Weight Decay = {self.weight_decay}\n"
|
| 121 |
+
f" |-> LR Scheduler Type = {self.lr_scheduler_type}\n"
|
| 122 |
+
f" |-> LR Scheduler Warmup Steps (Ratio) = {num_warmup_steps} ({self.warmup_ratio})\n"
|
| 123 |
+
f" |-> Dataset Size = {n_train_examples} Examples\n"
|
| 124 |
+
f" |-> Max Steps = {num_training_steps}\n"
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
def clip_grad_norm(self) -> None:
|
| 128 |
+
torch.nn.utils.clip_grad_norm_(self.vlm.parameters(), max_norm=self.max_grad_norm)
|
capvector-oft/prismatic/training/strategies/fsdp.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
fsdp.py
|
| 3 |
+
|
| 4 |
+
Core class definition for a strategy implementing Torch native Fully Sharded Data Parallel Training (with support for
|
| 5 |
+
fine-grained control over wrapping policies and mixed precision per component).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
from collections import OrderedDict
|
| 10 |
+
from functools import partial
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Callable, Optional
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
| 18 |
+
CheckpointImpl,
|
| 19 |
+
apply_activation_checkpointing,
|
| 20 |
+
checkpoint_wrapper,
|
| 21 |
+
)
|
| 22 |
+
from torch.distributed.fsdp import (
|
| 23 |
+
FullStateDictConfig,
|
| 24 |
+
MixedPrecision,
|
| 25 |
+
ShardingStrategy,
|
| 26 |
+
StateDictType,
|
| 27 |
+
)
|
| 28 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 29 |
+
from torch.optim import AdamW
|
| 30 |
+
from transformers.optimization import get_constant_schedule, get_cosine_schedule_with_warmup
|
| 31 |
+
|
| 32 |
+
from prismatic.models.vlms import PrismaticVLM
|
| 33 |
+
from prismatic.overwatch import initialize_overwatch
|
| 34 |
+
from prismatic.training.strategies.base_strategy import TrainingStrategy
|
| 35 |
+
|
| 36 |
+
# Initialize Overwatch =>> Wraps `logging.Logger`
|
| 37 |
+
overwatch = initialize_overwatch(__name__)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class FSDPStrategy(TrainingStrategy):
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
vlm: PrismaticVLM,
|
| 44 |
+
device_id: int,
|
| 45 |
+
stage: str,
|
| 46 |
+
epochs: int,
|
| 47 |
+
max_steps: Optional[int],
|
| 48 |
+
global_batch_size: int,
|
| 49 |
+
per_device_batch_size: int,
|
| 50 |
+
learning_rate: float,
|
| 51 |
+
weight_decay: float,
|
| 52 |
+
max_grad_norm: float,
|
| 53 |
+
lr_scheduler_type: str,
|
| 54 |
+
warmup_ratio: float,
|
| 55 |
+
enable_gradient_checkpointing: bool = True,
|
| 56 |
+
enable_mixed_precision_training: bool = True,
|
| 57 |
+
reduce_in_full_precision: bool = False,
|
| 58 |
+
mixed_precision_dtype: torch.dtype = torch.bfloat16,
|
| 59 |
+
worker_init_fn: Optional[Callable[[int], None]] = None,
|
| 60 |
+
sharding_strategy: str = "shard-grad-op",
|
| 61 |
+
state_dict_type: StateDictType = StateDictType.FULL_STATE_DICT,
|
| 62 |
+
) -> None:
|
| 63 |
+
super().__init__(
|
| 64 |
+
vlm=vlm,
|
| 65 |
+
device_id=device_id,
|
| 66 |
+
stage=stage,
|
| 67 |
+
epochs=epochs,
|
| 68 |
+
max_steps=max_steps,
|
| 69 |
+
global_batch_size=global_batch_size,
|
| 70 |
+
per_device_batch_size=per_device_batch_size,
|
| 71 |
+
learning_rate=learning_rate,
|
| 72 |
+
weight_decay=weight_decay,
|
| 73 |
+
max_grad_norm=max_grad_norm,
|
| 74 |
+
lr_scheduler_type=lr_scheduler_type,
|
| 75 |
+
warmup_ratio=warmup_ratio,
|
| 76 |
+
enable_gradient_checkpointing=enable_gradient_checkpointing,
|
| 77 |
+
enable_mixed_precision_training=enable_mixed_precision_training,
|
| 78 |
+
reduce_in_full_precision=reduce_in_full_precision,
|
| 79 |
+
mixed_precision_dtype=mixed_precision_dtype,
|
| 80 |
+
worker_init_fn=worker_init_fn,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# FSDP-Specific Parameters
|
| 84 |
+
if sharding_strategy == "shard-grad-op":
|
| 85 |
+
self.fsdp_sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2
|
| 86 |
+
elif sharding_strategy == "full-shard":
|
| 87 |
+
self.fsdp_sharding_strategy = ShardingStrategy.HYBRID_SHARD
|
| 88 |
+
else:
|
| 89 |
+
raise ValueError(f"FSDP Sharding Strategy {sharding_strategy} is not supported!")
|
| 90 |
+
|
| 91 |
+
assert state_dict_type == StateDictType.FULL_STATE_DICT, "Sharded state saving is not yet implemented!"
|
| 92 |
+
self.fsdp_state_dict_type = state_dict_type
|
| 93 |
+
self.fsdp_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
| 94 |
+
|
| 95 |
+
def save_checkpoint(
|
| 96 |
+
self,
|
| 97 |
+
run_dir: Path,
|
| 98 |
+
global_step: int,
|
| 99 |
+
epoch: int,
|
| 100 |
+
train_loss: Optional[float] = None,
|
| 101 |
+
only_trainable: bool = True,
|
| 102 |
+
) -> None:
|
| 103 |
+
"""Save a checkpoint to the `run_dir` only containing the state_dicts for trainable parameters by default."""
|
| 104 |
+
assert isinstance(self.vlm, FSDP), "FSDPStrategy.save_checkpoint assumes VLM is already wrapped in FSDP!"
|
| 105 |
+
|
| 106 |
+
# Summon Full State Dictionary =>> Reconstitute from Shards
|
| 107 |
+
with FSDP.state_dict_type(self.vlm, self.fsdp_state_dict_type, self.fsdp_save_policy):
|
| 108 |
+
full_vlm_state_dict = self.vlm.state_dict()
|
| 109 |
+
model_state_dicts = {
|
| 110 |
+
mkey: OrderedDict() for mkey in (self.trainable_module_keys if only_trainable else self.all_module_keys)
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
# Iterate through `full_vlm_state_dict` and split `mkey.{full_dotted_path}` -> `mkey: {full_dotted_path}`
|
| 114 |
+
for key, param in full_vlm_state_dict.items():
|
| 115 |
+
for mkey in model_state_dicts:
|
| 116 |
+
if key.startswith(mprefix := f"{mkey}."):
|
| 117 |
+
model_state_dicts[mkey][key.removeprefix(mprefix)] = param
|
| 118 |
+
|
| 119 |
+
# Save on rank zero *only*
|
| 120 |
+
if overwatch.is_rank_zero():
|
| 121 |
+
checkpoint_dir = run_dir / "checkpoints"
|
| 122 |
+
if train_loss is None:
|
| 123 |
+
checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss=inf.pt"
|
| 124 |
+
else:
|
| 125 |
+
checkpoint_path = (
|
| 126 |
+
checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss={train_loss:.4f}.pt"
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Save Checkpoint & Copy Latest to `latest-checkpoint.pt`
|
| 130 |
+
torch.save({"model": model_state_dicts}, checkpoint_path)
|
| 131 |
+
|
| 132 |
+
# TODO (siddk) :: This breaks w/ Sagemaker default permissions (root vs. <user>)... skip?
|
| 133 |
+
# shutil.copy(checkpoint_path, checkpoint_dir / "latest-checkpoint.pt")
|
| 134 |
+
|
| 135 |
+
def run_setup(self, run_dir: Path, n_train_examples: int) -> None:
|
| 136 |
+
# Iteratively Assemble FSDP Wrapping Policy by fetching the wrapping policies for each backbone/constituent
|
| 137 |
+
vlm_fsdp_wrapping_policy = self.vlm.get_fsdp_wrapping_policy()
|
| 138 |
+
|
| 139 |
+
# Assemble the Default FSDP Mixed Precision Policy
|
| 140 |
+
if self.enable_mixed_precision_training and self.mixed_precision_dtype == torch.bfloat16:
|
| 141 |
+
# MixedPrecision `param_dtype` specifies *compute* dtype (for forward/backward only)
|
| 142 |
+
# => Reference: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision
|
| 143 |
+
reduce_buffer_dtype = torch.bfloat16 if not self.reduce_in_full_precision else torch.float32
|
| 144 |
+
fsdp_precision_policy = MixedPrecision(
|
| 145 |
+
param_dtype=torch.bfloat16, reduce_dtype=reduce_buffer_dtype, buffer_dtype=reduce_buffer_dtype
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# When running FSDP with a frozen vision backbone --> move to half precision!
|
| 149 |
+
if self.stage not in {"full-finetune", "vla-full-train", "vla-sandwich-train"}:
|
| 150 |
+
overwatch.info("Casting Vision Backbone to *Half Precision* via `.to(dtype=...)`")
|
| 151 |
+
self.vlm.vision_backbone.to(dtype=self.vlm.vision_backbone.half_precision_dtype)
|
| 152 |
+
|
| 153 |
+
else:
|
| 154 |
+
# If we're not using mixed precision, everything is in default full precision!
|
| 155 |
+
fsdp_precision_policy = MixedPrecision(
|
| 156 |
+
param_dtype=torch.float32, reduce_dtype=torch.float32, buffer_dtype=torch.float32
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# <FSDP> => note that FSDP will automatically take care of device placement (similar to `autocast`)
|
| 160 |
+
self.vlm = FSDP(
|
| 161 |
+
self.vlm,
|
| 162 |
+
auto_wrap_policy=vlm_fsdp_wrapping_policy,
|
| 163 |
+
mixed_precision=fsdp_precision_policy,
|
| 164 |
+
sharding_strategy=self.fsdp_sharding_strategy,
|
| 165 |
+
device_id=torch.cuda.current_device(),
|
| 166 |
+
limit_all_gathers=True,
|
| 167 |
+
use_orig_params=True,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Gradient Checkpoint Setup
|
| 171 |
+
if self.enable_gradient_checkpointing:
|
| 172 |
+
# For Gradient Checkpointing under FSDP --> we make the same assumption as in the DDP/other strategies; the
|
| 173 |
+
# bulk of activation memory is taken up by the LLM activations. However, unlike other strategies, we
|
| 174 |
+
# cannot rely on the HF Transformers default `gradient_checkpointing_enable()` --> FSDP breaks semantics!
|
| 175 |
+
#
|
| 176 |
+
# Instead, we need to write our own *NO-REENTRANT* wrapper, and apply it to the LLM's Transformer Layer.
|
| 177 |
+
non_reentrant_wrapper = partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT)
|
| 178 |
+
|
| 179 |
+
def check_fn(submodule: nn.Module) -> bool:
|
| 180 |
+
return isinstance(submodule, self.llm_transformer_layer_cls)
|
| 181 |
+
|
| 182 |
+
# Note that the terms "activation checkpointing" and "gradient checkpointing" are synonymous!
|
| 183 |
+
apply_activation_checkpointing(self.vlm, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn)
|
| 184 |
+
|
| 185 |
+
# Barrier =>> Sharding takes a minute?
|
| 186 |
+
dist.barrier()
|
| 187 |
+
|
| 188 |
+
# Create Optimizer and LR Scheduler =>> note that most of the LR Schedulers we use require `max_steps/epochs`
|
| 189 |
+
# => Optimizer should only operate on parameters that are *unfrozen* / trainable!
|
| 190 |
+
n_train_examples = math.ceil(n_train_examples / self.global_batch_size) * self.global_batch_size
|
| 191 |
+
if self.max_steps is None:
|
| 192 |
+
num_training_steps = (n_train_examples * self.epochs) // self.global_batch_size
|
| 193 |
+
else:
|
| 194 |
+
num_training_steps = self.max_steps
|
| 195 |
+
|
| 196 |
+
if self.lr_scheduler_type == "linear-warmup+cosine-decay":
|
| 197 |
+
# Set warmup steps (floor) based on `warmup_ratio` (should be 0.03 - 0.05)
|
| 198 |
+
num_warmup_steps = int(num_training_steps * self.warmup_ratio)
|
| 199 |
+
|
| 200 |
+
# Default AdamW w/ specified LR & Linear Warmup / Cosine Decay & Weight Decay
|
| 201 |
+
# => Create Parameter Groups --> bias terms, normalization layer parameters shouldn't be decayed!
|
| 202 |
+
decay, no_decay = [], []
|
| 203 |
+
for name, param in self.vlm.named_parameters():
|
| 204 |
+
if not param.requires_grad:
|
| 205 |
+
continue
|
| 206 |
+
|
| 207 |
+
# Check on any parameters with fewer than 2 dimensions or with "bias" in the name
|
| 208 |
+
if param.ndim <= 1 or name.endswith(".bias"):
|
| 209 |
+
no_decay.append(param)
|
| 210 |
+
else:
|
| 211 |
+
decay.append(param)
|
| 212 |
+
|
| 213 |
+
# Build Parameter Groups
|
| 214 |
+
groups = [{"params": decay, "weight_decay": self.weight_decay}, {"params": no_decay, "weight_decay": 0.0}]
|
| 215 |
+
|
| 216 |
+
# Create Optimizer & LR Scheduler
|
| 217 |
+
self.optimizer = AdamW(groups, lr=self.learning_rate)
|
| 218 |
+
self.lr_scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps, num_training_steps)
|
| 219 |
+
for param_group in self.optimizer.param_groups:
|
| 220 |
+
param_group["lr"] = 0.0
|
| 221 |
+
|
| 222 |
+
elif self.lr_scheduler_type == "constant":
|
| 223 |
+
num_warmup_steps = 0
|
| 224 |
+
|
| 225 |
+
# Default AdamW w/ specified LR & Linear Warmup / Cosine Decay & Weight Decay
|
| 226 |
+
# => Create Parameter Groups --> bias terms, normalization layer parameters shouldn't be decayed!
|
| 227 |
+
decay, no_decay = [], []
|
| 228 |
+
for name, param in self.vlm.named_parameters():
|
| 229 |
+
if not param.requires_grad:
|
| 230 |
+
continue
|
| 231 |
+
|
| 232 |
+
# Check on any parameters with fewer than 2 dimensions or with "bias" in the name
|
| 233 |
+
if param.ndim <= 1 or name.endswith(".bias"):
|
| 234 |
+
no_decay.append(param)
|
| 235 |
+
else:
|
| 236 |
+
decay.append(param)
|
| 237 |
+
|
| 238 |
+
# Build Parameter Groups
|
| 239 |
+
groups = [{"params": decay, "weight_decay": self.weight_decay}, {"params": no_decay, "weight_decay": 0.0}]
|
| 240 |
+
|
| 241 |
+
# Create Optimizer & LR Scheduler
|
| 242 |
+
self.optimizer = AdamW(groups, lr=self.learning_rate)
|
| 243 |
+
self.lr_scheduler = get_constant_schedule(self.optimizer)
|
| 244 |
+
|
| 245 |
+
else:
|
| 246 |
+
raise ValueError(f"Learning Rate Schedule with type `{self.lr_scheduler_type}` is not supported!")
|
| 247 |
+
|
| 248 |
+
# Finalize Setup =>> Log!
|
| 249 |
+
overwatch.info(
|
| 250 |
+
"FSDP Full-Shard Strategy =>> Finalized Training Setup:\n"
|
| 251 |
+
f" |-> Global (Effective) Batch Size = {self.global_batch_size}\n"
|
| 252 |
+
f" |-> Per-Device Batch Size = {self.per_device_batch_size}\n"
|
| 253 |
+
f" |-> Distributed World Size = {overwatch.world_size()}\n"
|
| 254 |
+
f" |-> Gradient Accumulation Steps = {self.grad_accumulation_steps}\n\n"
|
| 255 |
+
f" |-> LLM Backbone FSDP Gradient Checkpointing = {self.enable_gradient_checkpointing}\n"
|
| 256 |
+
f" |-> Use FSDP Mixed Precision = {self.enable_mixed_precision_training}\n"
|
| 257 |
+
f" |-> Parameter Precision = {fsdp_precision_policy.param_dtype}\n"
|
| 258 |
+
f" |-> Reduction Precision = {fsdp_precision_policy.reduce_dtype}\n"
|
| 259 |
+
f" |-> Buffer Precision = {fsdp_precision_policy.buffer_dtype}\n\n"
|
| 260 |
+
f" |-> Default AdamW LR = {self.learning_rate}\n"
|
| 261 |
+
f" |-> AdamW Weight Decay = {self.weight_decay}\n"
|
| 262 |
+
f" |-> LR Scheduler Type = {self.lr_scheduler_type}\n"
|
| 263 |
+
f" |-> LR Scheduler Warmup Steps (Ratio) = {num_warmup_steps} ({self.warmup_ratio})\n"
|
| 264 |
+
f" |-> Dataset Size = {n_train_examples} Examples\n"
|
| 265 |
+
f" |-> Max Steps = {num_training_steps}\n"
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
def clip_grad_norm(self) -> None:
|
| 269 |
+
# Note =>> FSDP uses a custom `clip_grad_norm_` function; requires *uniform grad dtype*
|
| 270 |
+
self.vlm.clip_grad_norm_(max_norm=self.max_grad_norm)
|
capvector-oft/prismatic/training/train_utils.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utils for training/fine-tuning scripts."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_current_action_mask(token_ids):
|
| 9 |
+
# Create a tensor marking positions of IGNORE_INDEX
|
| 10 |
+
newline_positions = token_ids != IGNORE_INDEX
|
| 11 |
+
|
| 12 |
+
# Calculate cumulative sum to identify regions between newlines
|
| 13 |
+
cumsum = torch.cumsum(newline_positions, dim=1)
|
| 14 |
+
|
| 15 |
+
# Create the mask
|
| 16 |
+
mask = (1 <= cumsum) & (cumsum <= ACTION_DIM)
|
| 17 |
+
|
| 18 |
+
# Extract the action part only
|
| 19 |
+
action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX
|
| 20 |
+
mask = action_tokens_only_mask * mask
|
| 21 |
+
|
| 22 |
+
return mask
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_next_actions_mask(token_ids):
|
| 26 |
+
# Create a tensor marking positions of IGNORE_INDEX
|
| 27 |
+
newline_positions = token_ids != IGNORE_INDEX
|
| 28 |
+
|
| 29 |
+
# Calculate cumulative sum to identify regions between newlines
|
| 30 |
+
cumsum = torch.cumsum(newline_positions, dim=1)
|
| 31 |
+
|
| 32 |
+
# Create the mask
|
| 33 |
+
mask = cumsum > ACTION_DIM
|
| 34 |
+
|
| 35 |
+
# Extract the action part only
|
| 36 |
+
action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX
|
| 37 |
+
mask = action_tokens_only_mask * mask
|
| 38 |
+
|
| 39 |
+
return mask
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask):
|
| 43 |
+
correct_preds = (predicted_token_ids == ground_truth_token_ids) & mask
|
| 44 |
+
accuracy = correct_preds.sum().float() / mask.sum().float()
|
| 45 |
+
return accuracy
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask):
|
| 49 |
+
pred_continuous_actions = torch.tensor(
|
| 50 |
+
action_tokenizer.decode_token_ids_to_actions(predicted_token_ids[mask].cpu().numpy())
|
| 51 |
+
)
|
| 52 |
+
true_continuous_actions = torch.tensor(
|
| 53 |
+
action_tokenizer.decode_token_ids_to_actions(ground_truth_token_ids[mask].cpu().numpy())
|
| 54 |
+
)
|
| 55 |
+
l1_loss = torch.nn.functional.l1_loss(pred_continuous_actions, true_continuous_actions)
|
| 56 |
+
return l1_loss
|
capvector-oft/prismatic/util/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .torch_utils import check_bloat16_supported, set_global_seed
|
capvector-oft/prismatic/util/batching_utils.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
batching_utils.py
|
| 3 |
+
|
| 4 |
+
Core definitions of (Distributed) Samplers for VLM finetuning; provides functionality for construction and allocating
|
| 5 |
+
"split-modality" batches as described in the LLaVa paper; this makes sure that a given device/batch is either entirely
|
| 6 |
+
(vision, language) or (language-only) data, which leads to sizeable efficiency gains.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
from typing import Iterator, List, Optional, Tuple
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.distributed as dist
|
| 15 |
+
from torch.utils.data import Dataset, Sampler
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# High-Fidelity Bitwise Reproduction of the LLaVa Codebase Sampler Strategy + Per-Rank Allocation Scheme (following
|
| 19 |
+
# the default batching behavior of HF's Trainer Class --> derived from `accelerate`).
|
| 20 |
+
#
|
| 21 |
+
# =>> Reference: https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L60
|
| 22 |
+
# =>> Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L603
|
| 23 |
+
class SplitModalitySampler(Sampler):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
dataset: Dataset,
|
| 27 |
+
modality_lengths: List[Tuple[bool, int]],
|
| 28 |
+
global_batch_size: int,
|
| 29 |
+
num_replicas: Optional[int] = None,
|
| 30 |
+
rank: Optional[int] = None,
|
| 31 |
+
seed: int = 0,
|
| 32 |
+
drop_last: bool = False,
|
| 33 |
+
) -> None:
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.num_replicas = num_replicas if num_replicas is not None else dist.get_world_size()
|
| 36 |
+
self.rank = rank if rank is not None else dist.get_rank()
|
| 37 |
+
self.seed, self.epoch = seed, 0
|
| 38 |
+
|
| 39 |
+
# Custom Parameters
|
| 40 |
+
self.dataset, self.modality_lengths, self.drop_last = dataset, modality_lengths, drop_last
|
| 41 |
+
self.global_batch_size = global_batch_size
|
| 42 |
+
|
| 43 |
+
# For our purposes, `drop_last` is always False!
|
| 44 |
+
assert not self.drop_last, "SplitModalitySampler must set `drop_last = False`!"
|
| 45 |
+
self.total_size = math.ceil(len(self.dataset) / self.global_batch_size) * self.global_batch_size
|
| 46 |
+
self.num_samples = self.total_size // self.num_replicas
|
| 47 |
+
|
| 48 |
+
@staticmethod
|
| 49 |
+
def reindex_batch(batch_idxs: List[int], idx2lengths: List[int], n_buckets: int) -> List[List[int]]:
|
| 50 |
+
"""Re-indexes a batch in a way that is conducive to DistributedSampler + grouping by seqlen per rank."""
|
| 51 |
+
assert len(batch_idxs) % n_buckets == 0, "Batch length is not divisible by `num_replicas`!"
|
| 52 |
+
|
| 53 |
+
# Establish initial buckets, capacities, and max number of elements per bucket
|
| 54 |
+
n_examples_per_bucket = len(batch_idxs) // n_buckets
|
| 55 |
+
bucket_indices = [[] for _ in range(n_buckets)]
|
| 56 |
+
bucket_lengths = [0 for _ in range(n_buckets)]
|
| 57 |
+
|
| 58 |
+
# Note that `batch_idxs` is already sorted by corresponding length (in descending order)
|
| 59 |
+
for idx in batch_idxs:
|
| 60 |
+
shortest_bucket_idx = bucket_lengths.index(min(bucket_lengths))
|
| 61 |
+
bucket_indices[shortest_bucket_idx].append(idx)
|
| 62 |
+
|
| 63 |
+
# Update `bucket_lengths` --> set length to infinity if at capacity!
|
| 64 |
+
bucket_lengths[shortest_bucket_idx] += idx2lengths[idx]
|
| 65 |
+
if len(bucket_indices[shortest_bucket_idx]) == n_examples_per_bucket:
|
| 66 |
+
bucket_lengths[shortest_bucket_idx] = float("inf")
|
| 67 |
+
|
| 68 |
+
return bucket_indices
|
| 69 |
+
|
| 70 |
+
def get_modality_and_length_grouped_indices(self, generator: torch.Generator) -> List[int]:
|
| 71 |
+
"""
|
| 72 |
+
Returns a list of indices so that each slice of `global_batch_size` consecutive indices corresponds to elements
|
| 73 |
+
of the same modality with each sub-sequence of `per_replica_batch_size` (the batch size each unique device sees
|
| 74 |
+
during distributed training) is roughly grouped by sequence length (for training efficiency).
|
| 75 |
+
"""
|
| 76 |
+
multimodal_indices, multimodal_lengths = zip(
|
| 77 |
+
*[(idx, length) for idx, (is_multimodal, length) in enumerate(self.modality_lengths) if is_multimodal]
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Handle Special Case --> no "unimodal" inputs
|
| 81 |
+
unimodal_split = [
|
| 82 |
+
(idx, length) for idx, (is_multimodal, length) in enumerate(self.modality_lengths) if not is_multimodal
|
| 83 |
+
]
|
| 84 |
+
if len(unimodal_split) == 0:
|
| 85 |
+
unimodal_indices, unimodal_lengths = [], []
|
| 86 |
+
else:
|
| 87 |
+
unimodal_indices, unimodal_lengths = zip(*unimodal_split)
|
| 88 |
+
|
| 89 |
+
# Create a permutation of indices for each of the multimodal and unimodal data
|
| 90 |
+
mm_shuffled_idxs = torch.randperm(len(multimodal_indices), generator=generator)
|
| 91 |
+
uni_shuffled_idxs = torch.randperm(len(unimodal_indices), generator=generator)
|
| 92 |
+
|
| 93 |
+
# We're going to be running sorting/grouping relative to `self.global_batch_size` and `self.num_replicas`
|
| 94 |
+
g_bsz = self.global_batch_size
|
| 95 |
+
|
| 96 |
+
# Break each of the permutations into batches of length `global_batch_size`
|
| 97 |
+
mm_batch_idxs = [mm_shuffled_idxs[i : i + g_bsz].tolist() for i in range(0, len(mm_shuffled_idxs), g_bsz)]
|
| 98 |
+
uni_batch_idxs = [uni_shuffled_idxs[i : i + g_bsz].tolist() for i in range(0, len(uni_shuffled_idxs), g_bsz)]
|
| 99 |
+
|
| 100 |
+
# If "last" batch is not of length `g_bsz` --> PAD by stealing indices from the first batch!
|
| 101 |
+
if len(mm_batch_idxs[-1]) < g_bsz:
|
| 102 |
+
n_missing = g_bsz - len(mm_batch_idxs[-1])
|
| 103 |
+
mm_batch_idxs[-1].extend(mm_batch_idxs[0][:n_missing])
|
| 104 |
+
|
| 105 |
+
if len(uni_batch_idxs) > 0 and len(uni_batch_idxs[-1]) < g_bsz:
|
| 106 |
+
n_missing = g_bsz - len(uni_batch_idxs[-1])
|
| 107 |
+
uni_batch_idxs[-1].extend(uni_batch_idxs[0][:n_missing])
|
| 108 |
+
|
| 109 |
+
# Now we're going to sort each batch by length --> this will aid in grouping by length by rank (efficiency!)
|
| 110 |
+
mm_sorted_batch_idxs = [sorted(b, key=lambda i: multimodal_lengths[i], reverse=True) for b in mm_batch_idxs]
|
| 111 |
+
uni_sorted_batch_idxs = [sorted(b, key=lambda i: unimodal_lengths[i], reverse=True) for b in uni_batch_idxs]
|
| 112 |
+
|
| 113 |
+
# IMPORTANT :: At this point, for each modality, we have a list of "batches" (made up of indices) where indices
|
| 114 |
+
# are sorted by example sequence length *within* each batch. To make this more concrete, consider the following:
|
| 115 |
+
# => World Size (`num_replicas`) = 2
|
| 116 |
+
# => Global Batch Size (`g_bsz`) = 4
|
| 117 |
+
# => `multimodal_indices` = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
| 118 |
+
# `multimodal_lengths` = [20, 90, 21, 22, 91, 18, 89, 19, 93, 88, 92, 17]
|
| 119 |
+
#
|
| 120 |
+
# At this point in the code, `mm_sorted_batch_idxs` might then look like the following (length in parenthesis):
|
| 121 |
+
# => `mm_sorted_batch_idxs`: [
|
| 122 |
+
# [4 (91), 3 (21), 0 (20), 5 (18)] => Batch 1
|
| 123 |
+
# [6 (89), 9 (88), 7 (19), 11 (17)] => Batch 2
|
| 124 |
+
# [8 (93), 10 (92), 1 (90), 2 (21)] => Batch 3
|
| 125 |
+
# ]
|
| 126 |
+
#
|
| 127 |
+
# In practice: `g_bsz` is large (= 128), and for contiguous mini-batch "slices", length variance is low.
|
| 128 |
+
|
| 129 |
+
# PROBLEM :: We want to split these "global batches" into equal-sized pieces, so that each "replica" (GPU)
|
| 130 |
+
# sees a "mini-batch" of roughly the same sequence lengths; this is super useful for efficient training.
|
| 131 |
+
|
| 132 |
+
# HOWEVER :: The default "access pattern" for splitting a large batch into mini-batches by a DistributedSampler
|
| 133 |
+
# is akin to a "take every k" where `k` is equal to the number of replicas (GPUs) you're training on. Or, in
|
| 134 |
+
# Python notation --> `rank_k_indices = flatten(mm_sorted_batch_idxs)[k::num_replicas].
|
| 135 |
+
#
|
| 136 |
+
# Naively translating this our example means each GPU (in our world of 2 total) sees the following indices
|
| 137 |
+
# (grouped by "mini-batch" = `g_bsz / num_replicas` = 2 for convenience):
|
| 138 |
+
# => `rank_0_indices`: [ [4 (91), 0 (20)] =>> [6 (89), 7 (19)] =>> [8 (93), 1 (90)] ]
|
| 139 |
+
# => `rank_1_indices`: [ [3 (21), 5 (18)] =>> [9 (88), 11 (17)] =>> [10 (92), 2 (21)] ]
|
| 140 |
+
#
|
| 141 |
+
# We get lucky sometimes, but for the most part, each "mini-batch" has VASTLY DIFFERENT lengths! Bad!
|
| 142 |
+
|
| 143 |
+
# FIX :: If we "undo" the access pattern with the following code and re-arrange the way we allocate batches
|
| 144 |
+
# inside the __iter__ method below, we can allocate indices appropriately. Running the following code gives us
|
| 145 |
+
# the following indices (grouped by "mini-batch" again for convenience):
|
| 146 |
+
# => `rank_0_indices`: [ [4 (91), 3 (21)] =>> [6 (89), 9 (88)] =>> [8 (93), 10 (92)] ]
|
| 147 |
+
# => `rank_1_indices`: [ [5 (18), 0 (20)] =>> [11 (17), 7 (19)] =>> [2 (21), 1 (90)] ]
|
| 148 |
+
#
|
| 149 |
+
# Much better! As `g_bsz` and `dataset` grow, we're more often than not getting *decent* groupings!
|
| 150 |
+
mm_length_bucketed_idxs = [
|
| 151 |
+
self.reindex_batch(batch, multimodal_lengths, self.num_replicas) for batch in mm_sorted_batch_idxs
|
| 152 |
+
]
|
| 153 |
+
uni_length_bucketed_idxs = [
|
| 154 |
+
self.reindex_batch(batch, unimodal_lengths, self.num_replicas) for batch in uni_sorted_batch_idxs
|
| 155 |
+
]
|
| 156 |
+
|
| 157 |
+
# Note :: Because of the initial `randperm` --> we're indexing both sets from 0 (we're clobbering the range)
|
| 158 |
+
# => Flatten indices --> index into original `{modality}_indices` then re-batch!
|
| 159 |
+
mm_output_idxs = [idx for batch in mm_length_bucketed_idxs for bucket in batch for idx in bucket]
|
| 160 |
+
mm_reindexed = [multimodal_indices[idx] for idx in mm_output_idxs]
|
| 161 |
+
mm_batches = [mm_reindexed[i : i + g_bsz] for i in range(0, len(mm_reindexed), g_bsz)]
|
| 162 |
+
|
| 163 |
+
uni_output_idxs = [idx for batch in uni_length_bucketed_idxs for bucket in batch for idx in bucket]
|
| 164 |
+
uni_reindexed = [unimodal_indices[idx] for idx in uni_output_idxs]
|
| 165 |
+
uni_batches = [uni_reindexed[i : i + g_bsz] for i in range(0, len(uni_reindexed), g_bsz)]
|
| 166 |
+
|
| 167 |
+
# Finally, randomly permute the multimodal & unimodal batches, merging into a single stream of indices
|
| 168 |
+
merged_batches = mm_batches + uni_batches
|
| 169 |
+
merge_idxs = torch.randperm(len(merged_batches), generator=generator)
|
| 170 |
+
all_batches = [merged_batches[idx] for idx in merge_idxs]
|
| 171 |
+
|
| 172 |
+
# [Quality of Life] Shift "max length" batch to index 0 --> if we OOM, it happens immediately!
|
| 173 |
+
all_lengths = [length + ((_n_patches := 24 * 24) if is_mm else 0) for is_mm, length in self.modality_lengths]
|
| 174 |
+
all_batches_max_lengths = []
|
| 175 |
+
for batch in all_batches:
|
| 176 |
+
all_batches_max_lengths.append(max([all_lengths[idx] for idx in batch]))
|
| 177 |
+
|
| 178 |
+
# Identify Batch with "max length" --> Swap into Index 0
|
| 179 |
+
longest_batch_idx = np.argmax(all_batches_max_lengths)
|
| 180 |
+
all_batches[0], all_batches[longest_batch_idx] = all_batches[longest_batch_idx], all_batches[0]
|
| 181 |
+
|
| 182 |
+
# Flatten & Return all Indices
|
| 183 |
+
indices = [idx for batch in all_batches for idx in batch]
|
| 184 |
+
return indices
|
| 185 |
+
|
| 186 |
+
def __iter__(self) -> Iterator:
|
| 187 |
+
"""Deterministically shuffle, then split indices by modality and length."""
|
| 188 |
+
g = torch.Generator()
|
| 189 |
+
g.manual_seed(self.seed + self.epoch)
|
| 190 |
+
indices = self.get_modality_and_length_grouped_indices(g)
|
| 191 |
+
assert len(set(indices)) == len(self.modality_lengths) == len(self.dataset), "Oops!"
|
| 192 |
+
assert (len(indices) % self.global_batch_size == 0) and (len(indices) % self.num_replicas) == 0, "Oops"
|
| 193 |
+
|
| 194 |
+
# Note :: We compute per-replica batch size as a function of `global_batch` and `num_replicas` to ensure that
|
| 195 |
+
# gradient accumulation doesn't affect what indices are assigned a given rank.
|
| 196 |
+
per_replica_batch_size = self.global_batch_size // self.num_replicas
|
| 197 |
+
|
| 198 |
+
# Tensorize & Unravel --> rather than yielding via a `take_every` --> we want to partition a global batch
|
| 199 |
+
# across replicas by assigning each a contiguous sub-sequence.
|
| 200 |
+
indices_t = torch.as_tensor(indices)
|
| 201 |
+
per_replica_batch_indices_t = indices_t.reshape(-1, per_replica_batch_size)
|
| 202 |
+
replica_indices_t = per_replica_batch_indices_t[self.rank :: self.num_replicas]
|
| 203 |
+
|
| 204 |
+
replica_indices = replica_indices_t.flatten().tolist()
|
| 205 |
+
return iter(replica_indices)
|
| 206 |
+
|
| 207 |
+
def __len__(self) -> int:
|
| 208 |
+
return self.num_samples
|
| 209 |
+
|
| 210 |
+
def set_epoch(self, epoch: int) -> None:
|
| 211 |
+
"""To be called *between* epochs, prior to DataLoader instantiation; ensures random order across epochs."""
|
| 212 |
+
self.epoch = epoch
|
capvector-oft/prismatic/util/data_utils.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
data_utils.py
|
| 3 |
+
|
| 4 |
+
General utilities and classes for facilitating data loading and collation.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Callable, Dict, Sequence, Tuple
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 13 |
+
|
| 14 |
+
# HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels)
|
| 15 |
+
IGNORE_INDEX = -100
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def tree_map(fn: Callable, tree: dict) -> dict:
|
| 19 |
+
"""Maps a function over a nested dictionary."""
|
| 20 |
+
return {k: tree_map(fn, v) if isinstance(v, dict) else fn(v) for k, v in tree.items()}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def tree_map_with_key(fn: Callable, tree: dict, keys: Sequence = ()) -> dict:
|
| 24 |
+
"""Maps a function over a nested dictionary."""
|
| 25 |
+
return {
|
| 26 |
+
k: tree_map_with_key(fn, v, (*keys, k)) if isinstance(v, dict) else fn((*keys, k), v) for k, v in tree.items()
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class PaddedCollatorForLanguageModeling:
|
| 32 |
+
model_max_length: int
|
| 33 |
+
pad_token_id: int
|
| 34 |
+
default_image_resolution: Tuple[int, int, int]
|
| 35 |
+
padding_side: str = "right"
|
| 36 |
+
pixel_values_dtype: torch.dtype = torch.float32
|
| 37 |
+
|
| 38 |
+
def __post_init__(self) -> None:
|
| 39 |
+
self.dummy_pixel_values = torch.zeros(self.default_image_resolution, dtype=self.pixel_values_dtype)
|
| 40 |
+
|
| 41 |
+
def __call__(self, instances: Sequence[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
| 42 |
+
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
| 43 |
+
pixel_values = [instance["pixel_values"] for instance in instances]
|
| 44 |
+
|
| 45 |
+
# For now, we only support Tokenizers with `padding_side = "right"` during Training (but plan to extend!)
|
| 46 |
+
# => Handle padding via RNN Utils => `pad_sequence`
|
| 47 |
+
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id)
|
| 48 |
+
labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
| 49 |
+
|
| 50 |
+
# Truncate (if necessary)
|
| 51 |
+
input_ids, labels = input_ids[:, : self.model_max_length], labels[:, : self.model_max_length]
|
| 52 |
+
|
| 53 |
+
# Get `attention_mask` by checking for `pad_token_id`
|
| 54 |
+
attention_mask = input_ids.ne(self.pad_token_id)
|
| 55 |
+
|
| 56 |
+
# === Handle "unimodal" (language-only) vs. "multimodal" ===
|
| 57 |
+
|
| 58 |
+
# Some examples are "language-only" --> build a Tensor of `multimodal_indices` that we can slice into easily
|
| 59 |
+
multimodal_indices = torch.tensor(
|
| 60 |
+
[idx for idx in range(len(pixel_values)) if pixel_values[idx] is not None], dtype=torch.long
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Stack all `pixel_values` --> depending on type (torch.Tensor, or Dict[str, torch.Tensor]) & presence of None
|
| 64 |
+
if len(multimodal_indices) == 0:
|
| 65 |
+
pixel_values = torch.stack([self.dummy_pixel_values for _ in range(len(input_ids))])
|
| 66 |
+
elif isinstance(pv_example := pixel_values[multimodal_indices[0]], torch.Tensor):
|
| 67 |
+
pixel_values = torch.stack(
|
| 68 |
+
[
|
| 69 |
+
pixel_values[idx] if idx in multimodal_indices else self.dummy_pixel_values
|
| 70 |
+
for idx in range(len(input_ids))
|
| 71 |
+
]
|
| 72 |
+
)
|
| 73 |
+
elif isinstance(pv_example, dict):
|
| 74 |
+
pixel_values = {
|
| 75 |
+
k: torch.stack(
|
| 76 |
+
[
|
| 77 |
+
pixel_values[idx][k] if idx in multimodal_indices else self.dummy_pixel_values
|
| 78 |
+
for idx in range(len(input_ids))
|
| 79 |
+
]
|
| 80 |
+
)
|
| 81 |
+
for k in pv_example
|
| 82 |
+
}
|
| 83 |
+
else:
|
| 84 |
+
raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}")
|
| 85 |
+
|
| 86 |
+
return dict(
|
| 87 |
+
pixel_values=pixel_values,
|
| 88 |
+
input_ids=input_ids,
|
| 89 |
+
attention_mask=attention_mask,
|
| 90 |
+
labels=labels,
|
| 91 |
+
multimodal_indices=multimodal_indices,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@dataclass
|
| 96 |
+
class PaddedCollatorForActionPrediction:
|
| 97 |
+
model_max_length: int
|
| 98 |
+
pad_token_id: int
|
| 99 |
+
padding_side: str = "right"
|
| 100 |
+
pixel_values_dtype: torch.dtype = torch.float32
|
| 101 |
+
|
| 102 |
+
def __call__(self, instances: Sequence[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
| 103 |
+
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
| 104 |
+
pixel_values = [instance["pixel_values"] for instance in instances]
|
| 105 |
+
if "dataset_name" in instances[0]:
|
| 106 |
+
dataset_names = [instance["dataset_name"] for instance in instances]
|
| 107 |
+
else:
|
| 108 |
+
dataset_names = None
|
| 109 |
+
|
| 110 |
+
# For now, we only support Tokenizers with `padding_side = "right"` during training
|
| 111 |
+
# => Handle padding via RNN Utils => `pad_sequence`
|
| 112 |
+
assert self.padding_side == "right", f"Invalid Tokenizer `{self.padding_side = }`"
|
| 113 |
+
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id)
|
| 114 |
+
labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
| 115 |
+
|
| 116 |
+
# Truncate (if necessary)
|
| 117 |
+
input_ids, labels = input_ids[:, : self.model_max_length], labels[:, : self.model_max_length]
|
| 118 |
+
|
| 119 |
+
# Get `attention_mask` by checking for `pad_token_id`
|
| 120 |
+
attention_mask = input_ids.ne(self.pad_token_id)
|
| 121 |
+
|
| 122 |
+
# [Contract] For VLA Training =>> No "Unimodal" Data!
|
| 123 |
+
assert all([pv is not None for pv in pixel_values]), "Invalid VLA Example with `pixel_values = None`!"
|
| 124 |
+
|
| 125 |
+
# Stack all `pixel_values` --> depending on type is torch.Tensor or Dict[str, torch.Tensor]
|
| 126 |
+
if isinstance(pixel_values[0], torch.Tensor):
|
| 127 |
+
if "pixel_values_wrist" in instances[0]:
|
| 128 |
+
pixel_values_wrist = [instance["pixel_values_wrist"] for instance in instances]
|
| 129 |
+
pixel_values = torch.cat((torch.stack(pixel_values), torch.stack(pixel_values_wrist)), dim=1)
|
| 130 |
+
else:
|
| 131 |
+
pixel_values = torch.stack(pixel_values)
|
| 132 |
+
else:
|
| 133 |
+
raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}")
|
| 134 |
+
|
| 135 |
+
# Stack all actions
|
| 136 |
+
actions = [torch.from_numpy(np.copy(instance["actions"])) for instance in instances]
|
| 137 |
+
actions = torch.stack(actions)
|
| 138 |
+
|
| 139 |
+
# Stack proprio
|
| 140 |
+
if "proprio" in instances[0]:
|
| 141 |
+
proprio = [instance["proprio"] for instance in instances]
|
| 142 |
+
proprio = torch.Tensor(np.squeeze(np.stack(proprio)))
|
| 143 |
+
else:
|
| 144 |
+
proprio = None
|
| 145 |
+
|
| 146 |
+
output = dict(
|
| 147 |
+
pixel_values=pixel_values,
|
| 148 |
+
proprio=proprio,
|
| 149 |
+
input_ids=input_ids,
|
| 150 |
+
attention_mask=attention_mask,
|
| 151 |
+
labels=labels,
|
| 152 |
+
actions=actions,
|
| 153 |
+
)
|
| 154 |
+
if dataset_names is not None:
|
| 155 |
+
output["dataset_names"] = dataset_names
|
| 156 |
+
return output
|
capvector-oft/prismatic/util/nn_utils.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
nn_utils.py
|
| 3 |
+
|
| 4 |
+
Utility functions and PyTorch submodule definitions.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# === Definitions for Various Projection Modules, with Signature :: [..., in_dim] --> [..., out_dim] ===
|
| 12 |
+
class LinearProjector(nn.Module):
|
| 13 |
+
def __init__(self, vision_dim: int, llm_dim: int) -> None:
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.projector = nn.Linear(vision_dim, llm_dim, bias=True)
|
| 16 |
+
|
| 17 |
+
def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
|
| 18 |
+
return self.projector(img_patches)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class MLPProjector(nn.Module):
|
| 22 |
+
def __init__(self, vision_dim: int, llm_dim: int, mlp_type: str = "gelu-mlp") -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
if mlp_type == "gelu-mlp":
|
| 25 |
+
self.projector = nn.Sequential(
|
| 26 |
+
nn.Linear(vision_dim, llm_dim, bias=True),
|
| 27 |
+
nn.GELU(),
|
| 28 |
+
nn.Linear(llm_dim, llm_dim, bias=True),
|
| 29 |
+
)
|
| 30 |
+
else:
|
| 31 |
+
raise ValueError(f"Projector with `{mlp_type = }` is not supported!")
|
| 32 |
+
|
| 33 |
+
def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
|
| 34 |
+
return self.projector(img_patches)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class FusedMLPProjector(nn.Module):
|
| 38 |
+
def __init__(self, fused_vision_dim: int, llm_dim: int, mlp_type: str = "fused-gelu-mlp") -> None:
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.initial_projection_dim = fused_vision_dim * 4
|
| 41 |
+
if mlp_type == "fused-gelu-mlp":
|
| 42 |
+
self.projector = nn.Sequential(
|
| 43 |
+
nn.Linear(fused_vision_dim, self.initial_projection_dim, bias=True),
|
| 44 |
+
nn.GELU(),
|
| 45 |
+
nn.Linear(self.initial_projection_dim, llm_dim, bias=True),
|
| 46 |
+
nn.GELU(),
|
| 47 |
+
nn.Linear(llm_dim, llm_dim, bias=True),
|
| 48 |
+
)
|
| 49 |
+
else:
|
| 50 |
+
raise ValueError(f"Fused Projector with `{mlp_type = }` is not supported!")
|
| 51 |
+
|
| 52 |
+
def forward(self, fused_img_patches: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
return self.projector(fused_img_patches)
|
capvector-oft/prismatic/util/torch_utils.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
torch_utils.py
|
| 3 |
+
|
| 4 |
+
General utilities for randomness, mixed precision training, and miscellaneous checks in PyTorch.
|
| 5 |
+
|
| 6 |
+
Random `set_global_seed` functionality is taken directly from PyTorch-Lighting:
|
| 7 |
+
> Ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/seed.py
|
| 8 |
+
|
| 9 |
+
This is pretty important to get right if we're every randomly generating our masks (or prefix dropout) inside our
|
| 10 |
+
Dataset __getitem__() with multiple workers... if not handled properly, we will get repeated augmentations anytime
|
| 11 |
+
we inject randomness from non-PyTorch sources (e.g., numpy, random)!
|
| 12 |
+
> Ref: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/
|
| 13 |
+
|
| 14 |
+
Terminology
|
| 15 |
+
-> World Size :: Total number of processes distributed over (# nodes x # devices) -- assumed homogenous!
|
| 16 |
+
-> Rank :: Integer index of current process in the total world size
|
| 17 |
+
-> Local Rank :: Local index on given node in [0, Devices per Node]
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
import random
|
| 22 |
+
from typing import Callable, Optional
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import torch
|
| 26 |
+
|
| 27 |
+
# === Randomness ===
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def set_global_seed(seed: int, get_worker_init_fn: bool = False) -> Optional[Callable[[int], None]]:
|
| 31 |
+
"""Sets seed for all randomness libraries (mostly random, numpy, torch) and produces a `worker_init_fn`"""
|
| 32 |
+
assert np.iinfo(np.uint32).min < seed < np.iinfo(np.uint32).max, "Seed outside the np.uint32 bounds!"
|
| 33 |
+
|
| 34 |
+
# Set Seed as an Environment Variable
|
| 35 |
+
os.environ["EXPERIMENT_GLOBAL_SEED"] = str(seed)
|
| 36 |
+
random.seed(seed)
|
| 37 |
+
np.random.seed(seed)
|
| 38 |
+
torch.manual_seed(seed)
|
| 39 |
+
|
| 40 |
+
return worker_init_function if get_worker_init_fn else None
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def worker_init_function(worker_id: int) -> None:
|
| 44 |
+
"""
|
| 45 |
+
Borrowed directly from PyTorch-Lightning; inspired by this issue comment in the PyTorch repo:
|
| 46 |
+
> Ref: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
|
| 47 |
+
|
| 48 |
+
Intuition: You can think of the seed sequence spawn function as a "janky" torch.Generator() or jax.PRNGKey that
|
| 49 |
+
you can run iterative splitting on to get new (predictable) randomness.
|
| 50 |
+
|
| 51 |
+
:param worker_id: Identifier for the given worker [0, num_workers) for the Dataloader in question.
|
| 52 |
+
"""
|
| 53 |
+
# Get current `rank` (if running distributed) and `process_seed`
|
| 54 |
+
global_rank, process_seed = int(os.environ["LOCAL_RANK"]), torch.initial_seed()
|
| 55 |
+
|
| 56 |
+
# Back out the "base" (original) seed - the per-worker seed is set in PyTorch:
|
| 57 |
+
# > https://pytorch.org/docs/stable/data.html#data-loading-randomness
|
| 58 |
+
base_seed = process_seed - worker_id
|
| 59 |
+
|
| 60 |
+
# "Magic" code --> basically creates a seed sequence that mixes different "sources" and seeds every library...
|
| 61 |
+
seed_seq = np.random.SeedSequence([base_seed, worker_id, global_rank])
|
| 62 |
+
|
| 63 |
+
# Use 128 bits (4 x 32-bit words) to represent seed --> generate_state(k) produces a `k` element array!
|
| 64 |
+
np.random.seed(seed_seq.generate_state(4))
|
| 65 |
+
|
| 66 |
+
# Spawn distinct child sequences for PyTorch (reseed) and stdlib random
|
| 67 |
+
torch_seed_seq, random_seed_seq = seed_seq.spawn(2)
|
| 68 |
+
|
| 69 |
+
# Torch Manual seed takes 64 bits (so just specify a dtype of uint64
|
| 70 |
+
torch.manual_seed(torch_seed_seq.generate_state(1, dtype=np.uint64)[0])
|
| 71 |
+
|
| 72 |
+
# Use 128 Bits for `random`, but express as integer instead of as an array
|
| 73 |
+
random_seed = (random_seed_seq.generate_state(2, dtype=np.uint64).astype(list) * [1 << 64, 1]).sum()
|
| 74 |
+
random.seed(random_seed)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# === BFloat16 Support ===
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def check_bloat16_supported() -> bool:
|
| 81 |
+
try:
|
| 82 |
+
import packaging.version
|
| 83 |
+
import torch.cuda.nccl as nccl
|
| 84 |
+
import torch.distributed as dist
|
| 85 |
+
|
| 86 |
+
return (
|
| 87 |
+
(torch.version.cuda is not None)
|
| 88 |
+
and torch.cuda.is_bf16_supported()
|
| 89 |
+
and (packaging.version.parse(torch.version.cuda).release >= (11, 0))
|
| 90 |
+
and dist.is_nccl_available()
|
| 91 |
+
and (nccl.version() >= (2, 10))
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
except Exception:
|
| 95 |
+
return False
|