haofuly commited on
Commit
cf587f4
·
verified ·
1 Parent(s): b23769d

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. capvector-oft/prismatic/models/__init__.py +2 -0
  2. capvector-oft/prismatic/models/backbones/__init__.py +0 -0
  3. capvector-oft/prismatic/models/backbones/llm/__init__.py +4 -0
  4. capvector-oft/prismatic/models/backbones/llm/base_llm.py +223 -0
  5. capvector-oft/prismatic/models/backbones/llm/llama2.py +102 -0
  6. capvector-oft/prismatic/models/backbones/llm/mistral.py +72 -0
  7. capvector-oft/prismatic/models/backbones/llm/phi.py +64 -0
  8. capvector-oft/prismatic/models/backbones/llm/prompting/__init__.py +5 -0
  9. capvector-oft/prismatic/models/backbones/llm/prompting/base_prompter.py +73 -0
  10. capvector-oft/prismatic/models/backbones/llm/prompting/llama2_chat_prompter.py +91 -0
  11. capvector-oft/prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py +60 -0
  12. capvector-oft/prismatic/models/backbones/llm/prompting/phi_prompter.py +65 -0
  13. capvector-oft/prismatic/models/backbones/llm/prompting/vicuna_v15_prompter.py +82 -0
  14. capvector-oft/prismatic/models/backbones/vision/__init__.py +7 -0
  15. capvector-oft/prismatic/models/backbones/vision/base_vision.py +207 -0
  16. capvector-oft/prismatic/models/backbones/vision/clip_vit.py +27 -0
  17. capvector-oft/prismatic/models/backbones/vision/dinoclip_vit.py +147 -0
  18. capvector-oft/prismatic/models/backbones/vision/dinosiglip_vit.py +164 -0
  19. capvector-oft/prismatic/models/backbones/vision/dinov2_vit.py +19 -0
  20. capvector-oft/prismatic/models/backbones/vision/in1k_vit.py +22 -0
  21. capvector-oft/prismatic/models/backbones/vision/siglip_vit.py +24 -0
  22. capvector-oft/prismatic/models/load.py +226 -0
  23. capvector-oft/prismatic/models/materialize.py +130 -0
  24. capvector-oft/prismatic/models/projectors.py +49 -0
  25. capvector-oft/prismatic/models/registry.py +691 -0
  26. capvector-oft/prismatic/models/vlas/__init__.py +1 -0
  27. capvector-oft/prismatic/models/vlas/openvla.py +131 -0
  28. capvector-oft/prismatic/models/vlms/__init__.py +1 -0
  29. capvector-oft/prismatic/models/vlms/base_vlm.py +108 -0
  30. capvector-oft/prismatic/models/vlms/prismatic.py +621 -0
  31. capvector-oft/prismatic/overwatch/__init__.py +1 -0
  32. capvector-oft/prismatic/overwatch/overwatch.py +147 -0
  33. capvector-oft/prismatic/preprocessing/__init__.py +2 -0
  34. capvector-oft/prismatic/preprocessing/datasets/__init__.py +1 -0
  35. capvector-oft/prismatic/preprocessing/datasets/datasets.py +200 -0
  36. capvector-oft/prismatic/preprocessing/download.py +207 -0
  37. capvector-oft/prismatic/preprocessing/materialize.py +69 -0
  38. capvector-oft/prismatic/training/__init__.py +2 -0
  39. capvector-oft/prismatic/training/materialize.py +66 -0
  40. capvector-oft/prismatic/training/metrics.py +348 -0
  41. capvector-oft/prismatic/training/strategies/__init__.py +3 -0
  42. capvector-oft/prismatic/training/strategies/base_strategy.py +417 -0
  43. capvector-oft/prismatic/training/strategies/ddp.py +128 -0
  44. capvector-oft/prismatic/training/strategies/fsdp.py +270 -0
  45. capvector-oft/prismatic/training/train_utils.py +56 -0
  46. capvector-oft/prismatic/util/__init__.py +1 -0
  47. capvector-oft/prismatic/util/batching_utils.py +212 -0
  48. capvector-oft/prismatic/util/data_utils.py +156 -0
  49. capvector-oft/prismatic/util/nn_utils.py +53 -0
  50. capvector-oft/prismatic/util/torch_utils.py +95 -0
capvector-oft/prismatic/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .load import available_model_names, available_models, get_model_description, load, load_vla
2
+ from .materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform, get_vlm
capvector-oft/prismatic/models/backbones/__init__.py ADDED
File without changes
capvector-oft/prismatic/models/backbones/llm/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .base_llm import LLMBackbone
2
+ from .llama2 import LLaMa2LLMBackbone
3
+ from .mistral import MistralLLMBackbone
4
+ from .phi import PhiLLMBackbone
capvector-oft/prismatic/models/backbones/llm/base_llm.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ base_llm.py
3
+
4
+ Abstract class definition of a large (autoregressive) language model backbone (LLM), with full annotations of class
5
+ methods, utility functions, and initialization logic.
6
+
7
+ We also define the generic HFLLMBackbone class here, providing a default interface for loading any HF
8
+ AutoModelForCausalLM (e.g., LLamaForCausalLM). In general, we make the assumption that any given LLM backbone implements
9
+ the AutoModelForCausalLM API (though we may add Seq2Seq models in the future).
10
+
11
+ We make this assumption to keep the LLM handling in this codebase relatively lightweight, and to inherit all the nice HF
12
+ utilities around different types of decoding/generation strategies.
13
+ """
14
+
15
+ import warnings
16
+ from abc import ABC, abstractmethod
17
+ from functools import partial
18
+ from typing import Callable, List, Optional, Sequence, Type
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
23
+ from transformers import AutoConfig, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+
26
+ from prismatic.models.backbones.llm.prompting import PromptBuilder
27
+ from prismatic.overwatch import initialize_overwatch
28
+
29
+ # Suppress HF Deprecation Warnings
30
+ warnings.filterwarnings("ignore", category=FutureWarning)
31
+
32
+ # Initialize Overwatch =>> Wraps `logging.Logger`
33
+ overwatch = initialize_overwatch(__name__)
34
+
35
+
36
+ # === Abstract Base Class for arbitrary HF LLM Backbones ===
37
+ class LLMBackbone(nn.Module, ABC):
38
+ def __init__(self, llm_backbone_id: str) -> None:
39
+ super().__init__()
40
+ self.identifier = llm_backbone_id
41
+
42
+ # Instance attributes for an LLM Backbone
43
+ self.llm: PreTrainedModel = None
44
+ self.tokenizer: PreTrainedTokenizerBase = None
45
+
46
+ def get_tokenizer(self) -> PreTrainedTokenizerBase:
47
+ return self.tokenizer
48
+
49
+ @abstractmethod
50
+ def get_fsdp_wrapping_policy(self) -> Callable: ...
51
+
52
+ @abstractmethod
53
+ def enable_gradient_checkpointing(self) -> None: ...
54
+
55
+ @abstractmethod
56
+ def forward(
57
+ self,
58
+ input_ids: Optional[torch.LongTensor] = None,
59
+ attention_mask: Optional[torch.Tensor] = None,
60
+ position_ids: Optional[torch.LongTensor] = None,
61
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
62
+ inputs_embeds: Optional[torch.FloatTensor] = None,
63
+ labels: Optional[torch.LongTensor] = None,
64
+ use_cache: Optional[bool] = None,
65
+ output_attentions: Optional[bool] = None,
66
+ output_hidden_states: Optional[bool] = None,
67
+ return_dict: Optional[bool] = None,
68
+ ) -> CausalLMOutputWithPast:
69
+ """Run a forward pass through the LLM given targets (labels), returning the scalar Cross-Entropy Loss"""
70
+ raise NotImplementedError
71
+
72
+ @abstractmethod
73
+ def embed_input_ids(self, input_ids: torch.LongTensor) -> torch.Tensor: ...
74
+
75
+ @property
76
+ @abstractmethod
77
+ def prompt_builder_fn(self) -> Type[PromptBuilder]: ...
78
+
79
+ @property
80
+ @abstractmethod
81
+ def transformer_layer_cls(self) -> Type[nn.Module]: ...
82
+
83
+ @property
84
+ @abstractmethod
85
+ def half_precision_dtype(self) -> torch.dtype: ...
86
+
87
+ @property
88
+ @abstractmethod
89
+ def last_layer_finetune_modules(self) -> Sequence[nn.Module]: ...
90
+
91
+ @property
92
+ def embed_dim(self) -> int:
93
+ return self.llm.config.hidden_size
94
+
95
+ @property
96
+ def pad_token_id(self) -> int:
97
+ return self.tokenizer.pad_token_id
98
+
99
+
100
+ # === Abstract Base Class for Arbitrary HF Causal LLMs ===
101
+ class HFCausalLLMBackbone(LLMBackbone, ABC):
102
+ def __init__(
103
+ self,
104
+ llm_backbone_id: str,
105
+ llm_family: str,
106
+ llm_cls: Type[PreTrainedModel],
107
+ hf_hub_path: str,
108
+ llm_max_length: int = 2048,
109
+ hf_token: Optional[str] = None,
110
+ inference_mode: bool = False,
111
+ use_flash_attention_2: bool = False,
112
+ ) -> None:
113
+ super().__init__(llm_backbone_id)
114
+ self.llm_family = llm_family
115
+ self.llm_max_length = llm_max_length
116
+ self.inference_mode = inference_mode
117
+
118
+ # Initialize LLM (downloading from HF Hub if necessary) --> `llm_cls` is the actual {Model}ForCausalLM class!
119
+ # => Note: We're eschewing use of the AutoModel API so that we can be more explicit about LLM-specific details
120
+ if not self.inference_mode:
121
+ overwatch.info(f"Loading [bold]{llm_family}[/] LLM from [underline]`{hf_hub_path}`[/]", ctx_level=1)
122
+ self.llm = llm_cls.from_pretrained(
123
+ hf_hub_path,
124
+ token=hf_token,
125
+ use_flash_attention_2=use_flash_attention_2 if not self.inference_mode else False,
126
+ # The following parameters are set to prevent `UserWarnings` from HF; we want greedy decoding!
127
+ do_sample=False,
128
+ temperature=1.0,
129
+ top_p=1.0,
130
+ )
131
+
132
+ # [Contract] `inference_mode` means we're loading from a pretrained checkpoint; no need to load base weights!
133
+ else:
134
+ overwatch.info(f"Building empty [bold]{llm_family}[/] LLM from [underline]`{hf_hub_path}`[/]", ctx_level=1)
135
+ llm_config = AutoConfig.from_pretrained(hf_hub_path, token=hf_token)
136
+ self.llm = llm_cls._from_config(llm_config)
137
+
138
+ # Lightweight Handling (with extended explanation) for setting some LLM Parameters
139
+ # => Set `decoder.use_cache = False` --> incompatible with gradient checkpointing (+ training in general)
140
+ #
141
+ # Reference: https://discuss.huggingface.co/t/what-is-the-purpose-of-use-cache-in-decoder/958
142
+ self.llm.config.use_cache = False if not self.inference_mode else True
143
+
144
+ # => Turns out that when gradient checkpointing is on and the underlying LLM has no "trainable" parameters
145
+ # (requires_grad is False), backprop will fail; setting `enable_input_requires_grad()` registers a new
146
+ # forward hook that fixes this =>> also totally safe for the "full finetuning" setting!
147
+ if not self.inference_mode:
148
+ self.llm.enable_input_require_grads()
149
+
150
+ # Load (Fast) Tokenizer
151
+ overwatch.info(f"Loading [bold]{llm_family}[/] (Fast) Tokenizer via the AutoTokenizer API", ctx_level=1)
152
+ self.tokenizer = AutoTokenizer.from_pretrained(
153
+ hf_hub_path, model_max_length=self.llm_max_length, token=hf_token, padding_side="right"
154
+ )
155
+
156
+ # Validation =>> Our VLM logic currently operates under the assumption that the tokenization of a new input
157
+ # starts with a <BOS> token unless `add_special_tokens = False`; for these models, we empirically
158
+ # find that adding image patches *after* the BOS leads to much better performance.
159
+ #
160
+ # As a result we explicitly validate that a tokenizer conforms to the expected behavior; if you're reading this
161
+ # line, it's probably because you're adding a new LLM with a different tokenizer behavior. If so, feel free to
162
+ # override the `SPECIAL_CASES` set below, but make sure to make the appropriate changes in the `datasets.py`
163
+ # and VLM `forward()` logic!
164
+ SPECIAL_CASES = {
165
+ # Phi-2 Tokenizer doesn't add any BOS tokens by default, and sets BOS == EOS == "<|endoftext|>"
166
+ # =>> We'll prepend BOS to first input (to play nicely with image token insertion logic; verified that
167
+ # this works well with base LLM generation.
168
+ # =>> Like Llama-2 Tokenizers -- we'll add a special PAD token for training purposes.
169
+ "phi-2-3b",
170
+ }
171
+ if self.identifier in SPECIAL_CASES:
172
+ return
173
+
174
+ # Note =>> this assert should hold for all Llama-derived tokenizers (`LlamaTokenizerFast` ==> includes Mistral!
175
+ assert (self.tokenizer("Test 123", add_special_tokens=True).input_ids[0] == self.tokenizer.bos_token_id) and (
176
+ self.tokenizer("Test 123", add_special_tokens=False).input_ids[0] != self.tokenizer.bos_token_id
177
+ ), (
178
+ f"Default Tokenizer of type `{type(self.tokenizer)}` does not automatically prefix inputs with BOS token!\n"
179
+ "Please read the comment in `base_llm.py` for more information!"
180
+ )
181
+
182
+ def get_fsdp_wrapping_policy(self) -> Callable:
183
+ """Return a `transformer_auto_wrap_policy` where we wrap each instance of `self.transformer_layer_cls`"""
184
+ transformer_block_policy = partial(
185
+ transformer_auto_wrap_policy, transformer_layer_cls={self.transformer_layer_cls}
186
+ )
187
+
188
+ return transformer_block_policy
189
+
190
+ def enable_gradient_checkpointing(self) -> None:
191
+ """Dispatch to underlying LLM instance's `gradient_checkpointing_enable`; defined for all `PretrainedModel`."""
192
+ self.llm.gradient_checkpointing_enable()
193
+
194
+ def embed_input_ids(self, input_ids: torch.LongTensor) -> torch.Tensor:
195
+ return self.llm.get_input_embeddings()(input_ids)
196
+
197
+ # [Contract] Should match the `forward` call of the underlying `llm` instance!
198
+ def forward(
199
+ self,
200
+ input_ids: Optional[torch.LongTensor] = None,
201
+ attention_mask: Optional[torch.Tensor] = None,
202
+ position_ids: Optional[torch.LongTensor] = None,
203
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
204
+ inputs_embeds: Optional[torch.FloatTensor] = None,
205
+ labels: Optional[torch.LongTensor] = None,
206
+ use_cache: Optional[bool] = None,
207
+ output_attentions: Optional[bool] = None,
208
+ output_hidden_states: Optional[bool] = None,
209
+ return_dict: Optional[bool] = None,
210
+ ) -> CausalLMOutputWithPast:
211
+ output: CausalLMOutputWithPast = self.llm(
212
+ input_ids=input_ids,
213
+ attention_mask=attention_mask,
214
+ position_ids=position_ids,
215
+ past_key_values=past_key_values,
216
+ inputs_embeds=inputs_embeds,
217
+ labels=labels,
218
+ use_cache=use_cache,
219
+ output_attentions=output_attentions,
220
+ output_hidden_states=output_hidden_states,
221
+ return_dict=return_dict,
222
+ )
223
+ return output
capvector-oft/prismatic/models/backbones/llm/llama2.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ llama2.py
3
+
4
+ Class definition for all LLMs derived from LlamaForCausalLM.
5
+ """
6
+
7
+ from typing import Optional, Sequence, Type
8
+
9
+ import torch
10
+ from torch import nn as nn
11
+ from transformers import LlamaForCausalLM
12
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
13
+
14
+ from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone
15
+ from prismatic.models.backbones.llm.prompting import (
16
+ LLaMa2ChatPromptBuilder,
17
+ PromptBuilder,
18
+ PurePromptBuilder,
19
+ VicunaV15ChatPromptBuilder,
20
+ )
21
+
22
+ # Registry =>> Support LLaMa-2 Models (from HF Transformers)
23
+ # fmt: off
24
+ LLAMA2_MODELS = {
25
+ # === Pure Meta LLaMa-2 (non-instruct/chat-tuned) Models ===
26
+ "llama2-7b-pure": {
27
+ "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-7b-hf"
28
+ },
29
+
30
+ "llama2-13b-pure": {
31
+ "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-13b-hf"
32
+ },
33
+
34
+ # === Meta LLaMa-2 Chat Models ===
35
+ "llama2-7b-chat": {
36
+ "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-7b-chat-hf"
37
+ },
38
+
39
+ "llama2-13b-chat": {
40
+ "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-13b-chat-hf"
41
+ },
42
+
43
+ # === Vicuna v1.5 Chat Models ===
44
+ "vicuna-v15-7b": {
45
+ "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "lmsys/vicuna-7b-v1.5"
46
+ },
47
+
48
+ "vicuna-v15-13b": {
49
+ "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "lmsys/vicuna-13b-v1.5"
50
+ },
51
+ }
52
+ # fmt: on
53
+
54
+
55
+ class LLaMa2LLMBackbone(HFCausalLLMBackbone):
56
+ def __init__(
57
+ self,
58
+ llm_backbone_id: str,
59
+ llm_max_length: int = 2048,
60
+ hf_token: Optional[str] = None,
61
+ inference_mode: bool = False,
62
+ use_flash_attention_2: bool = True,
63
+ ) -> None:
64
+ super().__init__(
65
+ llm_backbone_id,
66
+ llm_max_length=llm_max_length,
67
+ hf_token=hf_token,
68
+ inference_mode=inference_mode,
69
+ use_flash_attention_2=use_flash_attention_2,
70
+ **LLAMA2_MODELS[llm_backbone_id],
71
+ )
72
+
73
+ # [Special Case] LLaMa-2 PAD Token Handling --> for clarity, we add an extra token (and resize)
74
+ self.tokenizer.add_special_tokens({"pad_token": "<PAD>"})
75
+ self.llm.config.pad_token_id = self.tokenizer.pad_token_id
76
+ self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64)
77
+
78
+ @property
79
+ def prompt_builder_fn(self) -> Type[PromptBuilder]:
80
+ if self.identifier.startswith("llama2-") and self.identifier.endswith("-pure"):
81
+ return PurePromptBuilder
82
+
83
+ elif self.identifier.startswith("llama2-") and self.identifier.endswith("-chat"):
84
+ return LLaMa2ChatPromptBuilder
85
+
86
+ elif self.identifier.startswith("vicuna"):
87
+ return VicunaV15ChatPromptBuilder
88
+
89
+ raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`")
90
+
91
+ @property
92
+ def transformer_layer_cls(self) -> Type[nn.Module]:
93
+ return LlamaDecoderLayer
94
+
95
+ @property
96
+ def half_precision_dtype(self) -> torch.dtype:
97
+ """LLaMa-2 was trained in BF16; see https://huggingface.co/docs/transformers/main/model_doc/llama2."""
98
+ return torch.bfloat16
99
+
100
+ @property
101
+ def last_layer_finetune_modules(self) -> Sequence[nn.Module]:
102
+ return (self.llm.model.embed_tokens, self.llm.model.layers[-1], self.llm.lm_head)
capvector-oft/prismatic/models/backbones/llm/mistral.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ mistral.py
3
+
4
+ Class definition for all LLMs derived from MistralForCausalLM.
5
+ """
6
+
7
+ from typing import Optional, Type
8
+
9
+ import torch
10
+ from torch import nn as nn
11
+ from transformers import MistralForCausalLM
12
+ from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
13
+
14
+ from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone
15
+ from prismatic.models.backbones.llm.prompting import MistralInstructPromptBuilder, PromptBuilder, PurePromptBuilder
16
+
17
+ # Registry =>> Support Mistral Models (from HF Transformers)
18
+ # fmt: off
19
+ MISTRAL_MODELS = {
20
+ # === Base Mistral v0.1 ===
21
+ "mistral-v0.1-7b-pure": {
22
+ "llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-v0.1"
23
+ },
24
+
25
+ # === Mistral Instruct v0.1 ===
26
+ "mistral-v0.1-7b-instruct": {
27
+ "llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-Instruct-v0.1"
28
+ }
29
+ }
30
+ # fmt: on
31
+
32
+
33
+ class MistralLLMBackbone(HFCausalLLMBackbone):
34
+ def __init__(
35
+ self,
36
+ llm_backbone_id: str,
37
+ llm_max_length: int = 2048,
38
+ hf_token: Optional[str] = None,
39
+ inference_mode: bool = False,
40
+ use_flash_attention_2: bool = True,
41
+ ) -> None:
42
+ super().__init__(
43
+ llm_backbone_id,
44
+ llm_max_length=llm_max_length,
45
+ hf_token=hf_token,
46
+ inference_mode=inference_mode,
47
+ use_flash_attention_2=use_flash_attention_2,
48
+ **MISTRAL_MODELS[llm_backbone_id],
49
+ )
50
+
51
+ # [Special Case] Mistral PAD Token Handling --> for clarity, we add an extra token (and resize)
52
+ self.tokenizer.add_special_tokens({"pad_token": "<PAD>"})
53
+ self.llm.config.pad_token_id = self.tokenizer.pad_token_id
54
+ self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64)
55
+
56
+ @property
57
+ def prompt_builder_fn(self) -> Type[PromptBuilder]:
58
+ if self.identifier.endswith("-pure"):
59
+ return PurePromptBuilder
60
+
61
+ elif self.identifier.endswith("-instruct"):
62
+ return MistralInstructPromptBuilder
63
+
64
+ raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`")
65
+
66
+ @property
67
+ def transformer_layer_cls(self) -> Type[nn.Module]:
68
+ return MistralDecoderLayer
69
+
70
+ @property
71
+ def half_precision_dtype(self) -> torch.dtype:
72
+ return torch.bfloat16
capvector-oft/prismatic/models/backbones/llm/phi.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ phi.py
3
+
4
+ Class definition for all LLMs derived from PhiForCausalLM.
5
+ """
6
+
7
+ from typing import Optional, Type
8
+
9
+ import torch
10
+ from torch import nn as nn
11
+ from transformers import PhiForCausalLM
12
+ from transformers.models.phi.modeling_phi import PhiDecoderLayer
13
+
14
+ from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone
15
+ from prismatic.models.backbones.llm.prompting import PhiPromptBuilder, PromptBuilder
16
+
17
+ # Registry ==> Support Phi Models (from HF Transformers)
18
+ # fmt: off
19
+ PHI_MODELS = {
20
+ # === Phi-2 ===
21
+ "phi-2-3b": {
22
+ "llm_family": "phi", "llm_cls": PhiForCausalLM, "hf_hub_path": "microsoft/phi-2"
23
+ }
24
+ }
25
+ # fmt: on
26
+
27
+
28
+ class PhiLLMBackbone(HFCausalLLMBackbone):
29
+ def __init__(
30
+ self,
31
+ llm_backbone_id: str,
32
+ llm_max_length: int = 2048,
33
+ hf_token: Optional[str] = None,
34
+ inference_mode: bool = False,
35
+ use_flash_attention_2: bool = True,
36
+ ) -> None:
37
+ super().__init__(
38
+ llm_backbone_id,
39
+ llm_max_length=llm_max_length,
40
+ hf_token=hf_token,
41
+ inference_mode=inference_mode,
42
+ use_flash_attention_2=use_flash_attention_2,
43
+ **PHI_MODELS[llm_backbone_id],
44
+ )
45
+
46
+ # [Special Case] Phi PAD Token Handling --> for clarity, we add an extra token (and resize)
47
+ self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
48
+ self.llm.config.pad_token_id = self.tokenizer.pad_token_id
49
+ self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64)
50
+
51
+ @property
52
+ def prompt_builder_fn(self) -> Type[PromptBuilder]:
53
+ if self.identifier.startswith("phi-2"):
54
+ return PhiPromptBuilder
55
+
56
+ raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`")
57
+
58
+ @property
59
+ def transformer_layer_cls(self) -> Type[nn.Module]:
60
+ return PhiDecoderLayer
61
+
62
+ @property
63
+ def half_precision_dtype(self) -> torch.dtype:
64
+ return torch.bfloat16
capvector-oft/prismatic/models/backbones/llm/prompting/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .base_prompter import PromptBuilder, PurePromptBuilder
2
+ from .llama2_chat_prompter import LLaMa2ChatPromptBuilder
3
+ from .mistral_instruct_prompter import MistralInstructPromptBuilder
4
+ from .phi_prompter import PhiPromptBuilder
5
+ from .vicuna_v15_prompter import VicunaV15ChatPromptBuilder
capvector-oft/prismatic/models/backbones/llm/prompting/base_prompter.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ base_prompter.py
3
+
4
+ Abstract class definition of a multi-turn prompt builder for ensuring consistent formatting for chat-based LLMs.
5
+ """
6
+
7
+ from abc import ABC, abstractmethod
8
+ from typing import Optional
9
+
10
+
11
+ class PromptBuilder(ABC):
12
+ def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None:
13
+ self.model_family = model_family
14
+
15
+ # Only some models define a system prompt => let subclasses handle this logic!
16
+ self.system_prompt = system_prompt
17
+
18
+ @abstractmethod
19
+ def add_turn(self, role: str, message: str) -> str: ...
20
+
21
+ @abstractmethod
22
+ def get_potential_prompt(self, user_msg: str) -> None: ...
23
+
24
+ @abstractmethod
25
+ def get_prompt(self) -> str: ...
26
+
27
+
28
+ class PurePromptBuilder(PromptBuilder):
29
+ def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None:
30
+ super().__init__(model_family, system_prompt)
31
+
32
+ # TODO (siddk) =>> Can't always assume LlamaTokenizer --> FIX ME!
33
+ self.bos, self.eos = "<s>", "</s>"
34
+
35
+ # Get role-specific "wrap" functions
36
+ self.wrap_human = lambda msg: f"In: {msg}\nOut: "
37
+ self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}"
38
+
39
+ # === `self.prompt` gets built up over multiple turns ===
40
+ self.prompt, self.turn_count = "", 0
41
+
42
+ def add_turn(self, role: str, message: str) -> str:
43
+ assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt")
44
+ message = message.replace("<image>", "").strip()
45
+
46
+ if (self.turn_count % 2) == 0:
47
+ human_message = self.wrap_human(message)
48
+ wrapped_message = human_message
49
+ else:
50
+ gpt_message = self.wrap_gpt(message)
51
+ wrapped_message = gpt_message
52
+
53
+ # Update Prompt
54
+ self.prompt += wrapped_message
55
+
56
+ # Bump Turn Counter
57
+ self.turn_count += 1
58
+
59
+ # Return "wrapped_message" (effective string added to context)
60
+ return wrapped_message
61
+
62
+ def get_potential_prompt(self, message: str) -> None:
63
+ # Assumes that it's always the user's (human's) turn!
64
+ prompt_copy = str(self.prompt)
65
+
66
+ human_message = self.wrap_human(message)
67
+ prompt_copy += human_message
68
+
69
+ return prompt_copy.removeprefix(self.bos).rstrip()
70
+
71
+ def get_prompt(self) -> str:
72
+ # Remove prefix <bos> (if exists) because it gets auto-inserted by tokenizer!
73
+ return self.prompt.removeprefix(self.bos).rstrip()
capvector-oft/prismatic/models/backbones/llm/prompting/llama2_chat_prompter.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ llama2_prompter.py
3
+
4
+ Defines a PromptBuilder for building LLaMa-2 Chat Prompts --> not sure if this is "optimal", but this is the pattern
5
+ that's used by HF and other online tutorials.
6
+
7
+ Reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2
8
+ """
9
+
10
+ from typing import Optional
11
+
12
+ from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder
13
+
14
+ # Default System Prompt for Prismatic Models
15
+ SYS_PROMPTS = {
16
+ "prismatic": (
17
+ "You are a helpful language and vision assistant. "
18
+ "You are able to understand the visual content that the user provides, "
19
+ "and assist the user with a variety of tasks using natural language."
20
+ ),
21
+ "openvla": (
22
+ "You are a helpful language and vision assistant. "
23
+ "You are able to understand the visual content that the user provides, "
24
+ "and assist the user with a variety of tasks using natural language."
25
+ ),
26
+ }
27
+
28
+
29
+ def format_system_prompt(system_prompt: str) -> str:
30
+ return f"<<SYS>\n{system_prompt.strip()}\n<</SYS>>\n\n"
31
+
32
+
33
+ class LLaMa2ChatPromptBuilder(PromptBuilder):
34
+ def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None:
35
+ super().__init__(model_family, system_prompt)
36
+ self.system_prompt = format_system_prompt(
37
+ SYS_PROMPTS[self.model_family] if system_prompt is None else system_prompt
38
+ )
39
+
40
+ # LLaMa-2 Specific
41
+ self.bos, self.eos = "<s>", "</s>"
42
+
43
+ # Get role-specific "wrap" functions
44
+ self.wrap_human = lambda msg: f"[INST] {msg} [/INST] "
45
+ self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}"
46
+
47
+ # === `self.prompt` gets built up over multiple turns ===
48
+ self.prompt, self.turn_count = "", 0
49
+
50
+ def add_turn(self, role: str, message: str) -> str:
51
+ assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt")
52
+ message = message.replace("<image>", "").strip()
53
+
54
+ # Special Handling for "system" prompt (turn_count == 0)
55
+ if self.turn_count == 0:
56
+ sys_message = self.wrap_human(self.system_prompt + message)
57
+ wrapped_message = sys_message
58
+ elif (self.turn_count % 2) == 0:
59
+ human_message = self.wrap_human(message)
60
+ wrapped_message = human_message
61
+ else:
62
+ gpt_message = self.wrap_gpt(message)
63
+ wrapped_message = gpt_message
64
+
65
+ # Update Prompt
66
+ self.prompt += wrapped_message
67
+
68
+ # Bump Turn Counter
69
+ self.turn_count += 1
70
+
71
+ # Return "wrapped_message" (effective string added to context)
72
+ return wrapped_message
73
+
74
+ def get_potential_prompt(self, message: str) -> None:
75
+ # Assumes that it's always the user's (human's) turn!
76
+ prompt_copy = str(self.prompt)
77
+
78
+ # Special Handling for "system" prompt (turn_count == 0)
79
+ if self.turn_count == 0:
80
+ sys_message = self.wrap_human(self.system_prompt + message)
81
+ prompt_copy += sys_message
82
+
83
+ else:
84
+ human_message = self.wrap_human(message)
85
+ prompt_copy += human_message
86
+
87
+ return prompt_copy.removeprefix(self.bos).rstrip()
88
+
89
+ def get_prompt(self) -> str:
90
+ # Remove prefix <bos> because it gets auto-inserted by tokenizer!
91
+ return self.prompt.removeprefix(self.bos).rstrip()
capvector-oft/prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ mistral_instruct_prompter.py
3
+
4
+ Defines a PromptBuilder for building Mistral Instruct Chat Prompts --> recommended pattern used by HF / online tutorial.s
5
+
6
+ Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
7
+ """
8
+
9
+ from typing import Optional
10
+
11
+ from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder
12
+
13
+
14
+ class MistralInstructPromptBuilder(PromptBuilder):
15
+ def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None:
16
+ super().__init__(model_family, system_prompt)
17
+
18
+ # Note =>> Mistral Tokenizer is an instance of `LlamaTokenizer(Fast)`
19
+ # =>> Mistral Instruct *does not* use a System Prompt
20
+ self.bos, self.eos = "<s>", "</s>"
21
+
22
+ # Get role-specific "wrap" functions
23
+ self.wrap_human = lambda msg: f"[INST] {msg} [/INST] "
24
+ self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}"
25
+
26
+ # === `self.prompt` gets built up over multiple turns ===
27
+ self.prompt, self.turn_count = "", 0
28
+
29
+ def add_turn(self, role: str, message: str) -> str:
30
+ assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt")
31
+ message = message.replace("<image>", "").strip()
32
+
33
+ if (self.turn_count % 2) == 0:
34
+ human_message = self.wrap_human(message)
35
+ wrapped_message = human_message
36
+ else:
37
+ gpt_message = self.wrap_gpt(message)
38
+ wrapped_message = gpt_message
39
+
40
+ # Update Prompt
41
+ self.prompt += wrapped_message
42
+
43
+ # Bump Turn Counter
44
+ self.turn_count += 1
45
+
46
+ # Return "wrapped_message" (effective string added to context)
47
+ return wrapped_message
48
+
49
+ def get_potential_prompt(self, message: str) -> None:
50
+ # Assumes that it's always the user's (human's) turn!
51
+ prompt_copy = str(self.prompt)
52
+
53
+ human_message = self.wrap_human(message)
54
+ prompt_copy += human_message
55
+
56
+ return prompt_copy.removeprefix(self.bos).rstrip()
57
+
58
+ def get_prompt(self) -> str:
59
+ # Remove prefix <bos> because it gets auto-inserted by tokenizer!
60
+ return self.prompt.removeprefix(self.bos).rstrip()
capvector-oft/prismatic/models/backbones/llm/prompting/phi_prompter.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ phi_prompter.py
3
+
4
+ Defines a PromptBuilder for building Phi-2 Input/Output Prompts --> recommended pattern used by HF / Microsoft.
5
+ Also handles Phi special case BOS token additions.
6
+
7
+ Reference: https://huggingface.co/microsoft/phi-2#qa-format
8
+ """
9
+
10
+ from typing import Optional
11
+
12
+ from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder
13
+
14
+
15
+ class PhiPromptBuilder(PromptBuilder):
16
+ def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None:
17
+ super().__init__(model_family, system_prompt)
18
+
19
+ # Note =>> Phi Tokenizer is an instance of `CodeGenTokenizer(Fast)`
20
+ # =>> By default, does *not* append <BOS> / <EOS> tokens --> we handle that here (IMPORTANT)!
21
+ self.bos, self.eos = "<|endoftext|>", "<|endoftext|>"
22
+
23
+ # Get role-specific "wrap" functions
24
+ # =>> Note that placement of <bos>/<eos> were based on experiments generating from Phi-2 in Input/Output mode
25
+ self.wrap_human = lambda msg: f"Input: {msg}\nOutput: "
26
+ self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}\n{self.eos}"
27
+
28
+ # === `self.prompt` gets built up over multiple turns ===
29
+ self.prompt, self.turn_count = "", 0
30
+
31
+ def add_turn(self, role: str, message: str) -> str:
32
+ assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt")
33
+ message = message.replace("<image>", "").strip()
34
+
35
+ # Special Handling for "first" input --> prepend a <BOS> token (expected by Prismatic)
36
+ if self.turn_count == 0:
37
+ bos_human_message = f"{self.bos}{self.wrap_human(message)}"
38
+ wrapped_message = bos_human_message
39
+ elif (self.turn_count % 2) == 0:
40
+ human_message = self.wrap_human(message)
41
+ wrapped_message = human_message
42
+ else:
43
+ gpt_message = self.wrap_gpt(message)
44
+ wrapped_message = gpt_message
45
+
46
+ # Update Prompt
47
+ self.prompt += wrapped_message
48
+
49
+ # Bump Turn Counter
50
+ self.turn_count += 1
51
+
52
+ # Return "wrapped_message" (effective string added to context)
53
+ return wrapped_message
54
+
55
+ def get_potential_prompt(self, message: str) -> None:
56
+ # Assumes that it's always the user's (human's) turn!
57
+ prompt_copy = str(self.prompt)
58
+
59
+ human_message = self.wrap_human(message)
60
+ prompt_copy += human_message
61
+
62
+ return prompt_copy.rstrip()
63
+
64
+ def get_prompt(self) -> str:
65
+ return self.prompt.rstrip()
capvector-oft/prismatic/models/backbones/llm/prompting/vicuna_v15_prompter.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ vicuna_v15_prompter.py
3
+
4
+ Defines a PromptBuilder for building Vicuna-v1.5 Chat Prompts.
5
+
6
+ Reference: https://huggingface.co/lmsys/vicuna-13b-v1.5
7
+ """
8
+
9
+ from typing import Optional
10
+
11
+ from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder
12
+
13
+ # Default System Prompt for LLaVa Models
14
+ SYS_PROMPTS = {
15
+ "prismatic": (
16
+ "A chat between a curious user and an artificial intelligence assistant. "
17
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
18
+ ),
19
+ "openvla": (
20
+ "A chat between a curious user and an artificial intelligence assistant. "
21
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
22
+ ),
23
+ }
24
+
25
+
26
+ class VicunaV15ChatPromptBuilder(PromptBuilder):
27
+ def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None:
28
+ super().__init__(model_family, system_prompt)
29
+ self.system_prompt = (SYS_PROMPTS[self.model_family] if system_prompt is None else system_prompt).strip() + " "
30
+
31
+ # LLaMa-2 Specific
32
+ self.bos, self.eos = "<s>", "</s>"
33
+
34
+ # Get role-specific "wrap" functions
35
+ self.wrap_human = lambda msg: f"USER: {msg} ASSISTANT: "
36
+ self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}"
37
+
38
+ # === `self.prompt` gets built up over multiple turns ===
39
+ self.prompt, self.turn_count = "", 0
40
+
41
+ def add_turn(self, role: str, message: str) -> str:
42
+ assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt")
43
+ message = message.replace("<image>", "").strip()
44
+
45
+ # Special Handling for "system" prompt (turn_count == 0)
46
+ if self.turn_count == 0:
47
+ sys_message = self.system_prompt + self.wrap_human(message)
48
+ wrapped_message = sys_message
49
+ elif (self.turn_count % 2) == 0:
50
+ human_message = self.wrap_human(message)
51
+ wrapped_message = human_message
52
+ else:
53
+ gpt_message = self.wrap_gpt(message)
54
+ wrapped_message = gpt_message
55
+
56
+ # Update Prompt
57
+ self.prompt += wrapped_message
58
+
59
+ # Bump Turn Counter
60
+ self.turn_count += 1
61
+
62
+ # Return "wrapped_message" (effective string added to context)
63
+ return wrapped_message
64
+
65
+ def get_potential_prompt(self, message: str) -> None:
66
+ # Assumes that it's always the user's (human's) turn!
67
+ prompt_copy = str(self.prompt)
68
+
69
+ # Special Handling for "system" prompt (turn_count == 0)
70
+ if self.turn_count == 0:
71
+ sys_message = self.system_prompt + self.wrap_human(message)
72
+ prompt_copy += sys_message
73
+
74
+ else:
75
+ human_message = self.wrap_human(message)
76
+ prompt_copy += human_message
77
+
78
+ return prompt_copy.removeprefix(self.bos).rstrip()
79
+
80
+ def get_prompt(self) -> str:
81
+ # Remove prefix <bos> (if exists) because it gets auto-inserted by tokenizer!
82
+ return self.prompt.removeprefix(self.bos).rstrip()
capvector-oft/prismatic/models/backbones/vision/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .base_vision import ImageTransform, VisionBackbone
2
+ from .clip_vit import CLIPViTBackbone
3
+ from .dinoclip_vit import DinoCLIPViTBackbone
4
+ from .dinosiglip_vit import DinoSigLIPViTBackbone
5
+ from .dinov2_vit import DinoV2ViTBackbone
6
+ from .in1k_vit import IN1KViTBackbone
7
+ from .siglip_vit import SigLIPViTBackbone
capvector-oft/prismatic/models/backbones/vision/base_vision.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ base_vision.py
3
+
4
+ Abstract class definition of a Vision Backbone (Visual Featurizer), with full annotations of class methods, utility
5
+ functions, and initialization logic.
6
+
7
+ We also define the generic TimmViTBackbone class here, providing a default interface for loading any TIMM Vision
8
+ Transformer model for feature extraction.
9
+ """
10
+
11
+ from abc import ABC, abstractmethod
12
+ from dataclasses import dataclass
13
+ from functools import partial
14
+ from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
15
+
16
+ import timm
17
+ import torch
18
+ import torch.nn as nn
19
+ import torchvision.transforms.functional as TVF
20
+ from PIL.Image import Image
21
+ from timm.models.vision_transformer import Block, VisionTransformer
22
+ from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
23
+ from torchvision.transforms import Compose, Resize
24
+
25
+
26
+ # === Utility Functions for Monkey-Patching ===
27
+ def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
28
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
29
+ result = fn(*args, **kwargs)
30
+ return result[0] if isinstance(result, tuple) else result
31
+
32
+ return wrapper
33
+
34
+
35
+ # === Interface for an Image Transform ===
36
+ class ImageTransform(Protocol):
37
+ def __call__(self, img: Image, **kwargs: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: ...
38
+
39
+
40
+ # === Custom Torchvision Image Transforms ===
41
+ @dataclass
42
+ class LetterboxPad:
43
+ padding_fill_value: Tuple[int, int, int]
44
+
45
+ def __call__(self, image: Image) -> Image:
46
+ """Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
47
+ (w, h), max_wh = image.size, max(image.size)
48
+ horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2)
49
+ padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
50
+ return TVF.pad(image, padding, fill=self.padding_fill_value, padding_mode="constant")
51
+
52
+
53
+ # === Abstract Base Class for arbitrary Vision Backbones ===
54
+ class VisionBackbone(nn.Module, ABC):
55
+ def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
56
+ super().__init__()
57
+ self.identifier: str = vision_backbone_id
58
+ self.image_resize_strategy: str = image_resize_strategy
59
+ self.default_image_size: int = default_image_size
60
+
61
+ # Instance attributes for a Vision Backbone
62
+ self.featurizer: nn.Module = None
63
+ self.image_transform: ImageTransform = None
64
+
65
+ def get_image_transform(self) -> ImageTransform:
66
+ return self.image_transform
67
+
68
+ @abstractmethod
69
+ def get_fsdp_wrapping_policy(self) -> Callable: ...
70
+
71
+ @abstractmethod
72
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
73
+ """Run a forward pass through the featurizer given a set of processed images, returning patch/grid features."""
74
+ raise NotImplementedError
75
+
76
+ @property
77
+ @abstractmethod
78
+ def default_image_resolution(self) -> Tuple[int, int, int]: ...
79
+
80
+ @property
81
+ @abstractmethod
82
+ def embed_dim(self) -> int: ...
83
+
84
+ @property
85
+ @abstractmethod
86
+ def num_patches(self) -> int: ...
87
+
88
+ @property
89
+ @abstractmethod
90
+ def half_precision_dtype(self) -> torch.dtype: ...
91
+
92
+
93
+ # === Abstract Base Class for Arbitrary TIMM Vision Transformer Backbones ===
94
+ class TimmViTBackbone(VisionBackbone, ABC):
95
+ def __init__(
96
+ self,
97
+ vision_backbone_id: str,
98
+ timm_path_or_url: str,
99
+ image_resize_strategy: str,
100
+ default_image_size: int = 224,
101
+ override_act_layer: Optional[str] = None,
102
+ ) -> None:
103
+ super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size)
104
+ self.timm_path_or_url = timm_path_or_url
105
+ self.override_act_layer = override_act_layer
106
+ self.dtype = torch.bfloat16
107
+
108
+ # Initialize Featurizer (ViT) by downloading from HF / TIMM Hub if necessary
109
+ if self.override_act_layer is None:
110
+ self.featurizer: VisionTransformer = timm.create_model(
111
+ self.timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size
112
+ )
113
+ else:
114
+ self.featurizer: VisionTransformer = timm.create_model(
115
+ self.timm_path_or_url,
116
+ pretrained=True,
117
+ num_classes=0,
118
+ img_size=self.default_image_size,
119
+ act_layer=self.override_act_layer,
120
+ )
121
+ self.featurizer.eval()
122
+
123
+ # Monkey-Patch the `forward()` function of the featurizer to ensure FSDP-compatibility
124
+ # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
125
+ # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
126
+ self.featurizer.forward = unpack_tuple(
127
+ partial(self.featurizer.get_intermediate_layers, n={len(self.featurizer.blocks) - 2})
128
+ )
129
+
130
+ # Validation =>> for now, this class *only* supports TIMM Vision Transformers (but can be extended!)
131
+ assert isinstance(self.featurizer, VisionTransformer), (
132
+ "Featurizer is not a TIMM VisionTransformer; if you would like to support a new visual representation, "
133
+ "file an issue or implement the requisite logic (see `prismatic/models/backbones/vision/base_vision.py`)!"
134
+ )
135
+
136
+ # Get Config =>> Note :: Override default image size to ensure correct image transform
137
+ self.data_cfg = timm.data.resolve_model_data_config(self.featurizer)
138
+ self.data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size)
139
+
140
+ # Initialize Default Image Transform --> Modified by `self.image_resize_strategy`
141
+ default_image_transform = timm.data.create_transform(**self.data_cfg, is_training=False)
142
+
143
+ # Fix =>> SigLIP & IN1K default transforms resize to *larger* than `self.default_image_size` (crops image)!
144
+ if "siglip" in self.timm_path_or_url or "in1k" in self.timm_path_or_url:
145
+ assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!"
146
+ assert isinstance(default_image_transform.transforms[0], Resize)
147
+ default_image_transform = Compose(
148
+ [
149
+ Resize(self.default_image_size, interpolation=default_image_transform.transforms[0].interpolation),
150
+ *default_image_transform.transforms[1:],
151
+ ]
152
+ )
153
+
154
+ # Switch on `image_resize_strategy`
155
+ if self.image_resize_strategy == "resize-naive":
156
+ assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!"
157
+ assert isinstance(default_image_transform.transforms[0], Resize)
158
+
159
+ target_size = (self.default_image_size, self.default_image_size)
160
+ self.image_transform = Compose(
161
+ [
162
+ Resize(target_size, interpolation=default_image_transform.transforms[0].interpolation),
163
+ *default_image_transform.transforms[1:],
164
+ ]
165
+ )
166
+
167
+ elif self.image_resize_strategy == "resize-crop":
168
+ self.image_transform = default_image_transform
169
+
170
+ elif self.image_resize_strategy == "letterbox":
171
+ assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!"
172
+ assert "mean" in self.data_cfg, "TIMM `data_cfg` missing image normalization mean!"
173
+
174
+ # Compute Padding Fill Value (rescaled normalization mean if applicable)
175
+ fill = tuple([int(x * 255) for x in self.data_cfg["mean"]])
176
+
177
+ # Build New Transform
178
+ self.image_transform = Compose([LetterboxPad(fill), *default_image_transform.transforms])
179
+
180
+ else:
181
+ raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!")
182
+
183
+ def get_fsdp_wrapping_policy(self) -> Callable:
184
+ """Return a simple FSDP policy that wraps each ViT block and then the _entire_ featurizer."""
185
+ vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer})
186
+ transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
187
+ return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy])
188
+
189
+ def forward(self, pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor:
190
+ """Runs transformed image/pixel tensor through vision backbone, returning _all_ patch features."""
191
+ return self.featurizer(pixel_values)
192
+
193
+ @property
194
+ def default_image_resolution(self) -> Tuple[int, int, int]:
195
+ return self.data_cfg["input_size"]
196
+
197
+ @property
198
+ def embed_dim(self) -> int:
199
+ return self.featurizer.embed_dim
200
+
201
+ @property
202
+ def num_patches(self) -> int:
203
+ return self.featurizer.patch_embed.num_patches
204
+
205
+ @property
206
+ def half_precision_dtype(self) -> torch.dtype:
207
+ return self.dtype
capvector-oft/prismatic/models/backbones/vision/clip_vit.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ clip_vit.py
3
+ """
4
+
5
+ from prismatic.models.backbones.vision.base_vision import TimmViTBackbone
6
+
7
+ # Registry =>> Supported CLIP Vision Backbones (from TIMM)
8
+ CLIP_VISION_BACKBONES = {
9
+ "clip-vit-b": "vit_base_patch16_clip_224.openai",
10
+ "clip-vit-l": "vit_large_patch14_clip_224.openai",
11
+ "clip-vit-l-336px": "vit_large_patch14_clip_336.openai",
12
+ }
13
+
14
+
15
+ # [IMPORTANT] By Default, TIMM initialized OpenAI CLIP models with the standard GELU activation from PyTorch.
16
+ # HOWEVER =>> Original OpenAI models were trained with the quick_gelu *approximation* -- while it's
17
+ # a decent approximation, the resulting features are *worse*; this was a super tricky bug
18
+ # to identify, but luckily there's an easy fix (`override_act_layer`)
19
+ class CLIPViTBackbone(TimmViTBackbone):
20
+ def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
21
+ super().__init__(
22
+ vision_backbone_id,
23
+ CLIP_VISION_BACKBONES[vision_backbone_id],
24
+ image_resize_strategy,
25
+ default_image_size=default_image_size,
26
+ override_act_layer="quick_gelu" if CLIP_VISION_BACKBONES[vision_backbone_id].endswith(".openai") else None,
27
+ )
capvector-oft/prismatic/models/backbones/vision/dinoclip_vit.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ dinoclip_vit.py
3
+
4
+ Vision backbone that returns concatenated features from both DINOv2 and CLIP.
5
+ """
6
+
7
+ from dataclasses import dataclass
8
+ from functools import partial
9
+ from typing import Callable, Dict, Tuple
10
+
11
+ import timm
12
+ import torch
13
+ from PIL import Image
14
+ from timm.models.vision_transformer import Block, VisionTransformer
15
+ from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
16
+ from torchvision.transforms import Compose, Resize
17
+
18
+ from prismatic.models.backbones.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple
19
+
20
+ # Registry =>> Supported DinoCLIP Pairs (as TIMM identifiers)
21
+ DINOCLIP_VISION_BACKBONES = {
22
+ "dinoclip-vit-l-336px": {
23
+ "dino": "vit_large_patch14_reg4_dinov2.lvd142m",
24
+ "clip": "vit_large_patch14_clip_336.openai",
25
+ },
26
+ }
27
+
28
+
29
+ @dataclass
30
+ class DinoCLIPImageTransform:
31
+ dino_image_transform: ImageTransform
32
+ clip_image_transform: ImageTransform
33
+ is_prismatic: bool = True
34
+
35
+ def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]:
36
+ return {"dino": self.dino_image_transform(img, **kwargs), "clip": self.clip_image_transform(img, **kwargs)}
37
+
38
+
39
+ class DinoCLIPViTBackbone(VisionBackbone):
40
+ def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
41
+ super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size)
42
+ self.dino_timm_path_or_url = DINOCLIP_VISION_BACKBONES[vision_backbone_id]["dino"]
43
+ self.clip_timm_path_or_url = DINOCLIP_VISION_BACKBONES[vision_backbone_id]["clip"]
44
+
45
+ # Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary
46
+ self.dino_featurizer: VisionTransformer = timm.create_model(
47
+ self.dino_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size
48
+ )
49
+ self.dino_featurizer.eval()
50
+
51
+ self.clip_featurizer: VisionTransformer = timm.create_model(
52
+ self.clip_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size
53
+ )
54
+ self.clip_featurizer.eval()
55
+
56
+ # Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility
57
+ # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
58
+ # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
59
+ self.dino_featurizer.forward = unpack_tuple(
60
+ partial(self.dino_featurizer.get_intermediate_layers, n={len(self.dino_featurizer.blocks) - 2})
61
+ )
62
+ self.clip_featurizer.forward = unpack_tuple(
63
+ partial(self.clip_featurizer.get_intermediate_layers, n={len(self.clip_featurizer.blocks) - 2})
64
+ )
65
+
66
+ # Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models
67
+ self.dino_data_cfg = timm.data.resolve_model_data_config(self.dino_featurizer)
68
+ self.dino_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size)
69
+
70
+ self.clip_data_cfg = timm.data.resolve_model_data_config(self.clip_featurizer)
71
+ self.clip_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size)
72
+
73
+ # Initialize *both* Transforms
74
+ default_dino_transform = timm.data.create_transform(**self.dino_data_cfg, is_training=False)
75
+ default_clip_transform = timm.data.create_transform(**self.clip_data_cfg, is_training=False)
76
+ if self.image_resize_strategy == "resize-naive":
77
+ assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_image_transform`!"
78
+ assert isinstance(default_clip_transform, Compose), "Unexpected `default_clip_image_transform`!"
79
+ assert isinstance(default_dino_transform.transforms[0], Resize)
80
+ assert isinstance(default_clip_transform.transforms[0], Resize)
81
+
82
+ target_size = (self.default_image_size, self.default_image_size)
83
+ dino_transform = Compose(
84
+ [
85
+ Resize(target_size, interpolation=default_dino_transform.transforms[0].interpolation),
86
+ *default_dino_transform.transforms[1:],
87
+ ]
88
+ )
89
+ clip_transform = Compose(
90
+ [
91
+ Resize(target_size, interpolation=default_clip_transform.transforms[0].interpolation),
92
+ *default_clip_transform.transforms[1:],
93
+ ]
94
+ )
95
+
96
+ self.image_transform = DinoCLIPImageTransform(dino_transform, clip_transform)
97
+
98
+ elif self.image_resize_strategy == "resize-crop":
99
+ self.image_transform = DinoCLIPImageTransform(default_dino_transform, default_clip_transform)
100
+
101
+ elif self.image_resize_strategy == "letterbox":
102
+ assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_transform`!"
103
+ assert isinstance(default_clip_transform, Compose), "Unexpected `default_clip_transform`!"
104
+ assert "mean" in self.dino_data_cfg and "mean" in self.clip_data_cfg, "DinoCLIP `data_cfg` missing `mean`!"
105
+
106
+ # Compute Padding Fill Value(s) (rescaled normalization mean if applicable)
107
+ dino_fill = tuple([int(x * 255) for x in self.dino_data_cfg["mean"]])
108
+ clip_fill = tuple([int(x * 255) for x in self.clip_data_cfg["mean"]])
109
+
110
+ # Build New Transform
111
+ self.image_transform = DinoCLIPImageTransform(
112
+ Compose([LetterboxPad(dino_fill), *default_dino_transform.transforms]),
113
+ Compose([LetterboxPad(clip_fill), *default_clip_transform.transforms]),
114
+ )
115
+
116
+ else:
117
+ raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!")
118
+
119
+ def get_fsdp_wrapping_policy(self) -> Callable:
120
+ """Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers."""
121
+ vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer})
122
+ transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
123
+ return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy])
124
+
125
+ def forward(self, pixel_values: Dict[str, torch.Tensor]) -> torch.Tensor:
126
+ """Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches."""
127
+ dino_patches = self.dino_featurizer(pixel_values["dino"])
128
+ clip_patches = self.clip_featurizer(pixel_values["clip"])
129
+
130
+ return torch.cat([dino_patches, clip_patches], dim=2)
131
+
132
+ @property
133
+ def default_image_resolution(self) -> Tuple[int, int, int]:
134
+ return self.dino_data_cfg["input_size"]
135
+
136
+ @property
137
+ def embed_dim(self) -> int:
138
+ return self.dino_featurizer.embed_dim + self.clip_featurizer.embed_dim
139
+
140
+ @property
141
+ def num_patches(self) -> int:
142
+ assert self.dino_featurizer.patch_embed.num_patches == self.clip_featurizer.patch_embed.num_patches
143
+ return self.dino_featurizer.patch_embed.num_patches
144
+
145
+ @property
146
+ def half_precision_dtype(self) -> torch.dtype:
147
+ return torch.bfloat16
capvector-oft/prismatic/models/backbones/vision/dinosiglip_vit.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ dinosiglip_vit.py
3
+
4
+ Vision backbone that returns concatenated features from both DINOv2 and SigLIP.
5
+ """
6
+
7
+ from dataclasses import dataclass
8
+ from functools import partial
9
+ from typing import Callable, Dict, Tuple
10
+
11
+ import timm
12
+ import torch
13
+ from PIL import Image
14
+ from timm.models.vision_transformer import Block, VisionTransformer
15
+ from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
16
+ from torchvision.transforms import Compose, Resize
17
+
18
+ from prismatic.models.backbones.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple
19
+
20
+ # Registry =>> Supported DinoSigLIP Pairs (as TIMM identifiers)
21
+ DINOSigLIP_VISION_BACKBONES = {
22
+ "dinosiglip-vit-so-224px": {
23
+ "dino": "vit_large_patch14_reg4_dinov2.lvd142m",
24
+ "siglip": "vit_so400m_patch14_siglip_224",
25
+ },
26
+ "dinosiglip-vit-so-384px": {
27
+ "dino": "vit_large_patch14_reg4_dinov2.lvd142m",
28
+ "siglip": "vit_so400m_patch14_siglip_384",
29
+ },
30
+ }
31
+
32
+
33
+ @dataclass
34
+ class DinoSigLIPImageTransform:
35
+ dino_image_transform: ImageTransform
36
+ siglip_image_transform: ImageTransform
37
+ is_prismatic: bool = True
38
+
39
+ def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]:
40
+ return {"dino": self.dino_image_transform(img, **kwargs), "siglip": self.siglip_image_transform(img, **kwargs)}
41
+
42
+
43
+ class DinoSigLIPViTBackbone(VisionBackbone):
44
+ def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
45
+ super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size)
46
+ self.dino_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[vision_backbone_id]["dino"]
47
+ self.siglip_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[vision_backbone_id]["siglip"]
48
+
49
+ # Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary
50
+ self.dino_featurizer: VisionTransformer = timm.create_model(
51
+ self.dino_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size
52
+ )
53
+ self.dino_featurizer.eval()
54
+
55
+ self.siglip_featurizer: VisionTransformer = timm.create_model(
56
+ self.siglip_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size
57
+ )
58
+ self.siglip_featurizer.eval()
59
+
60
+ # Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility
61
+ # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
62
+ # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
63
+ self.dino_featurizer.forward = unpack_tuple(
64
+ partial(self.dino_featurizer.get_intermediate_layers, n={len(self.dino_featurizer.blocks) - 2})
65
+ )
66
+ self.siglip_featurizer.forward = unpack_tuple(
67
+ partial(self.siglip_featurizer.get_intermediate_layers, n={len(self.siglip_featurizer.blocks) - 2})
68
+ )
69
+
70
+ # Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models
71
+ self.dino_data_cfg = timm.data.resolve_model_data_config(self.dino_featurizer)
72
+ self.dino_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size)
73
+
74
+ self.siglip_data_cfg = timm.data.resolve_model_data_config(self.siglip_featurizer)
75
+ self.siglip_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size)
76
+
77
+ # Initialize *both* Transforms
78
+ default_dino_transform = timm.data.create_transform(**self.dino_data_cfg, is_training=False)
79
+ default_siglip_transform = timm.data.create_transform(**self.siglip_data_cfg, is_training=False)
80
+
81
+ # Fix =>> SigLIP default transform resizes to *larger* than `self.default_image_size` (crops image)!!
82
+ assert isinstance(default_siglip_transform, Compose), "Unexpected `default_image_transform`!"
83
+ assert isinstance(default_siglip_transform.transforms[0], Resize)
84
+ default_siglip_transform = Compose(
85
+ [
86
+ Resize(self.default_image_size, interpolation=default_siglip_transform.transforms[0].interpolation),
87
+ *default_siglip_transform.transforms[1:],
88
+ ]
89
+ )
90
+
91
+ if self.image_resize_strategy == "resize-naive":
92
+ assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_image_transform`!"
93
+ assert isinstance(default_siglip_transform, Compose), "Unexpected `default_siglip_image_transform`!"
94
+ assert isinstance(default_dino_transform.transforms[0], Resize)
95
+ assert isinstance(default_siglip_transform.transforms[0], Resize)
96
+
97
+ target_size = (self.default_image_size, self.default_image_size)
98
+ dino_transform = Compose(
99
+ [
100
+ Resize(target_size, interpolation=default_dino_transform.transforms[0].interpolation),
101
+ *default_dino_transform.transforms[1:],
102
+ ]
103
+ )
104
+ siglip_transform = Compose(
105
+ [
106
+ Resize(target_size, interpolation=default_siglip_transform.transforms[0].interpolation),
107
+ *default_siglip_transform.transforms[1:],
108
+ ]
109
+ )
110
+
111
+ self.image_transform = DinoSigLIPImageTransform(dino_transform, siglip_transform)
112
+
113
+ elif self.image_resize_strategy == "resize-crop":
114
+ self.image_transform = DinoSigLIPImageTransform(default_dino_transform, default_siglip_transform)
115
+
116
+ elif self.image_resize_strategy == "letterbox":
117
+ assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_transform`!"
118
+ assert isinstance(default_siglip_transform, Compose), "Unexpected `default_siglip_transform`!"
119
+ assert (
120
+ "mean" in self.dino_data_cfg and "mean" in self.siglip_data_cfg
121
+ ), "DinoSigLIP `data_cfg` missing `mean`!"
122
+
123
+ # Compute Padding Fill Value(s) (rescaled normalization mean if applicable)
124
+ dino_fill = tuple([int(x * 255) for x in self.dino_data_cfg["mean"]])
125
+ siglip_fill = tuple([int(x * 255) for x in self.siglip_data_cfg["mean"]])
126
+
127
+ # Build New Transform
128
+ self.image_transform = DinoSigLIPImageTransform(
129
+ Compose([LetterboxPad(dino_fill), *default_dino_transform.transforms]),
130
+ Compose([LetterboxPad(siglip_fill), *default_siglip_transform.transforms]),
131
+ )
132
+
133
+ else:
134
+ raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!")
135
+
136
+ def get_fsdp_wrapping_policy(self) -> Callable:
137
+ """Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers."""
138
+ vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer})
139
+ transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
140
+ return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy])
141
+
142
+ def forward(self, pixel_values: Dict[str, torch.Tensor]) -> torch.Tensor:
143
+ """Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches."""
144
+ dino_patches = self.dino_featurizer(pixel_values["dino"])
145
+ siglip_patches = self.siglip_featurizer(pixel_values["siglip"])
146
+
147
+ return torch.cat([dino_patches, siglip_patches], dim=2)
148
+
149
+ @property
150
+ def default_image_resolution(self) -> Tuple[int, int, int]:
151
+ return self.dino_data_cfg["input_size"]
152
+
153
+ @property
154
+ def embed_dim(self) -> int:
155
+ return self.dino_featurizer.embed_dim + self.siglip_featurizer.embed_dim
156
+
157
+ @property
158
+ def num_patches(self) -> int:
159
+ assert self.dino_featurizer.patch_embed.num_patches == self.siglip_featurizer.patch_embed.num_patches
160
+ return self.dino_featurizer.patch_embed.num_patches
161
+
162
+ @property
163
+ def half_precision_dtype(self) -> torch.dtype:
164
+ return torch.bfloat16
capvector-oft/prismatic/models/backbones/vision/dinov2_vit.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ dinov2_vit.py
3
+ """
4
+
5
+ from prismatic.models.backbones.vision.base_vision import TimmViTBackbone
6
+
7
+ # Registry =>> Supported DINOv2 Vision Backbones (from TIMM) =>> Note:: Using DINOv2 w/ Registers!
8
+ # => Reference: https://arxiv.org/abs/2309.16588
9
+ DINOv2_VISION_BACKBONES = {"dinov2-vit-l": "vit_large_patch14_reg4_dinov2.lvd142m"}
10
+
11
+
12
+ class DinoV2ViTBackbone(TimmViTBackbone):
13
+ def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
14
+ super().__init__(
15
+ vision_backbone_id,
16
+ DINOv2_VISION_BACKBONES[vision_backbone_id],
17
+ image_resize_strategy,
18
+ default_image_size=default_image_size,
19
+ )
capvector-oft/prismatic/models/backbones/vision/in1k_vit.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ in1k_vit.py
3
+
4
+ Vision Transformers trained / finetuned on ImageNet (ImageNet-21K =>> ImageNet-1K)
5
+ """
6
+
7
+ from prismatic.models.backbones.vision.base_vision import TimmViTBackbone
8
+
9
+ # Registry =>> Supported Vision Backbones (from TIMM)
10
+ IN1K_VISION_BACKBONES = {
11
+ "in1k-vit-l": "vit_large_patch16_224.augreg_in21k_ft_in1k",
12
+ }
13
+
14
+
15
+ class IN1KViTBackbone(TimmViTBackbone):
16
+ def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
17
+ super().__init__(
18
+ vision_backbone_id,
19
+ IN1K_VISION_BACKBONES[vision_backbone_id],
20
+ image_resize_strategy,
21
+ default_image_size=default_image_size,
22
+ )
capvector-oft/prismatic/models/backbones/vision/siglip_vit.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ siglip_vit.py
3
+ """
4
+
5
+ from prismatic.models.backbones.vision.base_vision import TimmViTBackbone
6
+
7
+ # Registry =>> Supported SigLIP Vision Backbones (from TIMM) =>> Note:: Using SigLIP w/ Patch = 14 (but SO400M Arch)
8
+ SIGLIP_VISION_BACKBONES = {
9
+ "siglip-vit-b16-224px": "vit_base_patch16_siglip_224",
10
+ "siglip-vit-b16-256px": "vit_base_patch16_siglip_256",
11
+ "siglip-vit-b16-384px": "vit_base_patch16_siglip_384",
12
+ "siglip-vit-so400m": "vit_so400m_patch14_siglip_224",
13
+ "siglip-vit-so400m-384px": "vit_so400m_patch14_siglip_384",
14
+ }
15
+
16
+
17
+ class SigLIPViTBackbone(TimmViTBackbone):
18
+ def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
19
+ super().__init__(
20
+ vision_backbone_id,
21
+ SIGLIP_VISION_BACKBONES[vision_backbone_id],
22
+ image_resize_strategy,
23
+ default_image_size=default_image_size,
24
+ )
capvector-oft/prismatic/models/load.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ load.py
3
+
4
+ Entry point for loading pretrained VLMs for inference; exposes functions for listing available models (with canonical
5
+ IDs, mappings to paper experiments, and short descriptions), as well as for loading models (from disk or HF Hub).
6
+ """
7
+
8
+ import json
9
+ import os
10
+ from pathlib import Path
11
+ from typing import List, Optional, Union
12
+
13
+ from huggingface_hub import HfFileSystem, hf_hub_download
14
+
15
+ from prismatic.conf import ModelConfig
16
+ from prismatic.models.materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform
17
+ from prismatic.models.registry import GLOBAL_REGISTRY, MODEL_REGISTRY
18
+ from prismatic.models.vlas import OpenVLA
19
+ from prismatic.models.vlms import PrismaticVLM
20
+ from prismatic.overwatch import initialize_overwatch
21
+ from prismatic.vla.action_tokenizer import ActionTokenizer
22
+
23
+ # Initialize Overwatch =>> Wraps `logging.Logger`
24
+ overwatch = initialize_overwatch(__name__)
25
+
26
+
27
+ # === HF Hub Repository ===
28
+ HF_HUB_REPO = "TRI-ML/prismatic-vlms"
29
+ VLA_HF_HUB_REPO = "openvla/openvla-dev"
30
+
31
+
32
+ # === Available Models ===
33
+ def available_models() -> List[str]:
34
+ return list(MODEL_REGISTRY.keys())
35
+
36
+
37
+ def available_model_names() -> List[str]:
38
+ return list(GLOBAL_REGISTRY.items())
39
+
40
+
41
+ def get_model_description(model_id_or_name: str) -> str:
42
+ if model_id_or_name not in GLOBAL_REGISTRY:
43
+ raise ValueError(f"Couldn't find `{model_id_or_name = }; check `prismatic.available_model_names()`")
44
+
45
+ # Print Description & Return
46
+ print(json.dumps(description := GLOBAL_REGISTRY[model_id_or_name]["description"], indent=2))
47
+
48
+ return description
49
+
50
+
51
+ # === Load Pretrained Model ===
52
+ def load(
53
+ model_id_or_path: Union[str, Path],
54
+ hf_token: Optional[str] = None,
55
+ cache_dir: Optional[Union[str, Path]] = None,
56
+ load_for_training: bool = False,
57
+ ) -> PrismaticVLM:
58
+ """Loads a pretrained PrismaticVLM from either local disk or the HuggingFace Hub."""
59
+ if os.path.isdir(model_id_or_path):
60
+ overwatch.info(f"Loading from local path `{(run_dir := Path(model_id_or_path))}`")
61
+
62
+ # Get paths for `config.json` and pretrained checkpoint
63
+ config_json, checkpoint_pt = run_dir / "config.json", run_dir / "checkpoints" / "latest-checkpoint.pt"
64
+ assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`"
65
+ assert checkpoint_pt.exists(), f"Missing checkpoint for `{run_dir = }`"
66
+ else:
67
+ if model_id_or_path not in GLOBAL_REGISTRY:
68
+ raise ValueError(f"Couldn't find `{model_id_or_path = }; check `prismatic.available_model_names()`")
69
+
70
+ overwatch.info(f"Downloading `{(model_id := GLOBAL_REGISTRY[model_id_or_path]['model_id'])} from HF Hub")
71
+ with overwatch.local_zero_first():
72
+ config_json = hf_hub_download(repo_id=HF_HUB_REPO, filename=f"{model_id}/config.json", cache_dir=cache_dir)
73
+ checkpoint_pt = hf_hub_download(
74
+ repo_id=HF_HUB_REPO, filename=f"{model_id}/checkpoints/latest-checkpoint.pt", cache_dir=cache_dir
75
+ )
76
+
77
+ # Load Model Config from `config.json`
78
+ with open(config_json, "r") as f:
79
+ model_cfg = json.load(f)["model"]
80
+
81
+ # = Load Individual Components necessary for Instantiating a VLM =
82
+ # =>> Print Minimal Config
83
+ overwatch.info(
84
+ f"Found Config =>> Loading & Freezing [bold blue]{model_cfg['model_id']}[/] with:\n"
85
+ f" Vision Backbone =>> [bold]{model_cfg['vision_backbone_id']}[/]\n"
86
+ f" LLM Backbone =>> [bold]{model_cfg['llm_backbone_id']}[/]\n"
87
+ f" Arch Specifier =>> [bold]{model_cfg['arch_specifier']}[/]\n"
88
+ f" Checkpoint Path =>> [underline]`{checkpoint_pt}`[/]"
89
+ )
90
+
91
+ # Load Vision Backbone
92
+ overwatch.info(f"Loading Vision Backbone [bold]{model_cfg['vision_backbone_id']}[/]")
93
+ vision_backbone, image_transform = get_vision_backbone_and_transform(
94
+ model_cfg["vision_backbone_id"],
95
+ model_cfg["image_resize_strategy"],
96
+ )
97
+
98
+ # Load LLM Backbone --> note `inference_mode = True` by default when calling `load()`
99
+ overwatch.info(f"Loading Pretrained LLM [bold]{model_cfg['llm_backbone_id']}[/] via HF Transformers")
100
+ llm_backbone, tokenizer = get_llm_backbone_and_tokenizer(
101
+ model_cfg["llm_backbone_id"],
102
+ llm_max_length=model_cfg.get("llm_max_length", 2048),
103
+ hf_token=hf_token,
104
+ inference_mode=not load_for_training,
105
+ )
106
+
107
+ # Load VLM using `from_pretrained` (clobbers HF syntax... eventually should reconcile)
108
+ overwatch.info(f"Loading VLM [bold blue]{model_cfg['model_id']}[/] from Checkpoint")
109
+ vlm = PrismaticVLM.from_pretrained(
110
+ checkpoint_pt,
111
+ model_cfg["model_id"],
112
+ vision_backbone,
113
+ llm_backbone,
114
+ arch_specifier=model_cfg["arch_specifier"],
115
+ freeze_weights=not load_for_training,
116
+ )
117
+
118
+ return vlm
119
+
120
+
121
+ # === Load Pretrained VLA Model ===
122
+ def load_vla(
123
+ model_id_or_path: Union[str, Path],
124
+ hf_token: Optional[str] = None,
125
+ cache_dir: Optional[Union[str, Path]] = None,
126
+ load_for_training: bool = False,
127
+ step_to_load: Optional[int] = None,
128
+ model_type: str = "pretrained",
129
+ ) -> OpenVLA:
130
+ """Loads a pretrained OpenVLA from either local disk or the HuggingFace Hub."""
131
+
132
+ # TODO (siddk, moojink) :: Unify semantics with `load()` above; right now, `load_vla()` assumes path points to
133
+ # checkpoint `.pt` file, rather than the top-level run directory!
134
+ if os.path.isfile(model_id_or_path):
135
+ overwatch.info(f"Loading from local checkpoint path `{(checkpoint_pt := Path(model_id_or_path))}`")
136
+
137
+ # [Validate] Checkpoint Path should look like `.../<RUN_ID>/checkpoints/<CHECKPOINT_PATH>.pt`
138
+ assert (checkpoint_pt.suffix == ".pt") and (checkpoint_pt.parent.name == "checkpoints"), "Invalid checkpoint!"
139
+ run_dir = checkpoint_pt.parents[1]
140
+
141
+ # Get paths for `config.json`, `dataset_statistics.json` and pretrained checkpoint
142
+ config_json, dataset_statistics_json = run_dir / "config.json", run_dir / "dataset_statistics.json"
143
+ assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`"
144
+ assert dataset_statistics_json.exists(), f"Missing `dataset_statistics.json` for `{run_dir = }`"
145
+
146
+ # Otherwise =>> try looking for a match on `model_id_or_path` on the HF Hub (`VLA_HF_HUB_REPO`)
147
+ else:
148
+ # Search HF Hub Repo via fsspec API
149
+ overwatch.info(f"Checking HF for `{(hf_path := str(Path(VLA_HF_HUB_REPO) / model_type / model_id_or_path))}`")
150
+ if not (tmpfs := HfFileSystem()).exists(hf_path):
151
+ raise ValueError(f"Couldn't find valid HF Hub Path `{hf_path = }`")
152
+
153
+ # Identify Checkpoint to Load (via `step_to_load`)
154
+ step_to_load = f"{step_to_load:06d}" if step_to_load is not None else None
155
+ valid_ckpts = tmpfs.glob(f"{hf_path}/checkpoints/step-{step_to_load if step_to_load is not None else ''}*.pt")
156
+ if (len(valid_ckpts) == 0) or (step_to_load is not None and len(valid_ckpts) != 1):
157
+ raise ValueError(f"Couldn't find a valid checkpoint to load from HF Hub Path `{hf_path}/checkpoints/")
158
+
159
+ # Call to `glob` will sort steps in ascending order (if `step_to_load` is None); just grab last element
160
+ target_ckpt = Path(valid_ckpts[-1]).name
161
+
162
+ overwatch.info(f"Downloading Model `{model_id_or_path}` Config & Checkpoint `{target_ckpt}`")
163
+ with overwatch.local_zero_first():
164
+ relpath = Path(model_type) / model_id_or_path
165
+ config_json = hf_hub_download(
166
+ repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'config.json')!s}", cache_dir=cache_dir
167
+ )
168
+ dataset_statistics_json = hf_hub_download(
169
+ repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'dataset_statistics.json')!s}", cache_dir=cache_dir
170
+ )
171
+ checkpoint_pt = hf_hub_download(
172
+ repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'checkpoints' / target_ckpt)!s}", cache_dir=cache_dir
173
+ )
174
+
175
+ # Load VLA Config (and corresponding base VLM `ModelConfig`) from `config.json`
176
+ with open(config_json, "r") as f:
177
+ vla_cfg = json.load(f)["vla"]
178
+ model_cfg = ModelConfig.get_choice_class(vla_cfg["base_vlm"])()
179
+
180
+ # Load Dataset Statistics for Action Denormalization
181
+ with open(dataset_statistics_json, "r") as f:
182
+ norm_stats = json.load(f)
183
+
184
+ # = Load Individual Components necessary for Instantiating a VLA (via base VLM components) =
185
+ # =>> Print Minimal Config
186
+ overwatch.info(
187
+ f"Found Config =>> Loading & Freezing [bold blue]{model_cfg.model_id}[/] with:\n"
188
+ f" Vision Backbone =>> [bold]{model_cfg.vision_backbone_id}[/]\n"
189
+ f" LLM Backbone =>> [bold]{model_cfg.llm_backbone_id}[/]\n"
190
+ f" Arch Specifier =>> [bold]{model_cfg.arch_specifier}[/]\n"
191
+ f" Checkpoint Path =>> [underline]`{checkpoint_pt}`[/]"
192
+ )
193
+
194
+ # Load Vision Backbone
195
+ overwatch.info(f"Loading Vision Backbone [bold]{model_cfg.vision_backbone_id}[/]")
196
+ vision_backbone, image_transform = get_vision_backbone_and_transform(
197
+ model_cfg.vision_backbone_id,
198
+ model_cfg.image_resize_strategy,
199
+ )
200
+
201
+ # Load LLM Backbone --> note `inference_mode = True` by default when calling `load()`
202
+ overwatch.info(f"Loading Pretrained LLM [bold]{model_cfg.llm_backbone_id}[/] via HF Transformers")
203
+ llm_backbone, tokenizer = get_llm_backbone_and_tokenizer(
204
+ model_cfg.llm_backbone_id,
205
+ llm_max_length=model_cfg.llm_max_length,
206
+ hf_token=hf_token,
207
+ inference_mode=not load_for_training,
208
+ )
209
+
210
+ # Create Action Tokenizer
211
+ action_tokenizer = ActionTokenizer(llm_backbone.get_tokenizer())
212
+
213
+ # Load VLM using `from_pretrained` (clobbers HF syntax... eventually should reconcile)
214
+ overwatch.info(f"Loading VLA [bold blue]{model_cfg.model_id}[/] from Checkpoint")
215
+ vla = OpenVLA.from_pretrained(
216
+ checkpoint_pt,
217
+ model_cfg.model_id,
218
+ vision_backbone,
219
+ llm_backbone,
220
+ arch_specifier=model_cfg.arch_specifier,
221
+ freeze_weights=not load_for_training,
222
+ norm_stats=norm_stats,
223
+ action_tokenizer=action_tokenizer,
224
+ )
225
+
226
+ return vla
capvector-oft/prismatic/models/materialize.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ materialize.py
3
+
4
+ Factory class for initializing Vision Backbones, LLM Backbones, and VLMs from a set registry; provides and exports
5
+ individual functions for clear control flow.
6
+ """
7
+
8
+ from typing import Optional, Tuple
9
+
10
+ from transformers import PreTrainedTokenizerBase
11
+
12
+ from prismatic.models.backbones.llm import LLaMa2LLMBackbone, LLMBackbone, MistralLLMBackbone, PhiLLMBackbone
13
+ from prismatic.models.backbones.vision import (
14
+ CLIPViTBackbone,
15
+ DinoCLIPViTBackbone,
16
+ DinoSigLIPViTBackbone,
17
+ DinoV2ViTBackbone,
18
+ ImageTransform,
19
+ IN1KViTBackbone,
20
+ SigLIPViTBackbone,
21
+ VisionBackbone,
22
+ )
23
+ from prismatic.models.vlms import PrismaticVLM
24
+
25
+ # === Registries =>> Maps ID --> {cls(), kwargs} :: Different Registries for Vision Backbones, LLM Backbones, VLMs ===
26
+ # fmt: off
27
+
28
+ # === Vision Backbone Registry ===
29
+ VISION_BACKBONES = {
30
+ # === 224px Backbones ===
31
+ "clip-vit-l": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}},
32
+ "siglip-vit-so400m": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}},
33
+ "dinov2-vit-l": {"cls": DinoV2ViTBackbone, "kwargs": {"default_image_size": 224}},
34
+ "in1k-vit-l": {"cls": IN1KViTBackbone, "kwargs": {"default_image_size": 224}},
35
+ "dinosiglip-vit-so-224px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 224}},
36
+
37
+ # === Assorted CLIP Backbones ===
38
+ "clip-vit-b": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}},
39
+ "clip-vit-l-336px": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 336}},
40
+
41
+ # === Assorted SigLIP Backbones ===
42
+ "siglip-vit-b16-224px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}},
43
+ "siglip-vit-b16-256px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 256}},
44
+ "siglip-vit-b16-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}},
45
+ "siglip-vit-so400m-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}},
46
+
47
+ # === Fused Backbones ===
48
+ "dinoclip-vit-l-336px": {"cls": DinoCLIPViTBackbone, "kwargs": {"default_image_size": 336}},
49
+ "dinosiglip-vit-so-384px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 384}},
50
+ }
51
+
52
+
53
+ # === Language Model Registry ===
54
+ LLM_BACKBONES = {
55
+ # === LLaMa-2 Pure (Non-Chat) Backbones ===
56
+ "llama2-7b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
57
+ "llama2-13b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
58
+
59
+ # === LLaMa-2 Chat Backbones ===
60
+ "llama2-7b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
61
+ "llama2-13b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
62
+
63
+ # === Vicuna-v1.5 Backbones ===
64
+ "vicuna-v15-7b": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
65
+ "vicuna-v15-13b": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
66
+
67
+ # === Mistral v0.1 Backbones ===
68
+ "mistral-v0.1-7b-pure": {"cls": MistralLLMBackbone, "kwargs": {}},
69
+ "mistral-v0.1-7b-instruct": {"cls": MistralLLMBackbone, "kwargs": {}},
70
+
71
+ # === Phi-2 Backbone ===
72
+ "phi-2-3b": {"cls": PhiLLMBackbone, "kwargs": {}},
73
+ }
74
+
75
+ # fmt: on
76
+
77
+
78
+ def get_vision_backbone_and_transform(
79
+ vision_backbone_id: str, image_resize_strategy: str
80
+ ) -> Tuple[VisionBackbone, ImageTransform]:
81
+ """Instantiate a Vision Backbone, returning both the nn.Module wrapper class and default Image Transform."""
82
+ if vision_backbone_id in VISION_BACKBONES:
83
+ vision_cfg = VISION_BACKBONES[vision_backbone_id]
84
+ vision_backbone: VisionBackbone = vision_cfg["cls"](
85
+ vision_backbone_id, image_resize_strategy, **vision_cfg["kwargs"]
86
+ )
87
+ image_transform = vision_backbone.get_image_transform()
88
+ return vision_backbone, image_transform
89
+
90
+ else:
91
+ raise ValueError(f"Vision Backbone `{vision_backbone_id}` is not supported!")
92
+
93
+
94
+ def get_llm_backbone_and_tokenizer(
95
+ llm_backbone_id: str,
96
+ llm_max_length: int = 2048,
97
+ hf_token: Optional[str] = None,
98
+ inference_mode: bool = False,
99
+ ) -> Tuple[LLMBackbone, PreTrainedTokenizerBase]:
100
+ if llm_backbone_id in LLM_BACKBONES:
101
+ llm_cfg = LLM_BACKBONES[llm_backbone_id]
102
+ llm_backbone: LLMBackbone = llm_cfg["cls"](
103
+ llm_backbone_id,
104
+ llm_max_length=llm_max_length,
105
+ hf_token=hf_token,
106
+ inference_mode=inference_mode,
107
+ **llm_cfg["kwargs"],
108
+ )
109
+ tokenizer = llm_backbone.get_tokenizer()
110
+ return llm_backbone, tokenizer
111
+
112
+ else:
113
+ raise ValueError(f"LLM Backbone `{llm_backbone_id}` is not supported!")
114
+
115
+
116
+ def get_vlm(
117
+ model_id: str,
118
+ arch_specifier: str,
119
+ vision_backbone: VisionBackbone,
120
+ llm_backbone: LLMBackbone,
121
+ enable_mixed_precision_training: bool = True,
122
+ ) -> PrismaticVLM:
123
+ """Lightweight wrapper around initializing a VLM, mostly for future-proofing (if one wants to add a new VLM)."""
124
+ return PrismaticVLM(
125
+ model_id,
126
+ vision_backbone,
127
+ llm_backbone,
128
+ enable_mixed_precision_training=enable_mixed_precision_training,
129
+ arch_specifier=arch_specifier,
130
+ )
capvector-oft/prismatic/models/projectors.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of additional projectors for additional inputs to the VLA models."""
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class ProprioProjector(nn.Module):
7
+ """
8
+ Projects proprio state inputs into the LLM's embedding space.
9
+ """
10
+ def __init__(self, llm_dim: int, proprio_dim: int) -> None:
11
+ super().__init__()
12
+ self.llm_dim = llm_dim
13
+ self.proprio_dim = proprio_dim
14
+
15
+ self.fc1 = nn.Linear(self.proprio_dim, self.llm_dim, bias=True)
16
+ self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
17
+ self.act_fn1 = nn.GELU()
18
+
19
+ def forward(self, proprio: torch.Tensor = None) -> torch.Tensor:
20
+ # proprio: (bsz, proprio_dim)
21
+ projected_features = self.fc1(proprio)
22
+ projected_features = self.act_fn1(projected_features)
23
+ projected_features = self.fc2(projected_features)
24
+ return projected_features
25
+
26
+
27
+ class NoisyActionProjector(nn.Module):
28
+ """
29
+ [Diffusion] Projects noisy action inputs into the LLM's embedding space.
30
+
31
+ Note that since each action is tokenized into 7 tokens in OpenVLA (rather
32
+ than having 1 token per action), each noisy action token will have dimension 1
33
+ instead of 7.
34
+ """
35
+ def __init__(self, llm_dim: int) -> None:
36
+ super().__init__()
37
+ self.llm_dim = llm_dim
38
+ self.action_token_dim = 1
39
+
40
+ self.fc1 = nn.Linear(self.action_token_dim, self.llm_dim, bias=True)
41
+ self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
42
+ self.act_fn1 = nn.GELU()
43
+
44
+ def forward(self, noisy_actions: torch.Tensor = None) -> torch.Tensor:
45
+ # noisy_actions: (bsz, num_action_tokens=chunk_len*action_dim, 1)
46
+ projected_features = self.fc1(noisy_actions)
47
+ projected_features = self.act_fn1(projected_features)
48
+ projected_features = self.fc2(projected_features)
49
+ return projected_features
capvector-oft/prismatic/models/registry.py ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ registry.py
3
+
4
+ Exhaustive list of pretrained VLMs (with full descriptions / links to corresponding names and sections of paper).
5
+ """
6
+
7
+ # === Pretrained Model Registry ===
8
+ # fmt: off
9
+ MODEL_REGISTRY = {
10
+ # === LLaVa v1.5 Reproductions ===
11
+ "reproduction-llava-v15+7b": {
12
+ "model_id": "reproduction-llava-v15+7b",
13
+ "names": ["LLaVa v1.5 7B (Reproduction)"],
14
+ "description": {
15
+ "name": "LLaVa v1.5 7B (Reproduction)",
16
+ "optimization_procedure": "multi-stage",
17
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
18
+ "image_processing": "Letterbox",
19
+ "language_model": "Vicuña v1.5 7B",
20
+ "datasets": ["LLaVa v1.5 Instruct"],
21
+ "train_epochs": 1,
22
+ }
23
+ },
24
+ "reproduction-llava-v15+13b": {
25
+ "model_id": "reproduction-llava-v15+13b",
26
+ "names": ["LLaVa v1.5 13B (Reproduction)"],
27
+ "description": {
28
+ "name": "LLaVa v1.5 13B (Reproduction)",
29
+ "optimization_procedure": "multi-stage",
30
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
31
+ "image_processing": "Letterbox",
32
+ "language_model": "Vicuña v1.5 13B",
33
+ "datasets": ["LLaVa v1.5 Instruct"],
34
+ "train_epochs": 1,
35
+ }
36
+ },
37
+
38
+ # === Section 4.1 :: Optimization Procedure ===
39
+ "one-stage+7b": {
40
+ "model_id": "one-stage+7b",
41
+ "names": [
42
+ "One-Stage 7B",
43
+ "Single-Stage 7B",
44
+ "Frozen ViT (Single-Stage)",
45
+ "CLIP ViT-L 336px (Letterbox)",
46
+ "CLIP ViT-L 336px",
47
+ "Vicuña v1.5 7B",
48
+ "1 Epoch",
49
+ "Base",
50
+ ],
51
+ "description": {
52
+ "name": "Single-Stage 7B",
53
+ "optimization_procedure": "single-stage",
54
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
55
+ "image_processing": "Letterbox",
56
+ "language_model": "Vicuña v1.5 7B",
57
+ "datasets": ["LLaVa v1.5 Instruct"],
58
+ "train_epochs": 1,
59
+ }
60
+ },
61
+ "one-stage+13b": {
62
+ "model_id": "one-stage+13b",
63
+ "names": [
64
+ "One-Stage 13B",
65
+ "Single-Stage 13B",
66
+ "Vicuña v1.5 13B",
67
+ ],
68
+ "description": {
69
+ "name": "Single-Stage 13B",
70
+ "optimization_procedure": "single-stage",
71
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
72
+ "image_processing": "Letterbox",
73
+ "language_model": "Vicuña v1.5 13B",
74
+ "datasets": ["LLaVa v1.5 Instruct"],
75
+ "train_epochs": 1,
76
+ }
77
+ },
78
+
79
+ "full-ft-multi-stage+7b": {
80
+ "model_id": "full-ft-multi-stage+7b",
81
+ "names": ["Finetune ViT (Multi-Stage)"],
82
+ "description": {
83
+ "name": "Finetune ViT (Multi-Stage)",
84
+ "optimization_procedure": "multi-stage-full-finetune",
85
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
86
+ "image_processing": "Letterbox",
87
+ "language_model": "Vicuña v1.5 7B",
88
+ "datasets": ["LLaVa v1.5 Instruct"],
89
+ "train_epochs": 1,
90
+ }
91
+ },
92
+ "full-ft-one-stage+7b": {
93
+ "model_id": "full-ft-one-stage+7b",
94
+ "names": ["Finetune ViT (Single-Stage)"],
95
+ "description": {
96
+ "name": "Finetune ViT (Single-Stage)",
97
+ "optimization_procedure": "single-stage-full-finetune",
98
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
99
+ "image_processing": "Letterbox",
100
+ "language_model": "Vicuña v1.5 7B",
101
+ "datasets": ["LLaVa v1.5 Instruct"],
102
+ "train_epochs": 1,
103
+ }
104
+ },
105
+
106
+ # === Section 4.2 :: Image Processing and Visual Representations ===
107
+ "in1k-224px+7b": {
108
+ "model_id": "in1k-224px+7b",
109
+ "names": ["IN1K ViT-L 224px"],
110
+ "description": {
111
+ "name": "IN1K ViT-L 224px",
112
+ "optimization_procedure": "single-stage",
113
+ "visual_representation": "ImageNet-21K+1K ViT-L/16 @ 224px",
114
+ "image_processing": "Letterbox",
115
+ "language_model": "Vicuña v1.5 7B",
116
+ "datasets": ["LLaVa v1.5 Instruct"],
117
+ "train_epochs": 1,
118
+ },
119
+ },
120
+ "dinov2-224px+7b": {
121
+ "model_id": "dinov2-224px+7b",
122
+ "names": ["DINOv2 ViT-L 224px"],
123
+ "description": {
124
+ "name": "DINOv2 ViT-L 224px",
125
+ "optimization_procedure": "single-stage",
126
+ "visual_representation": "DINOv2 ViT-L/14 @ 224px",
127
+ "image_processing": "Letterbox",
128
+ "language_model": "Vicuña v1.5 7B",
129
+ "datasets": ["LLaVa v1.5 Instruct"],
130
+ "train_epochs": 1,
131
+ },
132
+ },
133
+ "clip-224px+7b": {
134
+ "model_id": "clip-224px+7b",
135
+ "names": ["CLIP ViT-L 224px"],
136
+ "description": {
137
+ "name": "CLIP ViT-L 224px",
138
+ "optimization_procedure": "single-stage",
139
+ "visual_representation": "CLIP ViT-L/14 @ 224px",
140
+ "image_processing": "Letterbox",
141
+ "language_model": "Vicuña v1.5 7B",
142
+ "datasets": ["LLaVa v1.5 Instruct"],
143
+ "train_epochs": 1,
144
+ },
145
+ },
146
+ "siglip-224px+7b": {
147
+ "model_id": "siglip-224px+7b",
148
+ "names": ["SigLIP ViT-SO 224px"],
149
+ "description": {
150
+ "name": "SigLIP ViT-SO 224px",
151
+ "optimization_procedure": "single-stage",
152
+ "visual_representation": "SigLIP ViT-SO/14 @ 224px",
153
+ "image_processing": "Letterbox",
154
+ "language_model": "Vicuña v1.5 7B",
155
+ "datasets": ["LLaVa v1.5 Instruct"],
156
+ "train_epochs": 1,
157
+ },
158
+ },
159
+
160
+ "clip-336px-resize-crop+7b": {
161
+ "model_id": "clip-336px-resize-crop+7b",
162
+ "names": ["CLIP ViT-L 336px (Resize Crop)"],
163
+ "description": {
164
+ "name": "CLIP ViT-L 336px (Resize Crop)",
165
+ "optimization_procedure": "single-stage",
166
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
167
+ "image_processing": "Resize Crop",
168
+ "language_model": "Vicuña v1.5 7B",
169
+ "datasets": ["LLaVa v1.5 Instruct"],
170
+ "train_epochs": 1,
171
+ }
172
+ },
173
+ "clip-336px-resize-naive+7b": {
174
+ "model_id": "clip-336px-resize-naive+7b",
175
+ "names": ["CLIP ViT-L 336px (Naive Resize)", "CLIP 336px (Naive Resize)"],
176
+ "description": {
177
+ "name": "CLIP ViT-L 336px (Naive Resize)",
178
+ "optimization_procedure": "single-stage",
179
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
180
+ "image_processing": "Naive Resize",
181
+ "language_model": "Vicuña v1.5 7B",
182
+ "datasets": ["LLaVa v1.5 Instruct"],
183
+ "train_epochs": 1,
184
+ }
185
+ },
186
+ "siglip-384px-letterbox+7b": {
187
+ "model_id": "siglip-384px-letterbox+7b",
188
+ "names": ["SigLIP ViT-SO 384px (Letterbox)", "SigLIP ViT-SO 384px"],
189
+ "description": {
190
+ "name": "SigLIP ViT-SO 384px (Letterbox)",
191
+ "optimization_procedure": "single-stage",
192
+ "visual_representation": "SigLIP ViT-SO/14 @ 384px",
193
+ "image_processing": "Letterbox",
194
+ "language_model": "Vicuña v1.5 7B",
195
+ "datasets": ["LLaVa v1.5 Instruct"],
196
+ "train_epochs": 1,
197
+ }
198
+ },
199
+ "siglip-384px-resize-crop+7b": {
200
+ "model_id": "siglip-384px-resize-crop+7b",
201
+ "names": ["SigLIP ViT-SO 384px (Resize Crop)"],
202
+ "description": {
203
+ "name": "SigLIP ViT-SO 384px (Resize Crop)",
204
+ "optimization_procedure": "single-stage",
205
+ "visual_representation": "SigLIP ViT-SO/14 @ 384px",
206
+ "image_processing": "Resize Crop",
207
+ "language_model": "Vicuña v1.5 7B",
208
+ "datasets": ["LLaVa v1.5 Instruct"],
209
+ "train_epochs": 1,
210
+ }
211
+ },
212
+ "siglip-384px-resize-naive+7b": {
213
+ "model_id": "siglip-384px-resize-naive+7b",
214
+ "names": ["SigLIP ViT-SO 384px (Naive Resize)", "SigLIP 384px (Naive Resize)"],
215
+ "description": {
216
+ "name": "SigLIP ViT-SO 384px (Naive Resize)",
217
+ "optimization_procedure": "single-stage",
218
+ "visual_representation": "SigLIP ViT-SO/14 @ 384px",
219
+ "image_processing": "Naive Resize",
220
+ "language_model": "Vicuña v1.5 7B",
221
+ "datasets": ["LLaVa v1.5 Instruct"],
222
+ "train_epochs": 1,
223
+ }
224
+ },
225
+
226
+ "dinoclip-336px-letterbox+7b": {
227
+ "model_id": "dinoclip-336px-letterbox+7b",
228
+ "names": ["DINOv2 + CLIP 336px (Letterbox)"],
229
+ "description": {
230
+ "name": "DINOv2 + CLIP 336px (Letterbox)",
231
+ "optimization_procedure": "single-stage",
232
+ "visual_representation": "DINOv2 ViT-L/14 + CLIP ViT-L/14 @ 336px",
233
+ "image_processing": "Letterbox",
234
+ "language_model": "Vicuña v1.5 7B",
235
+ "datasets": ["LLaVa v1.5 Instruct"],
236
+ "train_epochs": 1,
237
+ }
238
+ },
239
+ "dinoclip-336px-resize-naive+7b": {
240
+ "model_id": "dinoclip-336px-resize-naive+7b",
241
+ "names": ["DINOv2 + CLIP 336px (Naive Resize)"],
242
+ "description": {
243
+ "name": "DINOv2 + CLIP 336px (Naive Resize)",
244
+ "optimization_procedure": "single-stage",
245
+ "visual_representation": "DINOv2 ViT-L/14 + CLIP ViT-L/14 @ 336px",
246
+ "image_processing": "Naive Resize",
247
+ "language_model": "Vicuña v1.5 7B",
248
+ "datasets": ["LLaVa v1.5 Instruct"],
249
+ "train_epochs": 1,
250
+ }
251
+ },
252
+ "dinosiglip-384px-letterbox+7b": {
253
+ "model_id": "dinosiglip-384px-letterbox+7b",
254
+ "names": ["DINOv2 + SigLIP 384px (Letterbox)"],
255
+ "description": {
256
+ "name": "DINOv2 + SigLIP 384px (Letterbox)",
257
+ "optimization_procedure": "single-stage",
258
+ "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-L/14 @ 384px",
259
+ "image_processing": "Letterbox",
260
+ "language_model": "Vicuña v1.5 7B",
261
+ "datasets": ["LLaVa v1.5 Instruct"],
262
+ "train_epochs": 1,
263
+ }
264
+ },
265
+ "dinosiglip-384px-resize-naive+7b": {
266
+ "model_id": "dinosiglip-384px-resize-naive+7b",
267
+ "names": ["DINOv2 + SigLIP 384px (Naive Resize)"],
268
+ "description": {
269
+ "name": "DINOv2 + SigLIP 384px (Naive Resize)",
270
+ "optimization_procedure": "single-stage",
271
+ "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-L/14 @ 384px",
272
+ "image_processing": "Naive Resize",
273
+ "language_model": "Vicuña v1.5 7B",
274
+ "datasets": ["LLaVa v1.5 Instruct"],
275
+ "train_epochs": 1,
276
+ }
277
+ },
278
+
279
+ # === Section 4.3 :: Language Models ===
280
+ "llama2+7b": {
281
+ "model_id": "llama2+7b",
282
+ "names": ["Llama-2 7B"],
283
+ "description": {
284
+ "name": "Llama-2 7B",
285
+ "optimization_procedure": "single-stage",
286
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
287
+ "image_processing": "Letterbox",
288
+ "language_model": "Llama-2 7B",
289
+ "datasets": ["LLaVa v1.5 Instruct"],
290
+ "train_epochs": 1,
291
+ },
292
+ },
293
+ "llama2+13b": {
294
+ "model_id": "llama2+13b",
295
+ "names": ["Llama-2 13B"],
296
+ "description": {
297
+ "name": "Llama-2 13B",
298
+ "optimization_procedure": "single-stage",
299
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
300
+ "image_processing": "Letterbox",
301
+ "language_model": "Llama-2 13B",
302
+ "datasets": ["LLaVa v1.5 Instruct"],
303
+ "train_epochs": 1,
304
+ },
305
+ },
306
+
307
+ "vicuna-no-cotraining+7b": {
308
+ "model_id": "vicuna-no-cotraining+7b",
309
+ "names": ["Vicuña v1.5 7B (No Co-training)"],
310
+ "description": {
311
+ "name": "Vicuña v1.5 7B (No Co-training)",
312
+ "optimization_procedure": "single-stage",
313
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
314
+ "image_processing": "Letterbox",
315
+ "language_model": "Vicuña v1.5 7B",
316
+ "datasets": ["LLaVa v1.5 Multimodal-Only"],
317
+ "train_epochs": 1,
318
+ },
319
+ },
320
+ "llama2-no-cotraining+7b": {
321
+ "model_id": "llama2-no-cotraining+7b",
322
+ "names": ["Llama-2 7B (No Co-training)"],
323
+ "description": {
324
+ "name": "Llama-2 7B (No Co-training)",
325
+ "optimization_procedure": "single-stage",
326
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
327
+ "image_processing": "Letterbox",
328
+ "language_model": "Llama-2 7B",
329
+ "datasets": ["LLaVa v1.5 Multimodal-Only"],
330
+ "train_epochs": 1,
331
+ },
332
+ },
333
+
334
+ # === Section 4.4 :: Scaling Properties ===
335
+ "train-1.25-epochs+7b": {
336
+ "model_id": "train-1.25-epochs+7b",
337
+ "names": ["1.25 Epochs"],
338
+ "description": {
339
+ "name": "1.25 Epochs",
340
+ "optimization_procedure": "single-stage",
341
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
342
+ "image_processing": "Letterbox",
343
+ "language_model": "Vicuña v1.5 7B",
344
+ "datasets": ["LLaVa v1.5 Instruct"],
345
+ "train_epochs": 1.25,
346
+ }
347
+ },
348
+ "train-1.5-epochs+7b": {
349
+ "model_id": "train-1.5-epochs+7b",
350
+ "names": ["1.5 Epochs"],
351
+ "description": {
352
+ "name": "1.5 Epochs",
353
+ "optimization_procedure": "single-stage",
354
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
355
+ "image_processing": "Letterbox",
356
+ "language_model": "Vicuña v1.5 7B",
357
+ "datasets": ["LLaVa v1.5 Instruct"],
358
+ "train_epochs": 1.5,
359
+ }
360
+ },
361
+ "train-2-epochs+7b": {
362
+ "model_id": "train-2-epochs+7b",
363
+ "names": ["2 Epochs"],
364
+ "description": {
365
+ "name": "2 Epochs",
366
+ "optimization_procedure": "single-stage",
367
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
368
+ "image_processing": "Letterbox",
369
+ "language_model": "Vicuña v1.5 7B",
370
+ "datasets": ["LLaVa v1.5 Instruct"],
371
+ "train_epochs": 2,
372
+ }
373
+ },
374
+ "train-3-epochs+7b": {
375
+ "model_id": "train-3-epochs+7b",
376
+ "names": ["3 Epochs"],
377
+ "description": {
378
+ "name": "3 Epochs",
379
+ "optimization_procedure": "single-stage",
380
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
381
+ "image_processing": "Letterbox",
382
+ "language_model": "Vicuña v1.5 7B",
383
+ "datasets": ["LLaVa v1.5 Instruct"],
384
+ "train_epochs": 3,
385
+ }
386
+ },
387
+
388
+ "llava-lvis4v+7b": {
389
+ "model_id": "llava-lvis4v+7b",
390
+ "names": ["Base + LVIS-4V"],
391
+ "description": {
392
+ "name": "Base + LVIS-4V",
393
+ "optimization_procedure": "single-stage",
394
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
395
+ "image_processing": "Letterbox",
396
+ "language_model": "Vicuña v1.5 7B",
397
+ "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V"],
398
+ "train_epochs": 1,
399
+ }
400
+ },
401
+ "llava-lrv+7b": {
402
+ "model_id": "llava-lrv+7b",
403
+ "names": ["Base + LRV"],
404
+ "description": {
405
+ "name": "Base + LRV",
406
+ "optimization_procedure": "single-stage",
407
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
408
+ "image_processing": "Letterbox",
409
+ "language_model": "Vicuña v1.5 7B",
410
+ "datasets": ["LLaVa v1.5 Instruct", "LRV-Instruct"],
411
+ "train_epochs": 1,
412
+ }
413
+ },
414
+ "llava-lvis4v-lrv+7b": {
415
+ "model_id": "llava-lvis4v-lrv+7b",
416
+ "names": ["Base + LVIS-4V + LRV"],
417
+ "description": {
418
+ "name": "Base + LVIS-4V + LRV",
419
+ "optimization_procedure": "single-stage",
420
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
421
+ "image_processing": "Letterbox",
422
+ "language_model": "Vicuña v1.5 7B",
423
+ "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"],
424
+ "train_epochs": 1,
425
+ }
426
+ },
427
+
428
+ # ===
429
+
430
+ # === CLIP Prism Models ===
431
+ "prism-clip-controlled+7b": {
432
+ "model_id": "prism-clip-controlled+7b",
433
+ "names": ["Prism-CLIP 7B (Controlled)"],
434
+ "description": {
435
+ "name": "CLIP Prism 7B (Controlled)",
436
+ "optimization_procedure": "single-stage",
437
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
438
+ "image_processing": "Naive Resize",
439
+ "language_model": "Llama-2 7B",
440
+ "datasets": ["LLaVa v1.5 Instruct"],
441
+ "train_epochs": 1,
442
+ }
443
+ },
444
+ "prism-clip-controlled+13b": {
445
+ "model_id": "prism-clip-controlled+13b",
446
+ "names": ["Prism-CLIP 13B (Controlled)"],
447
+ "description": {
448
+ "name": "CLIP Prism 13B (Controlled)",
449
+ "optimization_procedure": "single-stage",
450
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
451
+ "image_processing": "Naive Resize",
452
+ "language_model": "Llama-2 13B",
453
+ "datasets": ["LLaVa v1.5 Instruct"],
454
+ "train_epochs": 1,
455
+ }
456
+ },
457
+ "prism-clip+7b": {
458
+ "model_id": "prism-clip+7b",
459
+ "names": ["Prism-CLIP 7B"],
460
+ "description": {
461
+ "name": "CLIP Prism 7B",
462
+ "optimization_procedure": "single-stage",
463
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
464
+ "image_processing": "Naive Resize",
465
+ "language_model": "Llama-2 7B",
466
+ "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"],
467
+ "train_epochs": 2,
468
+ },
469
+ },
470
+ "prism-clip+13b": {
471
+ "model_id": "prism-clip+13b",
472
+ "names": ["Prism-CLIP 13B"],
473
+ "description": {
474
+ "name": "CLIP Prism 13B",
475
+ "optimization_procedure": "single-stage",
476
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
477
+ "image_processing": "Naive Resize",
478
+ "language_model": "Llama-2 13B",
479
+ "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"],
480
+ "train_epochs": 2,
481
+ },
482
+ },
483
+
484
+ # === SigLIP Prism Models ==
485
+ "prism-siglip-controlled+7b": {
486
+ "model_id": "prism-siglip-controlled+7b",
487
+ "names": ["Prism-SigLIP 7B (Controlled)"],
488
+ "description": {
489
+ "name": "SigLIP Prism 7B (Controlled)",
490
+ "optimization_procedure": "single-stage",
491
+ "visual_representation": "SigLIP ViT-SO/14 @ 384px",
492
+ "image_processing": "Naive Resize",
493
+ "language_model": "Llama-2 7B",
494
+ "datasets": ["LLaVa v1.5 Instruct"],
495
+ "train_epochs": 1,
496
+ }
497
+ },
498
+ "prism-siglip-controlled+13b": {
499
+ "model_id": "prism-siglip-controlled+7b",
500
+ "names": ["Prism-SigLIP 13B (Controlled)"],
501
+ "description": {
502
+ "name": "SigLIP Prism 13B (Controlled)",
503
+ "optimization_procedure": "single-stage",
504
+ "visual_representation": "SigLIP ViT-SO/14 @ 384px",
505
+ "image_processing": "Naive Resize",
506
+ "language_model": "Llama-2 13B",
507
+ "datasets": ["LLaVa v1.5 Instruct"],
508
+ "train_epochs": 1,
509
+ }
510
+ },
511
+ "prism-siglip+7b": {
512
+ "model_id": "prism-siglip+7b",
513
+ "names": ["Prism-SigLIP 7B"],
514
+ "description": {
515
+ "name": "SigLIP Prism 7B",
516
+ "optimization_procedure": "single-stage",
517
+ "visual_representation": "SigLIP ViT-SO/14 @ 384px",
518
+ "image_processing": "Naive Resize",
519
+ "language_model": "Llama-2 7B",
520
+ "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"],
521
+ "train_epochs": 2,
522
+ }
523
+ },
524
+ "prism-siglip+13b": {
525
+ "model_id": "prism-siglip+13b",
526
+ "names": ["Prism-SigLIP 13B"],
527
+ "description": {
528
+ "name": "SigLIP Prism 13B",
529
+ "optimization_procedure": "single-stage",
530
+ "visual_representation": "SigLIP ViT-SO/14 @ 384px",
531
+ "image_processing": "Naive Resize",
532
+ "language_model": "Llama-2 13B",
533
+ "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"],
534
+ "train_epochs": 2,
535
+ }
536
+ },
537
+
538
+ # === DINOSigLIP Prism Models ===
539
+ "prism-dinosiglip-controlled+7b": {
540
+ "model_id": "prism-dinosiglip-controlled+7b",
541
+ "names": ["Prism-DINOSigLIP 7B (Controlled)", "Prism 7B (Controlled)"],
542
+ "description": {
543
+ "name": "DINOSigLIP Prism 7B (Controlled)",
544
+ "optimization_procedure": "single-stage",
545
+ "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO/14 @ 384px",
546
+ "image_processing": "Naive Resize",
547
+ "language_model": "Llama-2 7B",
548
+ "datasets": ["LLaVa v1.5 Instruct"],
549
+ "train_epochs": 1,
550
+ }
551
+ },
552
+ "prism-dinosiglip-controlled+13b": {
553
+ "model_id": "prism-dinosiglip-controlled+13b",
554
+ "names": ["Prism-DINOSigLIP 13B (Controlled)", "Prism 13B (Controlled)"],
555
+ "description": {
556
+ "name": "DINOSigLIP Prism 13B (Controlled)",
557
+ "optimization_procedure": "single-stage",
558
+ "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO/14 @ 384px",
559
+ "image_processing": "Naive Resize",
560
+ "language_model": "Llama-2 13B",
561
+ "datasets": ["LLaVa v1.5 Instruct"],
562
+ "train_epochs": 1,
563
+ }
564
+ },
565
+ "prism-dinosiglip+7b": {
566
+ "model_id": "prism-dinosiglip+7b",
567
+ "names": ["Prism-DINOSigLIP 7B"],
568
+ "description": {
569
+ "name": "DINOSigLIP Prism 7B",
570
+ "optimization_procedure": "single-stage",
571
+ "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO/14 @ 384px",
572
+ "image_processing": "Naive Resize",
573
+ "language_model": "Llama-2 7B",
574
+ "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"],
575
+ "train_epochs": 2,
576
+ },
577
+ },
578
+ "prism-dinosiglip+13b": {
579
+ "model_id": "prism-dinosiglip+13b",
580
+ "names": ["Prism-DINOSigLIP 13B"],
581
+ "description": {
582
+ "name": "DINOSigLIP Prism 13B",
583
+ "optimization_procedure": "single-stage",
584
+ "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO/14 @ 384px",
585
+ "image_processing": "Naive Resize",
586
+ "language_model": "Llama-2 13B",
587
+ "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"],
588
+ "train_epochs": 2,
589
+ },
590
+ },
591
+
592
+ # === DINOSigLIP 224px Prism Models ===
593
+ "prism-dinosiglip-224px-controlled+7b": {
594
+ "model_id": "prism-dinosiglip-224px-controlled+7b",
595
+ "names": ["Prism-DINOSigLIP 224px 7B (Controlled)"],
596
+ "description": {
597
+ "name": "DINOSigLIP 224px 7B (Controlled)",
598
+ "optimization_procedure": "single-stage",
599
+ "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO 14 @ 224px",
600
+ "image_processing": "Naive Resize",
601
+ "language_model": "Llama-2 7B",
602
+ "datasets": ["LLaVa v1.5 Instruct"],
603
+ "train_epochs": 1,
604
+ }
605
+ },
606
+ "prism-dinosiglip-224px+7b": {
607
+ "model_id": "prism-dinosiglip-224px+7b",
608
+ "names": ["Prism-DINOSigLIP 224px 7B"],
609
+ "description": {
610
+ "name": "DINOSigLIP 224px 7B",
611
+ "optimization_procedure": "single-stage",
612
+ "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO 14 @ 224px",
613
+ "image_processing": "Naive Resize",
614
+ "language_model": "Llama-2 7B",
615
+ "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"],
616
+ "train_epochs": 2,
617
+ }
618
+ },
619
+
620
+ # === Additional LLM Backbones ===
621
+ "llama2-chat+7b": {
622
+ "model_id": "llama2-chat+7b",
623
+ "names": ["Llama-2 Chat 7B"],
624
+ "description": {
625
+ "name": "Llama-2 Chat 7B",
626
+ "optimization_procedure": "single-stage",
627
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
628
+ "image_processing": "Letterbox",
629
+ "language_model": "Llama-2 Chat 7B",
630
+ "datasets": ["LLaVa v1.5 Instruct"],
631
+ "train_epochs": 1,
632
+ }
633
+ },
634
+ "llama2-chat+13b": {
635
+ "model_id": "llama2-chat+13b",
636
+ "names": ["Llama-2 Chat 13B"],
637
+ "description": {
638
+ "name": "Llama-2 Chat 13B",
639
+ "optimization_procedure": "single-stage",
640
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
641
+ "image_processing": "Letterbox",
642
+ "language_model": "Llama-2 Chat 13B",
643
+ "datasets": ["LLaVa v1.5 Instruct"],
644
+ "train_epochs": 1,
645
+ }
646
+ },
647
+ "mistral-v0.1+7b": {
648
+ "model_id": "mistral-v0.1+7b",
649
+ "names": ["Mistral v0.1 7B"],
650
+ "description": {
651
+ "name": "Mistral v0.1 7B",
652
+ "optimization_procedure": "single-stage",
653
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
654
+ "image_processing": "Letterbox",
655
+ "language_model": "Mistral v0.1 7B",
656
+ "datasets": ["LLaVa v1.5 Instruct"],
657
+ "train_epochs": 1,
658
+ }
659
+ },
660
+ "mistral-instruct-v0.1+7b": {
661
+ "model_id": "mistral-instruct-v0.1+7b",
662
+ "names": ["Mistral Instruct v0.1 7B"],
663
+ "description": {
664
+ "name": "Mistral Instruct v0.1 7B",
665
+ "optimization_procedure": "single-stage",
666
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
667
+ "image_processing": "Letterbox",
668
+ "language_model": "Mistral Instruct v0.1 7B",
669
+ "datasets": ["LLaVa v1.5 Instruct"],
670
+ "train_epochs": 1,
671
+ }
672
+ },
673
+ "phi-2+3b": {
674
+ "model_id": "phi-2+3b",
675
+ "names": ["Phi-2 3B"],
676
+ "description": {
677
+ "name": "Phi-2 3B",
678
+ "optimization_procedure": "single-stage",
679
+ "visual_representation": "CLIP ViT-L/14 @ 336px",
680
+ "image_processing": "Letterbox",
681
+ "language_model": "Phi-2 3B",
682
+ "datasets": ["LLaVa v1.5 Instruct"],
683
+ "train_epochs": 1,
684
+ }
685
+ },
686
+ }
687
+
688
+ # Build Global Registry (Model ID, Name) -> Metadata
689
+ GLOBAL_REGISTRY = {name: v for k, v in MODEL_REGISTRY.items() for name in [k] + v["names"]}
690
+
691
+ # fmt: on
capvector-oft/prismatic/models/vlas/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .openvla import OpenVLA
capvector-oft/prismatic/models/vlas/openvla.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ openvla.py
3
+
4
+ PyTorch Module defining OpenVLA as a lightweight wrapper around a PrismaticVLM; defines custom logic around
5
+ discretizing actions with the ActionTokenizer.
6
+ """
7
+
8
+ from typing import Dict, List, Optional
9
+
10
+ import numpy as np
11
+ import torch
12
+ from PIL import Image
13
+ from transformers import LlamaTokenizerFast
14
+
15
+ from prismatic.models.vlms.prismatic import PrismaticVLM
16
+ from prismatic.overwatch import initialize_overwatch
17
+ from prismatic.vla.action_tokenizer import ActionTokenizer
18
+
19
+ # Initialize Overwatch =>> Wraps `logging.Logger`
20
+ overwatch = initialize_overwatch(__name__)
21
+
22
+
23
+ class OpenVLA(PrismaticVLM):
24
+ def __init__(
25
+ self,
26
+ *args,
27
+ norm_stats: Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]],
28
+ action_tokenizer: ActionTokenizer,
29
+ **kwargs,
30
+ ) -> None:
31
+ super().__init__(*args, **kwargs)
32
+ self.norm_stats = norm_stats
33
+ self.action_tokenizer = action_tokenizer
34
+
35
+ @torch.inference_mode()
36
+ def predict_action(
37
+ self, image: Image, instruction: str, unnorm_key: Optional[str] = None, **kwargs: str
38
+ ) -> np.ndarray:
39
+ """
40
+ Core function for VLA inference; maps input image and task instruction to continuous action (de-tokenizes).
41
+
42
+ @param image: PIL Image as [height, width, 3]
43
+ @param instruction: Task instruction string
44
+ @param unnorm_key: Optional dataset name for retrieving un-normalizing statistics; if None, checks that model
45
+ was trained only on a single dataset, and retrieves those statistics.
46
+
47
+ @return Unnormalized (continuous) action vector --> end-effector deltas.
48
+ """
49
+ image_transform, tokenizer = self.vision_backbone.image_transform, self.llm_backbone.tokenizer
50
+
51
+ # Build VLA Prompt
52
+ prompt_builder = self.get_prompt_builder()
53
+ prompt_builder.add_turn(role="human", message=f"What action should the robot take to {instruction.lower()}?")
54
+ prompt_text = prompt_builder.get_prompt()
55
+
56
+ # Prepare Inputs
57
+ input_ids = tokenizer(prompt_text, truncation=True, return_tensors="pt").input_ids.to(self.device)
58
+ if isinstance(tokenizer, LlamaTokenizerFast):
59
+ # If the special empty token ('') does not already appear after the colon (':') token in the prompt
60
+ # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
61
+ if not torch.all(input_ids[:, -1] == 29871):
62
+ input_ids = torch.cat(
63
+ (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
64
+ )
65
+ else:
66
+ raise ValueError(f"Unsupported `tokenizer` type = {type(tokenizer)}")
67
+
68
+ # Preprocess Image
69
+ pixel_values = image_transform(image)
70
+ if isinstance(pixel_values, torch.Tensor):
71
+ pixel_values = pixel_values[None, ...].to(self.device)
72
+ elif isinstance(pixel_values, dict):
73
+ pixel_values = {k: v[None, ...].to(self.device) for k, v in pixel_values.items()}
74
+ else:
75
+ raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}")
76
+
77
+ # Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()`
78
+ autocast_dtype = self.llm_backbone.half_precision_dtype
79
+ with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training):
80
+ # fmt: off
81
+ generated_ids = super(PrismaticVLM, self).generate(
82
+ input_ids=input_ids, # Shape: [1, seq]
83
+ pixel_values=pixel_values, # Shape: [1, 3, res, res] or Dict[str, ...]
84
+ max_new_tokens=self.get_action_dim(unnorm_key),
85
+ **kwargs
86
+ )
87
+ # fmt: on
88
+
89
+ # Extract predicted action tokens and translate into (normalized) continuous actions
90
+ predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :]
91
+ normalized_actions = self.action_tokenizer.decode_token_ids_to_actions(predicted_action_token_ids.cpu().numpy())
92
+
93
+ # Un-normalize Actions
94
+ action_norm_stats = self.get_action_stats(unnorm_key)
95
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool))
96
+ action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
97
+ actions = np.where(
98
+ mask,
99
+ 0.5 * (normalized_actions + 1) * (action_high - action_low) + action_low,
100
+ normalized_actions,
101
+ )
102
+
103
+ return actions
104
+
105
+ @staticmethod
106
+ def _check_unnorm_key(norm_stats: Dict, unnorm_key: str) -> str:
107
+ if unnorm_key is None:
108
+ assert len(norm_stats) == 1, (
109
+ f"Your model was trained on more than one dataset, please pass a `unnorm_key` from the following "
110
+ f"options to choose the statistics used for un-normalizing actions: {norm_stats.keys()}"
111
+ )
112
+ unnorm_key = next(iter(norm_stats.keys()))
113
+
114
+ # Error Handling
115
+ assert (
116
+ unnorm_key in norm_stats
117
+ ), f"The `unnorm_key` you chose is not in the set of available statistics; choose from: {norm_stats.keys()}"
118
+
119
+ return unnorm_key
120
+
121
+ def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
122
+ """Dimensionality of the policy's action space."""
123
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
124
+
125
+ return len(self.norm_stats[unnorm_key]["action"]["q01"])
126
+
127
+ def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict:
128
+ """Dimensionality of the policy's action space."""
129
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
130
+
131
+ return self.norm_stats[unnorm_key]["action"]
capvector-oft/prismatic/models/vlms/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .prismatic import PrismaticVLM
capvector-oft/prismatic/models/vlms/base_vlm.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ base_vlm.py
3
+
4
+ Abstract class definition of a Vision-Language Model (VLM), with full annotations of class methods, utility functions,
5
+ and initialization logic. This is mostly to future-proof the codebase; while all our experiments instantiate
6
+ from PrismaticVLM, theoretically, this base class should be general enough to cover almost all models (e.g., IDEFICS,
7
+ PALI, Fuyu) in the future.
8
+
9
+ We use Abstract base classes *sparingly* -- mostly as a way to encapsulate any redundant logic or nested inheritance
10
+ (e.g., dependence on nn.Module, HF PretrainedModel, etc.). For other abstract objects (e.g., Tokenizers/Transforms),
11
+ prefer Protocol definitions instead.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from abc import ABC, abstractmethod
17
+ from pathlib import Path
18
+ from typing import Callable, List, Optional
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ from transformers import GenerationMixin, PretrainedConfig
23
+ from transformers.modeling_outputs import CausalLMOutputWithPast
24
+
25
+ from prismatic.models.backbones.llm import LLMBackbone
26
+ from prismatic.models.backbones.llm.prompting import PromptBuilder
27
+ from prismatic.models.backbones.vision import VisionBackbone
28
+
29
+
30
+ # === Abstract Base Class for arbitrary Vision-Language Models ===
31
+ class VLM(nn.Module, GenerationMixin, ABC):
32
+ def __init__(
33
+ self,
34
+ model_family: str,
35
+ model_id: str,
36
+ vision_backbone: VisionBackbone,
37
+ llm_backbone: LLMBackbone,
38
+ enable_mixed_precision_training: bool = True,
39
+ ) -> None:
40
+ super().__init__()
41
+ self.model_family, self.model_id = model_family, model_id
42
+ self.vision_backbone, self.llm_backbone = vision_backbone, llm_backbone
43
+ self.enable_mixed_precision_training = enable_mixed_precision_training
44
+
45
+ # Instance Attributes for a generic VLM
46
+ self.all_module_keys, self.trainable_module_keys = None, None
47
+
48
+ # === GenerationMixin Expected Attributes =>> *DO NOT MODIFY* ===
49
+ self.generation_config = self.llm_backbone.llm.generation_config
50
+ self.main_input_name = "input_ids"
51
+
52
+ @property
53
+ def device(self) -> torch.device:
54
+ """Borrowed from `transformers.modeling_utils.py` -- checks parameter device; assumes model on *ONE* device!"""
55
+ return next(self.parameters()).device
56
+
57
+ @classmethod
58
+ @abstractmethod
59
+ def from_pretrained(
60
+ cls,
61
+ pretrained_checkpoint: Path,
62
+ model_family: str,
63
+ model_id: str,
64
+ vision_backbone: VisionBackbone,
65
+ llm_backbone: LLMBackbone,
66
+ **kwargs: str,
67
+ ) -> VLM: ...
68
+
69
+ @abstractmethod
70
+ def get_prompt_builder(self, system_prompt: Optional[str] = None) -> PromptBuilder: ...
71
+
72
+ @abstractmethod
73
+ def freeze_backbones(self, stage: str) -> None: ...
74
+
75
+ @abstractmethod
76
+ def load_from_checkpoint(self, stage: str, run_dir: Path, pretrained_checkpoint: Optional[Path] = None) -> None: ...
77
+
78
+ @abstractmethod
79
+ def get_fsdp_wrapping_policy(self) -> Callable: ...
80
+
81
+ @abstractmethod
82
+ def forward(
83
+ self,
84
+ input_ids: Optional[torch.LongTensor] = None,
85
+ attention_mask: Optional[torch.Tensor] = None,
86
+ pixel_values: Optional[torch.FloatTensor] = None,
87
+ labels: Optional[torch.LongTensor] = None,
88
+ inputs_embeds: Optional[torch.FloatTensor] = None,
89
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
90
+ use_cache: Optional[bool] = None,
91
+ output_attentions: Optional[bool] = None,
92
+ output_hidden_states: Optional[bool] = None,
93
+ return_dict: Optional[bool] = None,
94
+ multimodal_indices: Optional[torch.LongTensor] = None,
95
+ ) -> CausalLMOutputWithPast: ...
96
+
97
+ # === GenerationMixin Expected Properties & Methods (DO NOT MODIFY) ===
98
+ @staticmethod
99
+ def can_generate() -> bool:
100
+ return True
101
+
102
+ @property
103
+ def config(self) -> PretrainedConfig:
104
+ return self.llm_backbone.llm.config
105
+
106
+ # => Beam Search Utility
107
+ def _reorder_cache(self, past_key_values, beam_idx):
108
+ return self.llm_backbone.llm._reorder_cache(past_key_values, beam_idx)
capvector-oft/prismatic/models/vlms/prismatic.py ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ prismatic.py
3
+
4
+ PyTorch Module defining a PrismaticVLM, our general interface for defining the various different VLMs in our work.
5
+
6
+ Notes:
7
+ - For now, we don't subclass `transformers.PretrainedModel` (or CausalLM). Instead, we assume a very limited subset
8
+ of the {Model}ForCausalLM API that enables dispatch to the underlying LLM's `generate` utilities (feeding inputs
9
+ through our custom projection shim).
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from functools import partial
15
+ from pathlib import Path
16
+ from typing import Callable, Dict, List, Optional, Type, Union
17
+
18
+ import torch
19
+ from PIL import Image
20
+ from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy
21
+ from transformers.modeling_outputs import CausalLMOutputWithPast
22
+
23
+ from prismatic.models.backbones.llm import LLMBackbone
24
+ from prismatic.models.backbones.llm.prompting import PromptBuilder
25
+ from prismatic.models.backbones.vision import VisionBackbone
26
+ from prismatic.models.vlms.base_vlm import VLM
27
+ from prismatic.overwatch import initialize_overwatch
28
+ from prismatic.util.nn_utils import FusedMLPProjector, LinearProjector, MLPProjector
29
+
30
+ # Initialize Overwatch =>> Wraps `logging.Logger`
31
+ overwatch = initialize_overwatch(__name__)
32
+
33
+
34
+ # HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels)
35
+ IGNORE_INDEX = -100
36
+
37
+
38
+ class PrismaticVLM(VLM):
39
+ def __init__(
40
+ self,
41
+ model_id: str,
42
+ vision_backbone: VisionBackbone,
43
+ llm_backbone: LLMBackbone,
44
+ enable_mixed_precision_training: bool = True,
45
+ arch_specifier: str = "gelu-mlp",
46
+ **kwargs,
47
+ ) -> None:
48
+ super().__init__(
49
+ "prismatic",
50
+ model_id,
51
+ vision_backbone,
52
+ llm_backbone,
53
+ enable_mixed_precision_training=enable_mixed_precision_training,
54
+ )
55
+
56
+ # Set Weight Initialization Seed for Projector Consistency
57
+ torch.manual_seed(vision_backbone.embed_dim)
58
+
59
+ # Initialize Projection (Adapter) based on `arch_specifier`
60
+ self.arch_specifier = arch_specifier
61
+ if arch_specifier == "linear":
62
+ self.projector = LinearProjector(vision_backbone.embed_dim, llm_backbone.embed_dim)
63
+ elif arch_specifier.endswith("fused-gelu-mlp"):
64
+ self.projector = FusedMLPProjector(vision_backbone.embed_dim, llm_backbone.embed_dim)
65
+ elif arch_specifier.endswith("gelu-mlp"):
66
+ self.projector = MLPProjector(vision_backbone.embed_dim, llm_backbone.embed_dim)
67
+ else:
68
+ raise ValueError(f"PrismaticVLM with `{arch_specifier = }` is not supported!")
69
+
70
+ # Trackers
71
+ self.vision_backbone_requires_grad = False
72
+
73
+ # Set Module Keys =>> used in Checkpoint Saving / Model Loading
74
+ self.all_module_keys = ["vision_backbone", "llm_backbone", "projector"]
75
+ self.trainable_module_keys = []
76
+
77
+ # === Generation Utilities ===
78
+ # => For computing likelihoods --> get tokens corresponding to "True", "False" and "Yes", "No"
79
+ self.string2idx = {}
80
+ for trigger_string in ["True", "False", "Yes", "No"] + [chr(ord("A") + i) for i in range(26)]:
81
+ token_idx_list = self.llm_backbone.tokenizer.encode(trigger_string, add_special_tokens=False)
82
+ assert len(token_idx_list) == 1, f'String "{trigger_string}" is tokenized as more than one token!'
83
+ self.string2idx[trigger_string] = token_idx_list[0]
84
+
85
+ @classmethod
86
+ def from_pretrained(
87
+ cls,
88
+ pretrained_checkpoint: Path,
89
+ model_id: str,
90
+ vision_backbone: VisionBackbone,
91
+ llm_backbone: LLMBackbone,
92
+ enable_mixed_precision_training: bool = True,
93
+ arch_specifier: str = "gelu-mlp",
94
+ freeze_weights: bool = True,
95
+ **kwargs,
96
+ ) -> PrismaticVLM:
97
+ """Initialize a PrismaticVLM from a pretrained checkpoint, freezing all weights, tailored for inference."""
98
+ vlm = cls(
99
+ model_id,
100
+ vision_backbone,
101
+ llm_backbone,
102
+ enable_mixed_precision_training=enable_mixed_precision_training,
103
+ arch_specifier=arch_specifier,
104
+ **kwargs,
105
+ )
106
+
107
+ # Load from Checkpoint (Custom --> should load both *projector* and *llm* weights)
108
+ model_state_dict = torch.load(pretrained_checkpoint, map_location="cpu")["model"]
109
+ assert (
110
+ "projector" in model_state_dict and "llm_backbone" in model_state_dict
111
+ ), "PrismaticVLM `from_pretrained` expects checkpoint with keys for `projector` AND `llm_backbone`!"
112
+
113
+ vlm.projector.load_state_dict(model_state_dict["projector"])
114
+ vlm.llm_backbone.load_state_dict(model_state_dict["llm_backbone"])
115
+ if "vision_backbone" in model_state_dict.keys():
116
+ vlm.vision_backbone.load_state_dict(model_state_dict["vision_backbone"])
117
+
118
+ # Freeze Weights
119
+ if freeze_weights:
120
+ vlm.requires_grad_(False)
121
+ vlm.eval()
122
+
123
+ return vlm
124
+
125
+ def get_prompt_builder(self, system_prompt: Optional[str] = None) -> PromptBuilder:
126
+ prompt_initializer: Type[PromptBuilder] = self.llm_backbone.prompt_builder_fn
127
+ return prompt_initializer(self.model_family, system_prompt=system_prompt)
128
+
129
+ def freeze_backbones(self, stage: str) -> None:
130
+ """
131
+ This function sets `requires_grad_` on each of the component modules explicitly, depending on stage.
132
+
133
+ We support two separate stages --> "align" and "finetune".
134
+ => "align" --> vision_backbone*, llm_backbone* are frozen; only the `projector` is trained.
135
+ => "finetune" --> vision_backbone* is frozen; both `projector` and `llm_backbone` are trained.
136
+
137
+ :param stage: Pretraining stage in < "align" | "finetune" | "full-finetune" | "vla-train" | "vla-full-train" >
138
+ """
139
+ if stage == "align":
140
+ self.vision_backbone.requires_grad_(False)
141
+ self.llm_backbone.requires_grad_(False)
142
+ self.projector.requires_grad_(True)
143
+
144
+ # Add to `self.trainable_module_keys`
145
+ self.trainable_module_keys = ["projector"]
146
+
147
+ # Update Trackers
148
+ self.vision_backbone_requires_grad = False
149
+
150
+ # Explicitly Log Frozen / Trainable Components
151
+ overwatch.info(f"[Frozen] 🥶 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1)
152
+ overwatch.info(f"[Frozen] 🥶 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1)
153
+ overwatch.info(f"[TRAINABLE] 🔥 =>> Projector `{self.arch_specifier}`", ctx_level=1)
154
+
155
+ elif stage in {"finetune", "vla-train"}:
156
+ self.vision_backbone.requires_grad_(False)
157
+ self.llm_backbone.requires_grad_(True)
158
+ self.projector.requires_grad_(True)
159
+
160
+ # Add to `self.trainable_module_keys`
161
+ self.trainable_module_keys = ["projector", "llm_backbone"]
162
+
163
+ # Update Trackers
164
+ self.vision_backbone_requires_grad = False
165
+
166
+ # Explicitly Log Frozen / Unfrozen Components
167
+ overwatch.info(f"[Frozen] 🥶 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1)
168
+ overwatch.info(f"[TRAINABLE] 🔥 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1)
169
+ overwatch.info(f"[TRAINABLE] 🔥 =>> Projector `{self.arch_specifier}`", ctx_level=1)
170
+
171
+ elif stage in {"full-finetune", "vla-full-train"}:
172
+ self.vision_backbone.dtype = torch.float32
173
+ self.vision_backbone.requires_grad_(True)
174
+ self.llm_backbone.requires_grad_(True)
175
+ self.projector.requires_grad_(True)
176
+
177
+ # Add to `self.trainable_module_keys`
178
+ self.trainable_module_keys = ["vision_backbone", "projector", "llm_backbone"]
179
+
180
+ # Update Trackers
181
+ self.vision_backbone_requires_grad = True
182
+
183
+ # Explicitly Log Frozen / Unfrozen Components
184
+ overwatch.info(f"[TRAINABLE] 🔥 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1)
185
+ overwatch.info(f"[TRAINABLE] 🔥 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1)
186
+ overwatch.info(f"[TRAINABLE] 🔥 =>> Projector `{self.arch_specifier}`", ctx_level=1)
187
+
188
+ elif stage in {"last-layer-finetune", "vla-last-layer-train"}:
189
+ self.vision_backbone.requires_grad_(False)
190
+ self.projector.requires_grad_(False)
191
+ self.llm_backbone.requires_grad_(False)
192
+
193
+ # Unfreeze final LLM layer
194
+ for module in self.llm_backbone.last_layer_finetune_modules:
195
+ module.requires_grad_(True)
196
+
197
+ # Add to `self.trainable_module_keys`
198
+ self.trainable_module_keys = ["llm_backbone"]
199
+
200
+ # Update Trackers
201
+ self.vision_backbone_requires_grad = False
202
+
203
+ # Explicitly Log Frozen / Unfrozen Components
204
+ # fmt: off
205
+ overwatch.info(f"[Frozen] 🥶 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1) # noqa: E501
206
+ overwatch.info(f"[Frozen, except last layer] 🥶🔥 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1) # noqa: E501
207
+ overwatch.info(f"[Frozen] 🥶 =>> Projector `{self.arch_specifier}`", ctx_level=1)
208
+ # fmt: on
209
+
210
+ elif stage in {"vla-sandwich-train"}:
211
+ self.vision_backbone.dtype = torch.float32
212
+ self.vision_backbone.requires_grad_(True)
213
+ self.projector.requires_grad_(True)
214
+ self.llm_backbone.requires_grad_(False)
215
+
216
+ # Unfreeze final LLM layer
217
+ for module in self.llm_backbone.last_layer_finetune_modules:
218
+ module.requires_grad_(True)
219
+
220
+ # Add to `self.trainable_module_keys`
221
+ self.trainable_module_keys = ["vision_backbone", "projector", "llm_backbone"]
222
+
223
+ # Update Trackers
224
+ self.vision_backbone_requires_grad = True
225
+
226
+ # Explicitly Log Frozen / Unfrozen Components
227
+ # fmt: off
228
+ overwatch.info(f"[TRAINABLE] 🔥 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1) # noqa: E501
229
+ overwatch.info(f"[Frozen, except last layer] 🥶🔥 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1) # noqa: E501
230
+ overwatch.info(f"[TRAINABLE] 🔥 =>> Projector `{self.arch_specifier}`", ctx_level=1)
231
+ # fmt: on
232
+
233
+ else:
234
+ raise ValueError(f"Stage `{stage}` is not supported for LLaVa! Try < align | finetune >")
235
+
236
+ overwatch.debug("##################################################")
237
+ overwatch.debug("##### Trainable Network Parameters: #####")
238
+ overwatch.debug("##################################################")
239
+ for name, param in self.named_parameters():
240
+ if param.requires_grad:
241
+ overwatch.debug(name)
242
+
243
+ def load_from_checkpoint(self, stage: str, run_dir: Path, pretrained_checkpoint: Optional[Path] = None) -> None:
244
+ """Load weights from checkpoint (if required by the given stage)."""
245
+ assert stage in {"align", "finetune", "full-finetune"}, f"Stage {stage} is not supported!"
246
+
247
+ # If we're running a `no-align` architecture, we're good!
248
+ if self.arch_specifier.startswith("no-align"):
249
+ overwatch.info(
250
+ f"PrismaticVLM with `{self.arch_specifier = }` does not require pretrained weights!", ctx_level=1
251
+ )
252
+ return
253
+
254
+ # Otherwise, handle stage-specific logic!
255
+ if stage == "align":
256
+ overwatch.info("Stage `align` does not require pretrained weights =>> Starting Training", ctx_level=1)
257
+ return
258
+
259
+ # Otherwise, load from `pretrained_checkpoint` or match on `run_dir` (s/+stage-finetune/+stage-align/g)
260
+ overwatch.info("Stage `finetune` requires `align` pretrained weights", ctx_level=1)
261
+
262
+ # Config specifies path to a checkpoint to load
263
+ if pretrained_checkpoint is not None:
264
+ overwatch.info(f"Loading from Provided Checkpoint `{pretrained_checkpoint}`", ctx_level=1)
265
+ model_state_dict = torch.load(pretrained_checkpoint)["model"]
266
+ self.projector.load_state_dict(model_state_dict["projector"])
267
+
268
+ return
269
+
270
+ # [Contract] If no `pretrained_checkpoint`, assume `align` lives in the run directory; string substitution!
271
+ model, scale, _, seed = run_dir.name.split("+")
272
+ align_dirs = [
273
+ d
274
+ for d in run_dir.parent.iterdir()
275
+ if (d.name.startswith(f"{model}+{scale}") and d.name.endswith(f"+stage-align+{seed}"))
276
+ ]
277
+ assert len(align_dirs) == 1, "Multiple or No Valid Pretrained Directories Exist -- Double Check `runs`!"
278
+ if (pretrained_checkpoint := (align_dirs[0] / "checkpoints" / "latest-checkpoint.pt")).exists():
279
+ overwatch.info(f"Loading from Discovered Checkpoint `{pretrained_checkpoint}`", ctx_level=1)
280
+ model_state_dict = torch.load(pretrained_checkpoint)["model"]
281
+ self.projector.load_state_dict(model_state_dict["projector"])
282
+ else:
283
+ raise ValueError(f"Could not find valid `align` checkpoint at {pretrained_checkpoint}!")
284
+
285
+ def get_fsdp_wrapping_policy(self) -> Callable:
286
+ """Return an FSDP _or_policy over the policies returned by each individual backbone (and our VLM policy)."""
287
+ vision_fsdp_wrapping_policy = self.vision_backbone.get_fsdp_wrapping_policy()
288
+ llm_fsdp_wrapping_policy = self.llm_backbone.get_fsdp_wrapping_policy()
289
+
290
+ # Get Prismatic Wrapping Policy =>> just a module wrapping policy around `self.projector`
291
+ prismatic_fsdp_wrapping_policy = partial(
292
+ _module_wrap_policy,
293
+ module_classes={LinearProjector, MLPProjector, FusedMLPProjector},
294
+ )
295
+
296
+ # Return union (_or_) over constituent policies
297
+ # => Note: there is *not* a fall-through policy; any module that isn't covered by the above constituents will
298
+ # automatically be folded into the root VLM FSDP instance.
299
+ return partial(
300
+ _or_policy,
301
+ policies=[
302
+ vision_fsdp_wrapping_policy,
303
+ llm_fsdp_wrapping_policy,
304
+ prismatic_fsdp_wrapping_policy,
305
+ ],
306
+ )
307
+
308
+ # Note =>> We're not explicitly subclassing `PreTrainedModel` because we don't need the bloat; however, `forward()`
309
+ # *must* match the signature of a `{Model}ForCausalLM` so that we can inherit from `GenerationMixin`
310
+
311
+ # ruff: noqa: C901
312
+ def forward(
313
+ self,
314
+ input_ids: Optional[torch.LongTensor] = None,
315
+ attention_mask: Optional[torch.Tensor] = None,
316
+ pixel_values: Optional[torch.FloatTensor] = None,
317
+ labels: Optional[torch.LongTensor] = None,
318
+ inputs_embeds: Optional[torch.FloatTensor] = None,
319
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
320
+ use_cache: Optional[bool] = None,
321
+ output_attentions: Optional[bool] = None,
322
+ output_hidden_states: Optional[bool] = None,
323
+ return_dict: Optional[bool] = None,
324
+ multimodal_indices: Optional[torch.LongTensor] = None,
325
+ ) -> CausalLMOutputWithPast:
326
+ """Run a forward pass through the VLM, returning a CausalLMOutputWithPast instance (contains loss)."""
327
+
328
+ # Handle Inference (leverage cache, short-circuit on just LLM forward)
329
+ if input_ids.shape[1] == 1 and past_key_values is not None:
330
+ # We're leveraging the cache, so just redirect to `self.llm_backbone` with `input_ids` and `past_key_values`
331
+ output = self.llm_backbone(
332
+ input_ids=input_ids,
333
+ attention_mask=None,
334
+ position_ids=None,
335
+ past_key_values=past_key_values,
336
+ inputs_embeds=None,
337
+ labels=None,
338
+ use_cache=use_cache,
339
+ output_attentions=output_attentions,
340
+ output_hidden_states=output_hidden_states,
341
+ return_dict=return_dict,
342
+ )
343
+ return output
344
+
345
+ elif input_ids.shape[1] == 1 or pixel_values is None:
346
+ raise RuntimeError("Invalid `forward()` call!")
347
+
348
+ # Handle Multimodal Indices is None --> pretend like the batch is fully multimodal (always image + text)!
349
+ if multimodal_indices is None:
350
+ multimodal_indices = torch.arange(len(input_ids), dtype=torch.long, device=input_ids.device)
351
+
352
+ # Handle Multimodal Indices is Empty (len == 0) --> simple unimodal forward
353
+ elif len(multimodal_indices) == 0:
354
+ return self.llm_backbone(
355
+ input_ids=input_ids,
356
+ attention_mask=attention_mask,
357
+ position_ids=None,
358
+ past_key_values=past_key_values,
359
+ inputs_embeds=None,
360
+ labels=labels,
361
+ use_cache=use_cache,
362
+ output_attentions=output_attentions,
363
+ output_hidden_states=output_hidden_states,
364
+ return_dict=return_dict,
365
+ )
366
+
367
+ # Run Visual Feature Extraction
368
+ with torch.set_grad_enabled(self.vision_backbone_requires_grad):
369
+ if isinstance(pixel_values, dict):
370
+ patch_features = self.vision_backbone({k: pixel_values[k][multimodal_indices] for k in pixel_values})
371
+ else:
372
+ patch_features = self.vision_backbone(pixel_values[multimodal_indices])
373
+
374
+ # Projection Logic :: [bsz, num_patches, llm_embed_dim] =>> num_patches = (2 *) (256 + 1) for ViT-L + CLS
375
+ projected_patch_embeddings = self.projector(patch_features)
376
+ projected_patch_attention_mask = None
377
+ if attention_mask is not None:
378
+ projected_patch_attention_mask = torch.full(
379
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
380
+ True,
381
+ dtype=attention_mask.dtype,
382
+ device=attention_mask.device,
383
+ )
384
+
385
+ # Get Input Embeddings from LLM Backbone :: [bsz, input_seq_len, llm_embed_dim]
386
+ input_embeddings = self.llm_backbone.embed_input_ids(input_ids)
387
+
388
+ # Build Multimodal Embeddings (and build resulting attention mask)
389
+ multimodal_embeddings = torch.cat(
390
+ [
391
+ input_embeddings[multimodal_indices, :1, :],
392
+ projected_patch_embeddings,
393
+ input_embeddings[multimodal_indices, 1:, :],
394
+ ],
395
+ dim=1,
396
+ )
397
+ multimodal_attention_mask = None
398
+ if attention_mask is not None:
399
+ multimodal_attention_mask = torch.cat(
400
+ [
401
+ attention_mask[multimodal_indices, :1],
402
+ projected_patch_attention_mask,
403
+ attention_mask[multimodal_indices, 1:],
404
+ ],
405
+ dim=1,
406
+ )
407
+
408
+ # [Contract] We assume the first token of `labels` (associated with <BOS>) is already marked as "IGNORE"
409
+ # => We'll ignore the per-token outputs for each of the patch embeddings as well!
410
+ multimodal_labels = None
411
+ if labels is not None:
412
+ projected_patch_labels = torch.full(
413
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
414
+ IGNORE_INDEX,
415
+ dtype=labels.dtype,
416
+ device=labels.device,
417
+ )
418
+ multimodal_labels = torch.cat(
419
+ [labels[multimodal_indices, :1], projected_patch_labels, labels[multimodal_indices, 1:]], dim=1
420
+ )
421
+
422
+ # === Add Unimodal Handling ===
423
+
424
+ # Create Fused Embeddings, Attention Mask, and Labels by Merging with "unimodal" Inputs (if applicable)
425
+ unimodal_indices = torch.tensor(
426
+ [idx for idx in range(len(input_ids)) if idx not in multimodal_indices],
427
+ dtype=torch.long,
428
+ device=multimodal_indices.device,
429
+ )
430
+
431
+ # No "unimodal" data --> Fused == Multimodal
432
+ if len(unimodal_indices) == 0:
433
+ fused_embeddings = multimodal_embeddings
434
+ fused_attention_mask = multimodal_attention_mask
435
+ fused_labels = multimodal_labels
436
+
437
+ else:
438
+ # Otherwise --> Merge w/ unimodal data
439
+
440
+ # This doesn't matter --> but in the "normal" case this is the embedding of the <PAD> token
441
+ # => NOTE :: Verified that `zeros/randn/empty/<PAD> embedding` all return the same result!
442
+ unimodal_embeddings_pad = torch.zeros(
443
+ (len(unimodal_indices), projected_patch_embeddings.shape[1], input_embeddings.shape[2]),
444
+ dtype=input_embeddings.dtype,
445
+ device=input_embeddings.device,
446
+ )
447
+ unimodal_attention_pad = torch.full(
448
+ (len(unimodal_indices), projected_patch_embeddings.shape[1]),
449
+ False,
450
+ dtype=attention_mask.dtype,
451
+ device=attention_mask.device,
452
+ )
453
+ unimodal_labels_pad = torch.full(
454
+ (len(unimodal_indices), projected_patch_embeddings.shape[1]),
455
+ IGNORE_INDEX,
456
+ dtype=labels.dtype,
457
+ device=labels.device,
458
+ )
459
+
460
+ unimodal_embeddings = torch.cat([input_embeddings[unimodal_indices], unimodal_embeddings_pad], dim=1)
461
+ unimodal_attention_mask = torch.cat([attention_mask[unimodal_indices], unimodal_attention_pad], dim=1)
462
+ unimodal_labels = torch.cat([labels[unimodal_indices], unimodal_labels_pad], dim=1)
463
+
464
+ # Create "Fused" Tensors by Stacking Multimodal & Unimodal
465
+ fused_embeddings = torch.vstack([multimodal_embeddings, unimodal_embeddings])
466
+ fused_attention_mask = torch.vstack([multimodal_attention_mask, unimodal_attention_mask])
467
+ fused_labels = torch.vstack([multimodal_labels, unimodal_labels])
468
+
469
+ # Run LLM Forward --> returns CausalLMOutputWithPast!
470
+ return self.llm_backbone(
471
+ input_ids=None,
472
+ attention_mask=fused_attention_mask,
473
+ position_ids=None,
474
+ past_key_values=past_key_values,
475
+ inputs_embeds=fused_embeddings,
476
+ labels=fused_labels,
477
+ use_cache=use_cache,
478
+ output_attentions=output_attentions,
479
+ output_hidden_states=output_hidden_states,
480
+ return_dict=return_dict,
481
+ )
482
+
483
+ # === GenerationMixin Methods ===
484
+ # => Note: The following methods override the functionality of `transformers.GenerationMixin`; these expect the
485
+ # contract in each of the function signatures, and also expect our `forward` function to roughly take
486
+ # the same arguments as the underlying LLM (see `LlamaModelForCausalLM` as an example)
487
+
488
+ def prepare_inputs_for_generation(
489
+ self,
490
+ input_ids: Optional[torch.LongTensor] = None,
491
+ attention_mask: Optional[torch.Tensor] = None,
492
+ pixel_values: Optional[torch.FloatTensor] = None,
493
+ inputs_embeds: Optional[torch.FloatTensor] = None,
494
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
495
+ use_cache: Optional[bool] = None,
496
+ **kwargs: torch.Tensor,
497
+ ) -> Dict[str, torch.Tensor]:
498
+ """Borrowed from `LlamaForCausalLM` --> in general, just handles caching logic during generation."""
499
+ if past_key_values:
500
+ input_ids = input_ids[:, -1:]
501
+
502
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
503
+ if inputs_embeds is not None and past_key_values is None:
504
+ model_inputs = {"inputs_embeds": inputs_embeds}
505
+ else:
506
+ model_inputs = {"input_ids": input_ids}
507
+
508
+ # Make sure `pixel_values` are preserved in `model_inputs`
509
+ model_inputs.update(
510
+ {
511
+ "attention_mask": attention_mask,
512
+ "pixel_values": pixel_values,
513
+ "past_key_values": past_key_values,
514
+ "use_cache": use_cache,
515
+ }
516
+ )
517
+
518
+ return model_inputs
519
+
520
+ @torch.inference_mode()
521
+ def generate_batch(
522
+ self,
523
+ pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]],
524
+ texts: List[str],
525
+ return_string_probabilities: Optional[List[str]] = None,
526
+ **kwargs: str,
527
+ ) -> Union[List[str], List[List[float]]]:
528
+ # For now, only support generation with a batch size of 1 for simplicity
529
+ tokenizer = self.llm_backbone.tokenizer
530
+
531
+ # Prepare Inputs
532
+ batch_input_ids = [
533
+ tokenizer(text, truncation=True, return_tensors="pt").input_ids.to(self.device) for text in texts
534
+ ]
535
+ if isinstance(pixel_values, torch.Tensor):
536
+ pixel_values = pixel_values[None, ...].to(self.device)
537
+ elif isinstance(pixel_values, dict):
538
+ pixel_values = {k: v[None, ...].to(self.device) for k, v in pixel_values.items()}
539
+ else:
540
+ raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}")
541
+
542
+ # Create Output Lists
543
+ gen_texts, gen_probabilities = [], []
544
+
545
+ # Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()`
546
+ autocast_dtype = self.llm_backbone.half_precision_dtype
547
+ with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training):
548
+ for idx, input_ids in enumerate(batch_input_ids):
549
+ if isinstance(pixel_values, torch.Tensor):
550
+ pixel_values = pixel_values[idx]
551
+ elif isinstance(pixel_values, dict):
552
+ pixel_values = {k: pixel_values[k][idx] for k in pixel_values}
553
+ else:
554
+ raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}")
555
+
556
+ # Handle `return_string_probabilities`
557
+ if return_string_probabilities is None:
558
+ full_out_ids = super().generate(input_ids=input_ids, pixel_values=pixel_values, **kwargs)
559
+ gen_ids = full_out_ids[0, input_ids.shape[1] :]
560
+
561
+ # Decode `gen_ids` and strip any <EOS> tokens
562
+ gen_texts.append(tokenizer.decode(gen_ids, skip_special_tokens=True).strip())
563
+
564
+ else:
565
+ full_out_dict = super().generate(
566
+ input_ids=input_ids,
567
+ pixel_values=pixel_values,
568
+ output_scores=True,
569
+ return_dict_in_generate=True,
570
+ **kwargs,
571
+ )
572
+
573
+ # Generation pattern should usually be [TOKEN] <EOS> for True/False and Yes/No Generations
574
+ gen_ids = full_out_dict.sequences[0, input_ids.shape[1] :]
575
+
576
+ # [Debug] Verify that the first token generated is in `self.string2idx.values()`
577
+ # assert gen_ids[0] in self.string2idx.values(), "Generated ID not in mapping!"
578
+
579
+ # Decode `gen_ids` and strip any <EOS> tokens
580
+ gen_texts.append(tokenizer.decode(gen_ids, skip_special_tokens=True).strip())
581
+
582
+ # Get all token probabilities --> softmax over logits
583
+ token_probs = torch.softmax(full_out_dict.scores[0][0], dim=0)
584
+
585
+ # Get *normalized* probabilities for all values in `return_token_probabilities`
586
+ slice_idxs = torch.tensor([self.string2idx[s] for s in return_string_probabilities])
587
+ string_probs_unnormalized = token_probs[slice_idxs]
588
+ string_probs = string_probs_unnormalized / string_probs_unnormalized.sum()
589
+ gen_probabilities.append(string_probs.cpu().numpy().tolist())
590
+
591
+ return gen_texts if return_string_probabilities is None else gen_probabilities
592
+
593
+ @torch.inference_mode()
594
+ def generate(self, image: Image, prompt_text: str, **kwargs: str) -> str:
595
+ # For now, only support generation with a batch size of 1 for simplicity
596
+ image_transform, tokenizer = self.vision_backbone.image_transform, self.llm_backbone.tokenizer
597
+
598
+ # Prepare Inputs
599
+ input_ids = tokenizer(prompt_text, truncation=True, return_tensors="pt").input_ids.to(self.device)
600
+ pixel_values = image_transform(image)
601
+ if isinstance(pixel_values, torch.Tensor):
602
+ pixel_values = pixel_values[None, ...].to(self.device)
603
+ elif isinstance(pixel_values, dict):
604
+ pixel_values = {k: v[None, ...].to(self.device) for k, v in pixel_values.items()}
605
+ else:
606
+ raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}")
607
+
608
+ # Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()`
609
+ autocast_dtype = self.llm_backbone.half_precision_dtype
610
+ with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training):
611
+ # fmt: off
612
+ generated_ids = super().generate(
613
+ input_ids=input_ids, # Shape: [1, seq]
614
+ pixel_values=pixel_values, # Shape: [1, 3, res, res] or Dict[str, Shape[1, 3, res, res]]
615
+ **kwargs
616
+ )
617
+ # fmt: on
618
+
619
+ generated_text = tokenizer.decode(generated_ids[0, input_ids.shape[1] :], skip_special_tokens=True).strip()
620
+
621
+ return generated_text
capvector-oft/prismatic/overwatch/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .overwatch import initialize_overwatch
capvector-oft/prismatic/overwatch/overwatch.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ overwatch.py
3
+
4
+ Utility class for creating a centralized/standardized logger (built on Rich) and accelerate handler.
5
+ """
6
+
7
+ import logging
8
+ import logging.config
9
+ import os
10
+ from contextlib import nullcontext
11
+ from logging import LoggerAdapter
12
+ from typing import Any, Callable, ClassVar, Dict, MutableMapping, Tuple, Union
13
+
14
+ # Overwatch Default Format String
15
+ RICH_FORMATTER, DATEFMT = "| >> %(message)s", "%m/%d [%H:%M:%S]"
16
+
17
+ # Set Logging Configuration
18
+ LOG_CONFIG = {
19
+ "version": 1,
20
+ "disable_existing_loggers": True,
21
+ "formatters": {"simple-console": {"format": RICH_FORMATTER, "datefmt": DATEFMT}},
22
+ "handlers": {
23
+ "console": {
24
+ "class": "rich.logging.RichHandler",
25
+ "formatter": "simple-console",
26
+ "markup": True,
27
+ "rich_tracebacks": True,
28
+ "show_level": True,
29
+ "show_path": True,
30
+ "show_time": True,
31
+ }
32
+ },
33
+ "root": {"level": "INFO", "handlers": ["console"]},
34
+ }
35
+ logging.config.dictConfig(LOG_CONFIG)
36
+
37
+
38
+ # === Custom Contextual Logging Logic ===
39
+ class ContextAdapter(LoggerAdapter):
40
+ CTX_PREFIXES: ClassVar[Dict[int, str]] = {**{0: "[*] "}, **{idx: "|=> ".rjust(4 + (idx * 4)) for idx in [1, 2, 3]}}
41
+
42
+ def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> Tuple[str, MutableMapping[str, Any]]:
43
+ ctx_level = kwargs.pop("ctx_level", 0)
44
+ return f"{self.CTX_PREFIXES[ctx_level]}{msg}", kwargs
45
+
46
+
47
+ class DistributedOverwatch:
48
+ def __init__(self, name: str) -> None:
49
+ """Initializer for an Overwatch object that wraps logging & `accelerate.PartialState`."""
50
+ from accelerate import PartialState
51
+
52
+ # Note that PartialState is always safe to initialize regardless of `accelerate launch` or `torchrun`
53
+ # =>> However, might be worth actually figuring out if we need the `accelerate` dependency at all!
54
+ self.logger, self.distributed_state = ContextAdapter(logging.getLogger(name), extra={}), PartialState()
55
+
56
+ # Logger Delegation (for convenience; would be nice to just compose & dynamic dispatch eventually)
57
+ self.debug = self.logger.debug
58
+ self.info = self.logger.info
59
+ self.warning = self.logger.warning
60
+ self.error = self.logger.error
61
+ self.critical = self.logger.critical
62
+
63
+ # Logging Defaults =>> only Log `INFO` on Main Process, `ERROR` on others!
64
+ self.logger.setLevel(logging.INFO if self.distributed_state.is_main_process else logging.ERROR)
65
+
66
+ @property
67
+ def rank_zero_only(self) -> Callable[..., Any]:
68
+ return self.distributed_state.on_main_process
69
+
70
+ @property
71
+ def local_zero_only(self) -> Callable[..., Any]:
72
+ return self.distributed_state.on_local_main_process
73
+
74
+ @property
75
+ def rank_zero_first(self) -> Callable[..., Any]:
76
+ return self.distributed_state.main_process_first
77
+
78
+ @property
79
+ def local_zero_first(self) -> Callable[..., Any]:
80
+ return self.distributed_state.local_main_process_first
81
+
82
+ def is_rank_zero(self) -> bool:
83
+ return self.distributed_state.is_main_process
84
+
85
+ def rank(self) -> int:
86
+ return self.distributed_state.process_index
87
+
88
+ def local_rank(self) -> int:
89
+ return self.distributed_state.local_process_index
90
+
91
+ def world_size(self) -> int:
92
+ return self.distributed_state.num_processes
93
+
94
+
95
+ class PureOverwatch:
96
+ def __init__(self, name: str) -> None:
97
+ """Initializer for an Overwatch object that just wraps logging."""
98
+ self.logger = ContextAdapter(logging.getLogger(name), extra={})
99
+
100
+ # Logger Delegation (for convenience; would be nice to just compose & dynamic dispatch eventually)
101
+ self.debug = self.logger.debug
102
+ self.info = self.logger.info
103
+ self.warning = self.logger.warning
104
+ self.error = self.logger.error
105
+ self.critical = self.logger.critical
106
+
107
+ # Logging Defaults =>> INFO
108
+ self.logger.setLevel(logging.INFO)
109
+
110
+ @staticmethod
111
+ def get_identity_ctx() -> Callable[..., Any]:
112
+ def identity(fn: Callable[..., Any]) -> Callable[..., Any]:
113
+ return fn
114
+
115
+ return identity
116
+
117
+ @property
118
+ def rank_zero_only(self) -> Callable[..., Any]:
119
+ return self.get_identity_ctx()
120
+
121
+ @property
122
+ def local_zero_only(self) -> Callable[..., Any]:
123
+ return self.get_identity_ctx()
124
+
125
+ @property
126
+ def rank_zero_first(self) -> Callable[..., Any]:
127
+ return nullcontext
128
+
129
+ @property
130
+ def local_zero_first(self) -> Callable[..., Any]:
131
+ return nullcontext
132
+
133
+ @staticmethod
134
+ def is_rank_zero() -> bool:
135
+ return True
136
+
137
+ @staticmethod
138
+ def rank() -> int:
139
+ return 0
140
+
141
+ @staticmethod
142
+ def world_size() -> int:
143
+ return 1
144
+
145
+
146
+ def initialize_overwatch(name: str) -> Union[DistributedOverwatch, PureOverwatch]:
147
+ return DistributedOverwatch(name) if int(os.environ.get("WORLD_SIZE", -1)) != -1 else PureOverwatch(name)
capvector-oft/prismatic/preprocessing/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .download import convert_to_jpg, download_extract
2
+ from .materialize import get_dataset_and_collator
capvector-oft/prismatic/preprocessing/datasets/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .datasets import AlignDataset, FinetuneDataset
capvector-oft/prismatic/preprocessing/datasets/datasets.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ datasets.py
3
+
4
+ PyTorch Dataset Definitions for Prismatic models; supports processing for both the `align` and `finetune` stages, with
5
+ utilities for formatting conversations during the `finetune` stage subject to the given LLM backbone's expected
6
+ formatting (e.g., SYS_PROMPT + USER: ... ASSISTANT: ... for Vicuña v1.5 Chat models).
7
+
8
+ We currently only support Map-style Datasets; assumes that all files (annotations, images) are on local disk, and that
9
+ random access image reading is relatively cheap/fast.
10
+ """
11
+
12
+ import copy
13
+ import json
14
+ from pathlib import Path
15
+ from typing import Dict, List, Tuple, Type
16
+
17
+ import torch
18
+ from PIL import Image
19
+ from torch.utils.data import Dataset
20
+ from transformers import CodeGenTokenizerFast, LlamaTokenizerFast, PreTrainedTokenizerBase
21
+
22
+ from prismatic.models.backbones.llm.prompting import PromptBuilder
23
+ from prismatic.models.backbones.vision import ImageTransform
24
+
25
+ # HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels)
26
+ IGNORE_INDEX = -100
27
+
28
+
29
+ class AlignDataset(Dataset[Dict[str, torch.Tensor]]):
30
+ def __init__(
31
+ self,
32
+ chat_json: Path,
33
+ image_dir: Path,
34
+ image_transform: ImageTransform,
35
+ tokenizer: PreTrainedTokenizerBase,
36
+ ) -> None:
37
+ super().__init__()
38
+ self.chat_json, self.image_dir = chat_json, image_dir
39
+ self.image_transform, self.tokenizer = image_transform, tokenizer
40
+ self.dataset_type = "align"
41
+
42
+ # Create Prompt Template
43
+ self.prompt_template = "{caption}" + self.tokenizer.eos_token
44
+
45
+ # Load Chat JSON
46
+ with open(self.chat_json, "r") as f:
47
+ self.examples = json.load(f)
48
+
49
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
50
+ """
51
+ Following the *actual* code executed from the LLaVa codebase, during the "align" phase, we actually discard
52
+ the "prompt" from the human, and instead directly predict the caption from the image.
53
+
54
+ As a concrete example given the "raw data" for the first example:
55
+ example = self.examples[0]["conversations"]` = {
56
+ [
57
+ {"from": "human", "value": "Render a clear and concise summary of the photo.\n<image>"},
58
+ {"from": "gpt", "value": "select luxury furniture 3 - inch gel memory foam mattress topper"}
59
+ ]
60
+ }
61
+
62
+ Return =>> self.tokenizer("<image> select luxury furniture 3 - inch gel memory foam mattress topper\n")
63
+
64
+ :param idx: Index to retrieve from the dataset.
65
+
66
+ :return: Dictionary of {"pixel_values": torch.Tensor, "input_ids": torch.Tensor, "labels": torch.Tensor}
67
+ """
68
+ image_path, conversation = Path(self.examples[idx]["image"]), self.examples[idx]["conversations"]
69
+ assert (len(conversation) == 2) and ("<image>" not in conversation[-1]["value"]), "Unexpected text!"
70
+
71
+ # Format Caption --> {caption}{eos_token}
72
+ caption = self.prompt_template.format(caption=conversation[-1]["value"].strip())
73
+
74
+ # We treat image patches as "tokens = [p1 p2 p3, ...]"; we need to specify ordering of text/patch tokens.
75
+ # => Critically, we find that inserting *after* the BOS token leads to the strongest performance!
76
+ # - input_ids = "<s> p1 p2 p3 ... <caption_text> \n"
77
+ # - labels = "IGNORE IGNORE ..." (copy `input_ids` replacing <s> and p{1...K} with IGNORE)
78
+ #
79
+ # IMPORTANT => IF WE'RE USING HF LLM.forward(... labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL!
80
+ input_ids = self.tokenizer(caption, truncation=True, return_tensors="pt").input_ids[0]
81
+ labels = copy.deepcopy(input_ids)
82
+
83
+ # Set the <BOS> token's label to IGNORE_INDEX (since we're inserting the image patches right after)
84
+ labels[0] = IGNORE_INDEX
85
+
86
+ # Process Image --> get "pixel_values" (will either be a torch.Tensor OR a Dict[str,torch.Tensor])
87
+ pixel_values = self.image_transform(Image.open(self.image_dir / image_path).convert("RGB"))
88
+
89
+ return dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels)
90
+
91
+ def get_modality_lengths(self, n_image_patches: int) -> List[Tuple[bool, int]]:
92
+ """Get a list of modalities (unimodal / text-only vs. multimodal) and length of conversations per example."""
93
+ modality_lengths = []
94
+ for example in self.examples:
95
+ is_multimodal = "image" in example
96
+ n_words = sum([len(turn["value"].replace("<image>", "").split()) for turn in example["conversations"]])
97
+ modality_lengths.append((is_multimodal, (n_image_patches + n_words) if is_multimodal else n_words))
98
+ return modality_lengths
99
+
100
+ def __len__(self) -> int:
101
+ return len(self.examples)
102
+
103
+
104
+ class FinetuneDataset(Dataset[Dict[str, torch.Tensor]]):
105
+ def __init__(
106
+ self,
107
+ instruct_json: Path,
108
+ image_dir: Path,
109
+ image_transform: ImageTransform,
110
+ tokenizer: PreTrainedTokenizerBase,
111
+ prompt_builder_fn: Type[PromptBuilder],
112
+ ) -> None:
113
+ super().__init__()
114
+ self.instruct_json, self.image_dir = instruct_json, image_dir
115
+ self.image_transform, self.tokenizer = image_transform, tokenizer
116
+ self.prompt_builder_fn = prompt_builder_fn
117
+ self.dataset_type = "finetune"
118
+
119
+ # Load Instruct JSON
120
+ with open(self.instruct_json, "r") as f:
121
+ self.examples = json.load(f)
122
+
123
+ # === Unimodal + Multimodal Handling ===
124
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
125
+ """
126
+ Unlike the *align* stage handling, for the *finetune* stage, we actually need to handle multiple "turns" of
127
+ dialog grounded in a single image.
128
+
129
+ To do this, we leverage the `prompt_builder_fn` which instantiates a PromptBuilder object. By calling the
130
+ methods for adding turns and getting a prompt, we ensure proper formatting and consistency for each example.
131
+
132
+ :param idx: Index to retrieve from the dataset.
133
+
134
+ :return: Dictionary of {"pixel_values": torch.Tensor, "input_ids": torch.Tensor, "labels": torch.Tensor}
135
+ """
136
+ conversation = self.examples[idx]["conversations"]
137
+
138
+ # Create Prompt Builder --> add each message sequentially
139
+ prompt_builder, input_ids, labels = self.prompt_builder_fn(model_family="prismatic"), [], []
140
+ for turn_idx, turn in enumerate(conversation):
141
+ # Get "effective" string added to prompt --> handle whitespace for tokenizer type!
142
+ msg = prompt_builder.add_turn(turn["from"], turn["value"])
143
+
144
+ # Llama Tokenizer (Fast) adds extra character if a string ends in whitespace --> strip if non-empty!
145
+ if isinstance(self.tokenizer, LlamaTokenizerFast):
146
+ msg = msg.rstrip()
147
+
148
+ # Phi-2 Tokenizer == CodeGenTokenizer (Fast) -- no special handling!
149
+ elif isinstance(self.tokenizer, CodeGenTokenizerFast):
150
+ pass
151
+
152
+ else:
153
+ raise ValueError(f"Tokenizer of type `{type(self.tokenizer)}` is not explicitly handled!")
154
+
155
+ # Tokenize Input IDs
156
+ turn_input_ids = self.tokenizer(msg, add_special_tokens=turn_idx == 0).input_ids
157
+
158
+ # [CRITICAL] We do not want to take the loss for the "USER: <msg>" prompts =>> just the responses!
159
+ turn_labels = (
160
+ [IGNORE_INDEX for _ in range(len(turn_input_ids))] if (turn_idx % 2) == 0 else list(turn_input_ids)
161
+ )
162
+
163
+ # Add to Trackers
164
+ input_ids.extend(turn_input_ids)
165
+ labels.extend(turn_labels)
166
+
167
+ # Tensorize =>> Set the <BOS> token's label to IGNORE_INDEX (since we're inserting the image patches after)
168
+ # - IMPORTANT => IF WE'RE USING HF LLM.forward(... labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL!
169
+ input_ids, labels = torch.tensor(input_ids), torch.tensor(labels)
170
+
171
+ # Handle Truncation (if necessary)
172
+ input_ids, labels = input_ids[: self.tokenizer.model_max_length], labels[: self.tokenizer.model_max_length]
173
+
174
+ # === Handle "unimodal" (language-only) vs. "multimodal" ===
175
+ if "image" in self.examples[idx]:
176
+ image_path = Path(self.examples[idx]["image"])
177
+
178
+ # Set the <BOS> token's label to IGNORE_INDEX (since we're inserting the image patches right after)
179
+ labels[0] = IGNORE_INDEX
180
+
181
+ # Process Image --> get "pixel_values" (will either be a torch.Tensor OR a Dict[str,torch.Tensor])
182
+ pixel_values = self.image_transform(Image.open(self.image_dir / image_path).convert("RGB"))
183
+
184
+ return dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels)
185
+
186
+ else:
187
+ # No image --> return `pixel_values` = None; Collator will do the smart batch handling for us!
188
+ return dict(pixel_values=None, input_ids=input_ids, labels=labels)
189
+
190
+ def get_modality_lengths(self) -> List[Tuple[bool, int]]:
191
+ """Get a list of modalities (unimodal / text-only vs. multimodal) and length of conversations per example."""
192
+ modality_lengths = []
193
+ for example in self.examples:
194
+ is_multimodal = "image" in example
195
+ n_words = sum([len(turn["value"].split()) for turn in example["conversations"]])
196
+ modality_lengths.append((is_multimodal, n_words))
197
+ return modality_lengths
198
+
199
+ def __len__(self) -> int:
200
+ return len(self.examples)
capvector-oft/prismatic/preprocessing/download.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ download.py
3
+
4
+ Utility functions for downloading and extracting various datasets to (local) disk.
5
+ """
6
+
7
+ import os
8
+ import shutil
9
+ from pathlib import Path
10
+ from typing import Dict, List, TypedDict
11
+ from zipfile import ZipFile
12
+
13
+ import requests
14
+ from PIL import Image
15
+ from rich.progress import BarColumn, DownloadColumn, MofNCompleteColumn, Progress, TextColumn, TransferSpeedColumn
16
+ from tqdm import tqdm
17
+
18
+ from prismatic.overwatch import initialize_overwatch
19
+
20
+ # Initialize Overwatch =>> Wraps `logging.Logger`
21
+ overwatch = initialize_overwatch(__name__)
22
+
23
+
24
+ # === Dataset Registry w/ Links ===
25
+ # fmt: off
26
+ DatasetComponent = TypedDict(
27
+ "DatasetComponent",
28
+ {"name": str, "extract": bool, "extract_type": str, "url": str, "do_rename": bool},
29
+ total=False
30
+ )
31
+
32
+ DATASET_REGISTRY: Dict[str, List[DatasetComponent]] = {
33
+ # === LLaVa v1.5 Dataset(s) ===
34
+
35
+ # Note =>> This is the full suite of datasets included in the LLaVa 1.5 "finetuning" stage; all the LLaVa v1.5
36
+ # models are finetuned on this split. We use this dataset for all experiments in our paper.
37
+ "llava-laion-cc-sbu-558k": [
38
+ {
39
+ "name": "chat.json", # Contains the "chat" traces :: {"human" => <prompt>, "gpt" => <caption>}
40
+ "extract": False,
41
+ "url": "https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/resolve/main/blip_laion_cc_sbu_558k.json",
42
+ "do_rename": True,
43
+ },
44
+ {
45
+ "name": "images", # Contains the LLaVa Processed Images (jpgs, 224x224 resolution)
46
+ "extract": True,
47
+ "extract_type": "directory",
48
+ "url": "https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/resolve/main/images.zip",
49
+ "do_rename": False,
50
+ }
51
+ ],
52
+
53
+ "llava-v1.5-instruct": [
54
+ {
55
+ "name": "llava_v1_5_mix665k.json",
56
+ "extract": False,
57
+ "url": (
58
+ "https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/resolve/main/llava_v1_5_mix665k.json"
59
+ ),
60
+ "do_rename": True,
61
+ },
62
+ {
63
+ "name": "coco/train2017", # Visual Instruct Tuning images are all sourced from COCO Train 2017
64
+ "extract": True,
65
+ "extract_type": "directory",
66
+ "url": "http://images.cocodataset.org/zips/train2017.zip",
67
+ "do_rename": True,
68
+ },
69
+ {
70
+ "name": "gqa/images",
71
+ "extract": True,
72
+ "extract_type": "directory",
73
+ "url": "https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip",
74
+ "do_rename": True,
75
+ },
76
+ {
77
+ "name": "ocr_vqa/images",
78
+ "extract": True,
79
+ "extract_type": "directory",
80
+ "url": "https://huggingface.co/datasets/qnguyen3/ocr_vqa/resolve/main/ocr_vqa.zip",
81
+ "do_rename": True,
82
+ },
83
+ {
84
+ "name": "textvqa/train_images",
85
+ "extract": True,
86
+ "extract_type": "directory",
87
+ "url": "https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip",
88
+ "do_rename": True,
89
+ },
90
+ {
91
+ "name": "vg/VG_100K",
92
+ "extract": True,
93
+ "extract_type": "directory",
94
+ "url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip",
95
+ "do_rename": True,
96
+ },
97
+ {
98
+ "name": "vg/VG_100K_2",
99
+ "extract": True,
100
+ "extract_type": "directory",
101
+ "url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip",
102
+ "do_rename": True,
103
+ },
104
+ ]
105
+ }
106
+ # fmt: on
107
+
108
+
109
+ def convert_to_jpg(image_dir: Path) -> None:
110
+ """Handling for OCR-VQA Images specifically; iterates through directory, converts all GIFs/PNGs."""
111
+ overwatch.info(f"Converting all Images in `{image_dir}` to JPG")
112
+
113
+ for image_fn in tqdm(list(image_dir.iterdir())):
114
+ if image_fn.suffix in {".jpg", ".jpeg"} or (jpg_fn := image_dir / f"{image_fn.stem}.jpg").exists():
115
+ continue
116
+
117
+ if image_fn.suffix == ".gif":
118
+ gif = Image.open(image_fn)
119
+ gif.seek(0)
120
+ gif.convert("RGB").save(jpg_fn)
121
+ elif image_fn.suffix == ".png":
122
+ Image.open(image_fn).convert("RGB").save(jpg_fn)
123
+ else:
124
+ raise ValueError(f"Unexpected image format `{image_fn.suffix}`")
125
+
126
+
127
+ def download_with_progress(url: str, download_dir: Path, chunk_size_bytes: int = 1024) -> Path:
128
+ """Utility function for downloading files from the internet, with a handy Rich-based progress bar."""
129
+ overwatch.info(f"Downloading {(dest_path := download_dir / Path(url).name)} from `{url}`", ctx_level=1)
130
+ if dest_path.exists():
131
+ return dest_path
132
+
133
+ # Otherwise --> fire an HTTP Request, with `stream = True`
134
+ response = requests.get(url, stream=True)
135
+
136
+ # Download w/ Transfer-Aware Progress
137
+ # => Reference: https://github.com/Textualize/rich/blob/master/examples/downloader.py
138
+ with Progress(
139
+ TextColumn("[bold]{task.description} - {task.fields[fname]}"),
140
+ BarColumn(bar_width=None),
141
+ "[progress.percentage]{task.percentage:>3.1f}%",
142
+ "•",
143
+ DownloadColumn(),
144
+ "•",
145
+ TransferSpeedColumn(),
146
+ transient=True,
147
+ ) as dl_progress:
148
+ dl_tid = dl_progress.add_task(
149
+ "Downloading", fname=dest_path.name, total=int(response.headers.get("content-length", "None"))
150
+ )
151
+ with open(dest_path, "wb") as f:
152
+ for data in response.iter_content(chunk_size=chunk_size_bytes):
153
+ dl_progress.advance(dl_tid, f.write(data))
154
+
155
+ return dest_path
156
+
157
+
158
+ def extract_with_progress(archive_path: Path, download_dir: Path, extract_type: str, cleanup: bool = False) -> Path:
159
+ """Utility function for extracting compressed archives, with a handy Rich-based progress bar."""
160
+ assert archive_path.suffix == ".zip", "Only `.zip` compressed archives are supported for now!"
161
+ overwatch.info(f"Extracting {archive_path.name} to `{download_dir}`", ctx_level=1)
162
+
163
+ # Extract w/ Progress
164
+ with Progress(
165
+ TextColumn("[bold]{task.description} - {task.fields[aname]}"),
166
+ BarColumn(bar_width=None),
167
+ "[progress.percentage]{task.percentage:>3.1f}%",
168
+ "•",
169
+ MofNCompleteColumn(),
170
+ transient=True,
171
+ ) as ext_progress:
172
+ with ZipFile(archive_path) as zf:
173
+ ext_tid = ext_progress.add_task("Extracting", aname=archive_path.name, total=len(members := zf.infolist()))
174
+ extract_path = Path(zf.extract(members[0], download_dir))
175
+ if extract_type == "file":
176
+ assert len(members) == 1, f"Archive `{archive_path}` with extract type `{extract_type} has > 1 member!"
177
+ elif extract_type == "directory":
178
+ for member in members[1:]:
179
+ zf.extract(member, download_dir)
180
+ ext_progress.advance(ext_tid)
181
+ else:
182
+ raise ValueError(f"Extract type `{extract_type}` for archive `{archive_path}` is not defined!")
183
+
184
+ # Cleanup (if specified)
185
+ if cleanup:
186
+ archive_path.unlink()
187
+
188
+ return extract_path
189
+
190
+
191
+ def download_extract(dataset_id: str, root_dir: Path) -> None:
192
+ """Download all files for a given dataset (querying registry above), extracting archives if necessary."""
193
+ os.makedirs(download_dir := root_dir / "download" / dataset_id, exist_ok=True)
194
+
195
+ # Download Files => Single-Threaded, with Progress Bar
196
+ dl_tasks = [d for d in DATASET_REGISTRY[dataset_id] if not (download_dir / d["name"]).exists()]
197
+ for dl_task in dl_tasks:
198
+ dl_path = download_with_progress(dl_task["url"], download_dir)
199
+
200
+ # Extract Files (if specified) --> Note (assumes ".zip" ONLY!)
201
+ if dl_task["extract"]:
202
+ dl_path = extract_with_progress(dl_path, download_dir, dl_task["extract_type"])
203
+ dl_path = dl_path.parent if dl_path.is_file() else dl_path
204
+
205
+ # Rename Path --> dl_task["name"]
206
+ if dl_task["do_rename"]:
207
+ shutil.move(dl_path, download_dir / dl_task["name"])
capvector-oft/prismatic/preprocessing/materialize.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ materialize.py
3
+
4
+ Factory class for initializing pretraining datasets on a per-VLM basis; provides and exports individual functions for
5
+ clear control flow.
6
+ """
7
+
8
+ from typing import Tuple, Type
9
+
10
+ from torch.utils.data import Dataset
11
+ from transformers import PreTrainedTokenizerBase
12
+
13
+ from prismatic.conf import DatasetConfig
14
+ from prismatic.models.backbones.llm.prompting import PromptBuilder
15
+ from prismatic.models.backbones.vision import ImageTransform
16
+ from prismatic.preprocessing.datasets import AlignDataset, FinetuneDataset
17
+ from prismatic.util.data_utils import PaddedCollatorForLanguageModeling
18
+
19
+ # Dataset Initializers =>> Maps Stage --> cls()
20
+ DATASET_INITIALIZER = {"align": AlignDataset, "finetune": FinetuneDataset, "full-finetune": FinetuneDataset}
21
+
22
+
23
+ def get_dataset_and_collator(
24
+ stage: str,
25
+ dataset_cfg: DatasetConfig,
26
+ image_transform: ImageTransform,
27
+ tokenizer: PreTrainedTokenizerBase,
28
+ prompt_builder_fn: Type[PromptBuilder],
29
+ default_image_resolution: Tuple[int, int, int],
30
+ padding_side: str = "right",
31
+ ) -> Tuple[Dataset, PaddedCollatorForLanguageModeling]:
32
+ dataset_cls = DATASET_INITIALIZER[stage]
33
+ dataset_root_dir = dataset_cfg.dataset_root_dir
34
+ collator = PaddedCollatorForLanguageModeling(
35
+ tokenizer.model_max_length, tokenizer.pad_token_id, default_image_resolution, padding_side=padding_side
36
+ )
37
+
38
+ # Switch on `stage`
39
+ if stage == "align":
40
+ annotation_json, image_dir = dataset_cfg.align_stage_components
41
+ dataset = dataset_cls(
42
+ dataset_root_dir / annotation_json, dataset_root_dir / image_dir, image_transform, tokenizer
43
+ )
44
+ return dataset, collator
45
+
46
+ elif stage == "finetune":
47
+ annotation_json, image_dir = dataset_cfg.finetune_stage_components
48
+ dataset = dataset_cls(
49
+ dataset_root_dir / annotation_json,
50
+ dataset_root_dir / image_dir,
51
+ image_transform,
52
+ tokenizer,
53
+ prompt_builder_fn=prompt_builder_fn,
54
+ )
55
+ return dataset, collator
56
+
57
+ elif stage == "full-finetune":
58
+ annotation_json, image_dir = dataset_cfg.finetune_stage_components
59
+ dataset = dataset_cls(
60
+ dataset_root_dir / annotation_json,
61
+ dataset_root_dir / image_dir,
62
+ image_transform,
63
+ tokenizer,
64
+ prompt_builder_fn=prompt_builder_fn,
65
+ )
66
+ return dataset, collator
67
+
68
+ else:
69
+ raise ValueError(f"Stage `{stage}` is not supported!")
capvector-oft/prismatic/training/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .materialize import get_train_strategy
2
+ from .metrics import Metrics, VLAMetrics
capvector-oft/prismatic/training/materialize.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ materialize.py
3
+
4
+ Factory class defining functions for instantiating various Training Strategies, supporting different VLMs, backbones,
5
+ and strategy configurations.
6
+ """
7
+
8
+ from typing import Callable, Optional
9
+
10
+ import torch
11
+
12
+ from prismatic.models.vlms import PrismaticVLM
13
+ from prismatic.training.strategies import FSDPStrategy, TrainingStrategy
14
+
15
+ # Registry =>> Maps ID --> {cls(), kwargs} :: supports FSDP for now, but DDP handler is also implemented!
16
+ TRAIN_STRATEGIES = {
17
+ "fsdp-shard-grad-op": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "shard-grad-op"}},
18
+ "fsdp-full-shard": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "full-shard"}},
19
+ }
20
+
21
+
22
+ def get_train_strategy(
23
+ train_strategy: str,
24
+ vlm: PrismaticVLM,
25
+ device_id: int,
26
+ stage: str,
27
+ epochs: int,
28
+ max_steps: Optional[int],
29
+ global_batch_size: int,
30
+ per_device_batch_size: int,
31
+ learning_rate: float,
32
+ weight_decay: float,
33
+ max_grad_norm: float,
34
+ lr_scheduler_type: str,
35
+ warmup_ratio: float,
36
+ enable_gradient_checkpointing: bool = True,
37
+ enable_mixed_precision_training: bool = True,
38
+ reduce_in_full_precision: bool = False,
39
+ mixed_precision_dtype: torch.dtype = torch.bfloat16,
40
+ worker_init_fn: Optional[Callable[[int], None]] = None,
41
+ ) -> TrainingStrategy:
42
+ if train_strategy in TRAIN_STRATEGIES:
43
+ strategy_cfg = TRAIN_STRATEGIES[train_strategy]
44
+ strategy = strategy_cfg["cls"](
45
+ vlm=vlm,
46
+ device_id=device_id,
47
+ stage=stage,
48
+ epochs=epochs,
49
+ max_steps=max_steps,
50
+ global_batch_size=global_batch_size,
51
+ per_device_batch_size=per_device_batch_size,
52
+ learning_rate=learning_rate,
53
+ weight_decay=weight_decay,
54
+ max_grad_norm=max_grad_norm,
55
+ lr_scheduler_type=lr_scheduler_type,
56
+ warmup_ratio=warmup_ratio,
57
+ enable_gradient_checkpointing=enable_gradient_checkpointing,
58
+ enable_mixed_precision_training=enable_mixed_precision_training,
59
+ reduce_in_full_precision=reduce_in_full_precision,
60
+ mixed_precision_dtype=mixed_precision_dtype,
61
+ worker_init_fn=worker_init_fn,
62
+ **strategy_cfg["kwargs"],
63
+ )
64
+ return strategy
65
+ else:
66
+ raise ValueError(f"Train Strategy `{train_strategy}` is not supported!")
capvector-oft/prismatic/training/metrics.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ metrics.py
3
+
4
+ Utility classes defining a Metrics container and multiple Trackers to enable model/stage-specific logging to various
5
+ endpoints (e.g., JSONL local logs, Weights & Biases).
6
+ """
7
+
8
+ import time
9
+ from collections import defaultdict, deque
10
+ from pathlib import Path
11
+ from typing import Any, Dict, Optional, Protocol, Tuple, Union
12
+
13
+ import jsonlines
14
+ import numpy as np
15
+ import torch
16
+ import wandb
17
+
18
+ from prismatic.overwatch import initialize_overwatch
19
+
20
+ # Initialize Overwatch =>> Wraps `logging.Logger`
21
+ overwatch = initialize_overwatch(__name__)
22
+
23
+
24
+ # === Define Tracker Interface ===
25
+ class Tracker(Protocol):
26
+ def write_hyperparameters(self) -> None: ...
27
+
28
+ def write(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None: ...
29
+
30
+ def finalize(self) -> None: ...
31
+
32
+
33
+ # === Individual Tracker Definitions ===
34
+ class JSONLinesTracker:
35
+ def __init__(self, run_id: str, run_dir: Path, hparams: Dict[str, Any]) -> None:
36
+ self.run_id, self.run_dir, self.hparams = run_id, run_dir, hparams
37
+
38
+ @overwatch.rank_zero_only
39
+ def write_hyperparameters(self) -> None:
40
+ with jsonlines.open(self.run_dir / "run-metrics.jsonl", mode="w", sort_keys=True) as js_tracker:
41
+ js_tracker.write({"run_id": self.run_id, "hparams": self.hparams})
42
+
43
+ @overwatch.rank_zero_only
44
+ def write(self, _: int, metrics: Dict[str, Union[int, float]]) -> None:
45
+ with jsonlines.open(self.run_dir / f"{self.run_id}.jsonl", mode="a", sort_keys=True) as js_tracker:
46
+ js_tracker.write(metrics)
47
+
48
+ def finalize(self) -> None:
49
+ return
50
+
51
+
52
+ class WeightsBiasesTracker:
53
+ def __init__(
54
+ self,
55
+ run_id: str,
56
+ run_dir: Path,
57
+ hparams: Dict[str, Any],
58
+ project: str = "prismatic",
59
+ entity: Optional[str] = None,
60
+ group: str = "align",
61
+ ) -> None:
62
+ self.run_id, self.run_dir, self.hparams = run_id, run_dir, hparams
63
+
64
+ # Get W&B-Specific Initialization Parameters
65
+ self.project, self.entity, self.group, self.wandb_dir = project, entity, group, self.run_dir
66
+
67
+ # Call W&B.init()
68
+ self.initialize()
69
+
70
+ @overwatch.rank_zero_only
71
+ def initialize(self) -> None:
72
+ wandb.init(
73
+ name=self.run_id,
74
+ dir=self.wandb_dir,
75
+ config=self.hparams,
76
+ project=self.project,
77
+ entity=self.entity,
78
+ group=self.group,
79
+ )
80
+
81
+ @overwatch.rank_zero_only
82
+ def write_hyperparameters(self) -> None:
83
+ wandb.config = self.hparams
84
+
85
+ @overwatch.rank_zero_only
86
+ def write(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None:
87
+ wandb.log(metrics, step=global_step)
88
+
89
+ @staticmethod
90
+ def finalize() -> None:
91
+ if overwatch.is_rank_zero():
92
+ wandb.finish()
93
+
94
+ # A job gets 210 seconds to get its affairs in order
95
+ time.sleep(210)
96
+
97
+
98
+ # === Core Metrics Container :: Initializes Trackers => Compiles/Pushes Metrics ===
99
+
100
+
101
+ class Metrics:
102
+ def __init__(
103
+ self,
104
+ active_trackers: Tuple[str, ...],
105
+ run_id: str,
106
+ run_dir: Path,
107
+ hparams: Dict[str, Any],
108
+ stage: str,
109
+ wandb_project: str = "prismatic",
110
+ wandb_entity: Optional[str] = None,
111
+ grad_accumulation_steps: int = 1,
112
+ window_size: int = 128,
113
+ ) -> None:
114
+ self.run_id, self.run_dir, self.hparams, self.stage = run_id, run_dir, hparams, stage
115
+
116
+ # Initialize Trackers
117
+ self.trackers = []
118
+ for tracker_type in active_trackers:
119
+ if tracker_type == "jsonl":
120
+ tracker = JSONLinesTracker(run_id, run_dir, hparams)
121
+ elif tracker_type == "wandb":
122
+ tracker = WeightsBiasesTracker(
123
+ run_id, run_dir, hparams, project=wandb_project, entity=wandb_entity, group=self.stage
124
+ )
125
+ else:
126
+ raise ValueError(f"Tracker with type `{tracker_type} is not supported!")
127
+
128
+ # Add Hyperparameters --> add to `self.trackers`
129
+ tracker.write_hyperparameters()
130
+ self.trackers.append(tracker)
131
+
132
+ # Create Universal Metrics Buffers
133
+ self.global_step, self.start_time, self.step_start_time = 0, time.time(), time.time()
134
+ self.state = {
135
+ "loss_raw": deque(maxlen=grad_accumulation_steps),
136
+ "loss": deque(maxlen=window_size),
137
+ "step_time": deque(maxlen=window_size),
138
+ "lr": [],
139
+ }
140
+
141
+ def log(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None:
142
+ for tracker in self.trackers:
143
+ tracker.write(global_step, metrics)
144
+
145
+ def get_status(self, loss: Optional[torch.Tensor] = None) -> str:
146
+ lr = self.state["lr"][-1] if len(self.state["lr"]) > 0 else 0
147
+ if loss is None:
148
+ return f"=>> [Global Step] {self.global_step:06d} =>> LR :: {lr:.6f}"
149
+
150
+ # Otherwise, embed `loss` in status report!
151
+ return f"=>> [Global Step] {self.global_step:06d} =>> LR :: {lr:.6f} -- Loss :: {loss:.4f}"
152
+
153
+ def commit(
154
+ self, *, global_step: Optional[int] = None, lr: Optional[float] = None, update_step_time: bool = False, **kwargs
155
+ ) -> None:
156
+ """Update all metrics in `self.state` by iterating through special positional arguments & kwargs."""
157
+ if global_step is not None:
158
+ self.global_step = global_step
159
+
160
+ # For all other variables --> only track on rank zero!
161
+ if not overwatch.is_rank_zero():
162
+ return
163
+
164
+ # Special Positional Arguments
165
+ if lr is not None:
166
+ self.state["lr"].append(lr)
167
+
168
+ if update_step_time:
169
+ self.state["step_time"].append(time.time() - self.step_start_time)
170
+ self.step_start_time = time.time()
171
+
172
+ # Generic Keyword Arguments
173
+ for key, value in kwargs.items():
174
+ if key == "loss":
175
+ loss_val = value.detach()
176
+ self.state["loss_raw"].append(loss_val)
177
+ self.state["loss"].append(loss_val)
178
+ else:
179
+ self.state[key].append(value.detach())
180
+
181
+ @overwatch.rank_zero_only
182
+ def push(self) -> str:
183
+ # Note :: Raw Loss is an Average over Gradient Accumulation Steps --> No Smoothing!
184
+ loss_raw = torch.stack(list(self.state["loss_raw"])).mean().item()
185
+ loss = torch.stack(list(self.state["loss"])).mean().item()
186
+ step_time, lr = np.mean(list(self.state["step_time"])), self.state["lr"][-1]
187
+ status = self.get_status(loss)
188
+
189
+ # Fire to Trackers
190
+ prefix = self.stage.capitalize()
191
+ self.log(
192
+ self.global_step,
193
+ metrics={
194
+ f"{prefix}/Step": self.global_step,
195
+ f"{prefix}/Loss": loss,
196
+ f"{prefix}/Loss (Raw)": loss_raw,
197
+ f"{prefix}/Learning Rate": lr,
198
+ f"{prefix}/Step Time": step_time,
199
+ },
200
+ )
201
+ return status
202
+
203
+ def finalize(self) -> str:
204
+ for tracker in self.trackers:
205
+ tracker.finalize()
206
+
207
+
208
+ class VLAMetrics:
209
+ def __init__(
210
+ self,
211
+ active_trackers: Tuple[str, ...],
212
+ run_id: str,
213
+ run_dir: Path,
214
+ hparams: Dict[str, Any],
215
+ wandb_project: str = "openvla",
216
+ wandb_entity: Optional[str] = "stanford-voltron",
217
+ grad_accumulation_steps: int = 1,
218
+ window_size: int = 1,
219
+ resume_step: Optional[int] = None,
220
+ resume_epoch: Optional[int] = None,
221
+ ) -> None:
222
+ self.run_id, self.run_dir, self.hparams = run_id, run_dir, hparams
223
+
224
+ # Initialize Trackers
225
+ self.trackers = []
226
+ for tracker_type in active_trackers:
227
+ if tracker_type == "jsonl":
228
+ tracker = JSONLinesTracker(run_id, run_dir, hparams)
229
+ elif tracker_type == "wandb":
230
+ tracker = WeightsBiasesTracker(
231
+ run_id, run_dir, hparams, project=wandb_project, entity=wandb_entity, group="vla-train"
232
+ )
233
+ else:
234
+ raise ValueError(f"Tracker with type `{tracker_type} is not supported!")
235
+
236
+ # Add Hyperparameters --> add to `self.trackers`
237
+ tracker.write_hyperparameters()
238
+ self.trackers.append(tracker)
239
+
240
+ # Create Universal Metrics Buffers
241
+ self.global_step = 0 if resume_step is None else resume_step
242
+ self.epoch = 0 if resume_epoch is None else resume_epoch
243
+ self.start_time, self.step_start_time = time.time(), time.time()
244
+ self.state = {
245
+ "loss_raw": deque(maxlen=grad_accumulation_steps),
246
+ "loss": deque(maxlen=window_size),
247
+ "l1_loss": deque(maxlen=window_size),
248
+ "action_accuracy": deque(maxlen=window_size),
249
+ "step_time": deque(maxlen=window_size),
250
+ "lr": [],
251
+ }
252
+
253
+ # Created metrics buffers for individual tracked datasets
254
+ self.dataset_trackers = defaultdict(lambda: VLAMetrics([], "", "", {}))
255
+
256
+ def log(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None:
257
+ for tracker in self.trackers:
258
+ tracker.write(global_step, metrics)
259
+
260
+ def get_status(self, loss: Optional[torch.Tensor] = None) -> str:
261
+ lr = self.state["lr"][-1] if len(self.state["lr"]) > 0 else 0
262
+ if loss is None:
263
+ return f"=>> [Epoch {self.epoch:03d}] Global Step {self.global_step:06d} =>> LR :: {lr:.6f}"
264
+
265
+ # Otherwise, embed `loss` in status report!
266
+ return f"=>> [Epoch {self.epoch:03d}] Global Step {self.global_step:06d} =>> LR :: {lr:.6f} - Loss :: {loss:.4f}"
267
+
268
+ def commit(
269
+ self,
270
+ *,
271
+ global_step: Optional[int] = None,
272
+ epoch: Optional[int] = None,
273
+ lr: Optional[float] = None,
274
+ update_step_time: bool = False,
275
+ **kwargs,
276
+ ) -> None:
277
+ """Update all metrics in `self.state` by iterating through special positional arguments & kwargs."""
278
+ if global_step is not None:
279
+ self.global_step = global_step
280
+
281
+ if epoch is not None:
282
+ self.epoch = epoch
283
+
284
+ # For all other variables --> only track on rank zero!
285
+ if not overwatch.is_rank_zero():
286
+ return
287
+
288
+ # Special Positional Arguments
289
+ if lr is not None:
290
+ self.state["lr"].append(lr)
291
+
292
+ if update_step_time:
293
+ self.state["step_time"].append(time.time() - self.step_start_time)
294
+ self.step_start_time = time.time()
295
+
296
+ # Generic Keyword Arguments
297
+ for key, value in kwargs.items():
298
+ if key == "loss":
299
+ loss_val = value.detach()
300
+ self.state["loss_raw"].append(loss_val)
301
+ self.state["loss"].append(loss_val)
302
+ else:
303
+ self.state[key].append(value.detach())
304
+
305
+ def commit_for_dataset(self, dataset_name: str, **kwargs) -> None:
306
+ self.dataset_trackers[dataset_name].commit(**kwargs)
307
+
308
+ @overwatch.rank_zero_only
309
+ def push(self) -> str:
310
+ # Note :: Raw Loss is an Average over Gradient Accumulation Steps --> No Smoothing!
311
+ loss_raw = torch.stack(list(self.state["loss_raw"])).mean().item()
312
+ loss = torch.stack(list(self.state["loss"])).mean().item()
313
+ l1_loss = torch.stack(list(self.state["l1_loss"])).mean().item()
314
+ action_accuracy = torch.stack(list(self.state["action_accuracy"])).mean().item()
315
+ step_time, lr = np.mean(list(self.state["step_time"])), self.state["lr"][-1]
316
+ status = self.get_status(loss)
317
+
318
+ # Get metrics per dataset
319
+ dataset_metrics = {}
320
+ for ds, tracker in self.dataset_trackers.items():
321
+ dataset_metrics.update(
322
+ {
323
+ f"{ds}/L1 Loss": torch.stack(list(tracker.state["l1_loss"])).mean().item(),
324
+ f"{ds}/Action Token Accuracy": torch.stack(list(tracker.state["action_accuracy"])).mean().item(),
325
+ }
326
+ )
327
+
328
+ # Fire to Trackers
329
+ prefix = "VLA Train"
330
+ self.log(
331
+ self.global_step,
332
+ metrics={
333
+ f"{prefix}/Step": self.global_step,
334
+ f"{prefix}/Epoch": self.epoch,
335
+ f"{prefix}/Loss": loss,
336
+ f"{prefix}/L1 Loss": l1_loss,
337
+ f"{prefix}/Action Token Accuracy": action_accuracy,
338
+ f"{prefix}/Loss (Raw)": loss_raw,
339
+ f"{prefix}/Learning Rate": lr,
340
+ f"{prefix}/Step Time": step_time,
341
+ **dataset_metrics,
342
+ },
343
+ )
344
+ return status
345
+
346
+ def finalize(self) -> str:
347
+ for tracker in self.trackers:
348
+ tracker.finalize()
capvector-oft/prismatic/training/strategies/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .base_strategy import TrainingStrategy
2
+ from .ddp import DDPStrategy
3
+ from .fsdp import FSDPStrategy
capvector-oft/prismatic/training/strategies/base_strategy.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ base_strategy.py
3
+
4
+ Abstract class definition of a (distributed) training strategy, with full annotations of class methods, utility
5
+ functions, and initialization logic.
6
+
7
+ Training Strategies (DDP, FSDP-Grad, FSDP-Full) tend to have a lot of repeated components; this class does a lot of
8
+ heavy lifting.
9
+ """
10
+
11
+ from abc import ABC, abstractmethod
12
+ from pathlib import Path
13
+ from typing import Callable, Optional
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.distributed as dist
18
+ from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset
19
+ from tqdm import tqdm
20
+ from transformers.modeling_outputs import CausalLMOutputWithPast
21
+
22
+ from prismatic.models.vlms import PrismaticVLM
23
+ from prismatic.overwatch import initialize_overwatch
24
+ from prismatic.training.metrics import Metrics, VLAMetrics
25
+ from prismatic.training.train_utils import (
26
+ compute_actions_l1_loss,
27
+ compute_token_accuracy,
28
+ get_current_action_mask,
29
+ get_next_actions_mask,
30
+ )
31
+ from prismatic.util import check_bloat16_supported
32
+ from prismatic.util.batching_utils import SplitModalitySampler
33
+ from prismatic.util.data_utils import PaddedCollatorForActionPrediction, PaddedCollatorForLanguageModeling
34
+ from prismatic.vla.action_tokenizer import ActionTokenizer
35
+
36
+ # HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels)
37
+ from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, NUM_ACTIONS_CHUNK, IGNORE_INDEX
38
+ NEWLINE_INDEX = 13 # '\n'
39
+ STOP_INDEX = 2 # '</s>'
40
+
41
+ # Initialize Overwatch =>> Wraps `logging.Logger`
42
+ overwatch = initialize_overwatch(__name__)
43
+
44
+
45
+ # === Abstract Base Class for an arbitrary Training Strategy ===
46
+ class TrainingStrategy(ABC):
47
+ def __init__(
48
+ self,
49
+ vlm: PrismaticVLM,
50
+ device_id: int,
51
+ stage: str,
52
+ epochs: int,
53
+ max_steps: Optional[int],
54
+ global_batch_size: int,
55
+ per_device_batch_size: int,
56
+ learning_rate: float,
57
+ weight_decay: float,
58
+ max_grad_norm: float,
59
+ lr_scheduler_type: str,
60
+ warmup_ratio: float,
61
+ enable_gradient_checkpointing: bool = True,
62
+ enable_mixed_precision_training: bool = True,
63
+ reduce_in_full_precision: bool = False,
64
+ mixed_precision_dtype: torch.dtype = torch.bfloat16,
65
+ worker_init_fn: Optional[Callable[[int], None]] = None,
66
+ **_: str,
67
+ ) -> None:
68
+ self.vlm, self.device_id, self.stage = vlm, device_id, stage
69
+
70
+ # Get relevant VLM instance parameters before they get (potentially) wrapped
71
+ self.all_module_keys, self.trainable_module_keys = self.vlm.all_module_keys, self.vlm.trainable_module_keys
72
+ self.llm_transformer_layer_cls = self.vlm.llm_backbone.transformer_layer_cls
73
+
74
+ # Optimization Parameters
75
+ self.epochs, self.max_steps = epochs, max_steps
76
+ self.global_batch_size, self.per_device_batch_size = global_batch_size, per_device_batch_size
77
+
78
+ self.learning_rate, self.weight_decay, self.max_grad_norm = learning_rate, weight_decay, max_grad_norm
79
+ self.lr_scheduler_type, self.warmup_ratio = lr_scheduler_type, warmup_ratio
80
+
81
+ # Generic Strategy Parameters
82
+ self.enable_gradient_checkpointing = enable_gradient_checkpointing
83
+ self.enable_mixed_precision_training = enable_mixed_precision_training
84
+ self.reduce_in_full_precision = reduce_in_full_precision
85
+ self.mixed_precision_dtype = mixed_precision_dtype
86
+
87
+ # DataLoader Parameters
88
+ self.worker_init_fn = worker_init_fn
89
+
90
+ # Optimizers & Scheduler (initialized in `run_setup`)
91
+ self.optimizer, self.lr_scheduler = None, None
92
+
93
+ # Lightweight Validation
94
+ assert (
95
+ self.global_batch_size % self.per_device_batch_size == 0
96
+ ), "Per-device batch size must evenly divide global batch size!"
97
+ self.grad_accumulation_steps = self.global_batch_size // self.per_device_batch_size // overwatch.world_size()
98
+ if self.enable_mixed_precision_training:
99
+ assert self.mixed_precision_dtype == torch.bfloat16, "Only BF16 mixed precision training is supported!"
100
+ assert check_bloat16_supported(), "BFloat16 is not supported on this hardware; unset `mixed_precision`"
101
+
102
+ @abstractmethod
103
+ def save_checkpoint(
104
+ self,
105
+ run_dir: Path,
106
+ global_step: int,
107
+ epoch: int,
108
+ train_loss: Optional[float] = None,
109
+ only_trainable: bool = True,
110
+ ) -> None: ...
111
+
112
+ @abstractmethod
113
+ def run_setup(self, run_dir: Path, n_train_examples: int) -> None: ...
114
+
115
+ @abstractmethod
116
+ def clip_grad_norm(self) -> None: ...
117
+
118
+ def run_training(
119
+ self,
120
+ dataset: Dataset,
121
+ collator: PaddedCollatorForLanguageModeling,
122
+ metrics: Metrics,
123
+ stage: str = "finetune",
124
+ batch_construction_strategy: str = "split-modality",
125
+ seed: int = 7,
126
+ ) -> None:
127
+ """Run the training loop for the given `dataset` and `collator`; log losses, results to `metrics`"""
128
+ if "finetune" in stage and batch_construction_strategy == "split-modality":
129
+ # Instantiate the split-modality sampler; if you want to extend with other batch construction schemes,
130
+ # (e.g., grouping by length) =>> can easily add them here!
131
+ modality_lengths = dataset.get_modality_lengths()
132
+ sampler = SplitModalitySampler(
133
+ dataset,
134
+ modality_lengths,
135
+ global_batch_size=self.global_batch_size,
136
+ num_replicas=overwatch.world_size(),
137
+ rank=overwatch.rank(),
138
+ seed=seed,
139
+ drop_last=False,
140
+ )
141
+
142
+ else:
143
+ sampler = DistributedSampler(
144
+ dataset,
145
+ num_replicas=overwatch.world_size(),
146
+ rank=overwatch.rank(),
147
+ shuffle=True,
148
+ seed=seed,
149
+ drop_last=False,
150
+ )
151
+
152
+ # Create a DataLoader with the initialized sampler, per-device-bsz, and collator
153
+ dataloader = DataLoader(
154
+ dataset,
155
+ batch_size=self.per_device_batch_size,
156
+ sampler=sampler,
157
+ collate_fn=collator,
158
+ num_workers=2,
159
+ worker_init_fn=self.worker_init_fn,
160
+ )
161
+
162
+ # Max Steps vs. Epochs Computation
163
+ steps_per_epoch = len(dataloader) // self.grad_accumulation_steps
164
+ if self.max_steps is not None and steps_per_epoch < self.max_steps:
165
+ # Just set `epochs` to some large number --> we'll short-circuit based on steps anyway
166
+ self.epochs = 100
167
+
168
+ # === Train ===
169
+ status = metrics.get_status()
170
+ with tqdm(
171
+ total=(
172
+ (self.epochs * (len(dataloader) // self.grad_accumulation_steps))
173
+ if self.max_steps is None
174
+ else self.max_steps
175
+ ),
176
+ desc=status,
177
+ leave=False,
178
+ disable=not overwatch.is_rank_zero(),
179
+ ) as progress:
180
+ for epoch in range(self.epochs):
181
+ self.vlm.train()
182
+ sampler.set_epoch(epoch)
183
+
184
+ # Zero-Gradients (just in case)
185
+ self.optimizer.zero_grad()
186
+
187
+ # Note that we'll unpack batch (and let AMP/FSDP do its thing) in the VLM.forward() call
188
+ # => Basically, if we're using mixed precision (or not), autocast()/FSDP will move to device!
189
+ for train_idx, batch in enumerate(dataloader):
190
+ # [Contract] self.vlm.forward() must automatically compute `loss` and return!
191
+ with torch.autocast(
192
+ "cuda",
193
+ dtype=self.mixed_precision_dtype,
194
+ enabled=self.enable_mixed_precision_training,
195
+ ):
196
+ output: CausalLMOutputWithPast = self.vlm(
197
+ input_ids=batch["input_ids"],
198
+ attention_mask=batch["attention_mask"],
199
+ pixel_values=batch["pixel_values"],
200
+ labels=batch["labels"],
201
+ multimodal_indices=batch["multimodal_indices"],
202
+ )
203
+ loss = output.loss
204
+
205
+ # Commit Loss (Prior to Gradient Accumulation Normalization)
206
+ metrics.commit(loss=loss)
207
+
208
+ # Normalize Loss to account for Gradient Accumulation --> Backward!
209
+ # [IMPORTANT] Technically speaking, doing gradient accumulation in this way is "incorrect"; this is
210
+ # because in general, each batch has a *different number of masked out tokens* (because
211
+ # we're instruct-tuning). Taking the mean over two unbalanced means != the right thing!
212
+ #
213
+ # HOWEVER -- at least at the 7B scale, the "naive" approach is just as performant as
214
+ # the "correct" implementation, without adding extra complexity.
215
+ #
216
+ # That being said =>> at the 13B scale, *no matter what we tried, ANY gradient accumulation is just
217
+ # really bad for downstream performance. Initial investigation shows that BF16 accumulation
218
+ # just really tanks in precision... and don't have a good/clean way to fix this. Would love for
219
+ # someone to PR and fix this (and I'd greatly appreciate it!!!)
220
+ normalized_loss = loss / self.grad_accumulation_steps
221
+ normalized_loss.backward()
222
+
223
+ # Step =>> Only if Done w/ Gradient Accumulation
224
+ if (train_idx + 1) % self.grad_accumulation_steps == 0:
225
+ metrics.commit(update_step_time=True)
226
+
227
+ # Clip Gradients --> this is custom, per-strategy because of DDP vs. FSDP locality-assumptions
228
+ self.clip_grad_norm()
229
+
230
+ # Optimizer & LR Scheduler Step
231
+ self.optimizer.step()
232
+ self.lr_scheduler.step()
233
+ self.optimizer.zero_grad()
234
+
235
+ # Push Metrics
236
+ metrics.commit(global_step=metrics.global_step + 1, lr=self.lr_scheduler.get_last_lr()[0])
237
+ status = metrics.push()
238
+
239
+ # Check for Termination & Save Final Checkpoint (in case `max_steps` is not None)
240
+ if self.max_steps is not None and metrics.global_step >= self.max_steps:
241
+ self.save_checkpoint(metrics.run_dir, metrics.global_step, epoch, loss.item())
242
+ dist.barrier()
243
+
244
+ return
245
+
246
+ # Update Progress Bar
247
+ progress.update()
248
+ progress.set_description(status)
249
+
250
+ # Save checkpoint at end each epoch (if `self.max_steps` is None)
251
+ if self.max_steps is None:
252
+ self.save_checkpoint(metrics.run_dir, metrics.global_step, epoch, loss.item())
253
+ dist.barrier()
254
+
255
+ # === VLA Training ===
256
+
257
+ def run_vla_training(
258
+ self,
259
+ vla_dataset: IterableDataset,
260
+ collator: PaddedCollatorForActionPrediction,
261
+ action_tokenizer: ActionTokenizer,
262
+ metrics: VLAMetrics,
263
+ save_interval: int = 2500,
264
+ save_full_model: bool = True,
265
+ ) -> None:
266
+ """Run the VLA training loop for the given `dataset` and `collator`; log losses, action metrics to `metrics`."""
267
+ assert isinstance(vla_dataset, IterableDataset), "VLA training expects an IterableDataset!"
268
+ assert self.grad_accumulation_steps == 1, "VLA training does not support gradient accumulation!"
269
+
270
+ # Create a DataLoader =>> Set `num_workers` to 0; RLDS loader handles parallelism!
271
+ dataloader = DataLoader(
272
+ vla_dataset,
273
+ batch_size=self.per_device_batch_size,
274
+ sampler=None,
275
+ collate_fn=collator,
276
+ num_workers=0,
277
+ worker_init_fn=self.worker_init_fn,
278
+ )
279
+
280
+ # === Train ===
281
+ status = metrics.get_status()
282
+ with tqdm(
283
+ total=(self.epochs * len(dataloader)) if self.max_steps is None else self.max_steps,
284
+ desc=status,
285
+ leave=False,
286
+ disable=not overwatch.is_rank_zero(),
287
+ ) as progress:
288
+ self.vlm.train()
289
+
290
+ # Zero Gradients (just in case)
291
+ self.optimizer.zero_grad()
292
+
293
+ # [Contract] DataLoader wraps RLDS Loader (`.as_numpy_iterator() =>> implicit `.repeat()`)
294
+ # => This means looping over the DataLoader is basically "infinite" (so no outer loop over epochs).
295
+ # Slightly breaks default PyTorch semantics, which is why we adaptively compute `epoch` below.
296
+ for batch in dataloader:
297
+ # Note that we'll unpack batch (and let AMP/FSDP do its thing) in the VLM.forward() call
298
+ # => Basically, if we're using mixed precision (or not), autocast()/FSDP will move to device!
299
+ with torch.autocast(
300
+ "cuda", dtype=self.mixed_precision_dtype, enabled=self.enable_mixed_precision_training
301
+ ):
302
+ # [Contract] self.vlm.forward() must automatically compute `loss` and return!
303
+ output: CausalLMOutputWithPast = self.vlm(
304
+ input_ids=batch["input_ids"],
305
+ attention_mask=batch["attention_mask"],
306
+ pixel_values=batch["pixel_values"],
307
+ labels=batch["labels"],
308
+ )
309
+ loss = output.loss
310
+
311
+ # Commit Loss =>> Backward!
312
+ metrics.commit(loss=loss)
313
+ loss.backward()
314
+
315
+ # Get predicted and ground-truth token IDs
316
+ predicted_token_ids = output.logits[:, self.vlm.vision_backbone.num_patches : -1].argmax(dim=2)
317
+ ground_truth_token_ids = batch["labels"][:, 1:].to(predicted_token_ids.device)
318
+
319
+ #######################################################################
320
+ # === Compute Current Action Token Accuracy & L1 Loss ===
321
+ #######################################################################
322
+
323
+ # Get current action mask: Target the first ACTION_DIM non-ignore tokens
324
+ current_action_mask = get_current_action_mask(ground_truth_token_ids)
325
+
326
+ # Compute Accuracy
327
+ action_accuracy = compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask=current_action_mask)
328
+
329
+ # Compute L1 Loss on Predicted (Continuous) Actions
330
+ action_l1_loss = compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=current_action_mask)
331
+
332
+ #######################################################################
333
+ # === Compute Next Actions Token Accuracy & L1 Loss ===
334
+ #######################################################################
335
+
336
+ # Get next actions mask: Target all tokens after the first ACTION_DIM non-ignore tokens (excluding the last token, which is the stop token)
337
+ next_actions_mask = get_next_actions_mask(ground_truth_token_ids)
338
+
339
+ # Compute Accuracy
340
+ next_actions_accuracy = compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask)
341
+
342
+ # Compute L1 Loss on Predicted (Continuous) Actions
343
+ next_actions_l1_loss = compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask)
344
+
345
+ #######################################################################
346
+ # === Log ===
347
+ #######################################################################
348
+
349
+ # Commit Metrics
350
+ metrics.commit(
351
+ action_accuracy=action_accuracy,
352
+ l1_loss=action_l1_loss,
353
+ next_actions_accuracy=next_actions_accuracy,
354
+ next_actions_l1_loss=next_actions_l1_loss,
355
+ update_step_time=True,
356
+ )
357
+
358
+ # Compute metrics per dataset --> only on rank_zero since we don't log them on other workers anyways
359
+ if overwatch.is_rank_zero():
360
+ datasets = set(batch["dataset_names"])
361
+ if len(datasets) > 1:
362
+ for ds in datasets:
363
+ ds_mask = torch.tensor([elem == ds for elem in batch["dataset_names"]])
364
+ action_accuracy_ds = correct_preds[ds_mask].sum().float() / mask[ds_mask].sum().float()
365
+ pred_continuous_actions_ds = torch.tensor(
366
+ action_tokenizer.decode_token_ids_to_actions(
367
+ predicted_token_ids[ds_mask][mask[ds_mask]].cpu().numpy()
368
+ )
369
+ )
370
+ continuous_actions_gt_ds = torch.tensor(
371
+ action_tokenizer.decode_token_ids_to_actions(
372
+ ground_truth_token_ids[ds_mask][mask[ds_mask]].cpu().numpy()
373
+ )
374
+ )
375
+ action_l1_loss_ds = torch.nn.functional.l1_loss(
376
+ pred_continuous_actions_ds, continuous_actions_gt_ds
377
+ )
378
+ metrics.commit_for_dataset(
379
+ dataset_name=ds.decode(),
380
+ action_accuracy=action_accuracy_ds,
381
+ l1_loss=action_l1_loss_ds,
382
+ next_actions_accuracy=next_actions_accuracy,
383
+ next_actions_l1_loss=next_actions_l1_loss,
384
+ )
385
+
386
+ # === Gradient Step ===
387
+
388
+ # Clip Gradients --> this is custom, per-strategy because of DDP vs. FSDP locality assumptions
389
+ self.clip_grad_norm()
390
+
391
+ # Optimizer & LR Scheduler Step
392
+ self.optimizer.step()
393
+ self.lr_scheduler.step()
394
+ self.optimizer.zero_grad()
395
+
396
+ # Compute epoch value using number of completed gradient steps
397
+ epoch = (metrics.global_step + 1) // (len(vla_dataset) // self.global_batch_size)
398
+
399
+ # Push Metrics
400
+ metrics.commit(global_step=metrics.global_step + 1, epoch=epoch, lr=self.lr_scheduler.get_last_lr()[0])
401
+ status = metrics.push()
402
+
403
+ # Check for Save Interval or Max Steps & Save Checkpoint
404
+ if (terminate := (self.max_steps is not None and metrics.global_step >= self.max_steps)) or (
405
+ (metrics.global_step % save_interval) == 0
406
+ ):
407
+ self.save_checkpoint(
408
+ metrics.run_dir, metrics.global_step, epoch, loss.item(), only_trainable=not save_full_model
409
+ )
410
+ dist.barrier()
411
+
412
+ if terminate:
413
+ return
414
+
415
+ # Update Progress Bar
416
+ progress.update()
417
+ progress.set_description(status)
capvector-oft/prismatic/training/strategies/ddp.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ddp.py
3
+
4
+ Core class definition for a strategy implementing Torch native Distributed Data Parallel Training; note that on most
5
+ GPU hardware and LLM backbones >= 5-7B parameters, DDP training will OOM, which is why we opt for FSDP.
6
+ """
7
+
8
+ import shutil
9
+ from pathlib import Path
10
+ from typing import Optional
11
+
12
+ import torch
13
+ from torch.nn.parallel import DistributedDataParallel as DDP
14
+ from torch.optim import AdamW
15
+ from transformers.optimization import get_constant_schedule, get_cosine_schedule_with_warmup
16
+
17
+ from prismatic.overwatch import initialize_overwatch
18
+ from prismatic.training.strategies.base_strategy import TrainingStrategy
19
+
20
+ # Initialize Overwatch =>> Wraps `logging.Logger`
21
+ overwatch = initialize_overwatch(__name__)
22
+
23
+
24
+ class DDPStrategy(TrainingStrategy):
25
+ @overwatch.rank_zero_only
26
+ def save_checkpoint(
27
+ self,
28
+ run_dir: Path,
29
+ global_step: int,
30
+ epoch: int,
31
+ train_loss: Optional[float] = None,
32
+ only_trainable: bool = True,
33
+ ) -> None:
34
+ """Save a checkpoint to the `run_dir` only containing the state_dicts for trainable parameters by default."""
35
+ assert isinstance(self.vlm, DDP), "save_checkpoint assumes VLM is already wrapped in DDP!"
36
+
37
+ # Splinter State Dictionary by Top-Level Submodules (or subset, if `only_trainable`)
38
+ model_state_dicts = {
39
+ mkey: getattr(self.vlm.module, mkey).state_dict()
40
+ for mkey in (self.trainable_module_keys if only_trainable else self.all_module_keys)
41
+ }
42
+ optimizer_state_dict = self.optimizer.state_dict()
43
+
44
+ # Set Checkpoint Path =>> Embed *minimal* training statistics!
45
+ checkpoint_dir = run_dir / "checkpoints"
46
+ if train_loss is None:
47
+ checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss=inf.pt"
48
+ else:
49
+ checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss={train_loss:.4f}.pt"
50
+
51
+ # Save Checkpoint & Copy Latest to `latest-checkpoint.pt`
52
+ torch.save({"model": model_state_dicts, "optimizer": optimizer_state_dict}, checkpoint_path)
53
+ shutil.copy(checkpoint_path, checkpoint_dir / "latest-checkpoint.pt")
54
+
55
+ def run_setup(self, run_dir: Path, n_train_examples: int) -> None:
56
+ # Gradient Checkpointing Setup
57
+ if self.enable_gradient_checkpointing:
58
+ # For Gradient Checkpointing --> we make the assumption that the "bulk" of activation memory is taken up
59
+ # by the LLM; because we also make the explicit assumption that each LLM is derived from a HF
60
+ # pretrained model, the only thing we *need* to do (technically) is call `gradient_checkpoint_enable`
61
+ # on `self.llm_backbone`.
62
+ #
63
+ # What does it actually do? --> runs the *generic* custom_forward + torch.utils.checkpoint.checkpoint logic
64
+ # => github.com/huggingface/transformers/.../models/llama/modeling_llama.py#L692-L706
65
+ #
66
+ # Additional Reference (to better understand gradient checkpointing in PyTorch writ large)
67
+ # => github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb
68
+ overwatch.info("Enabling Gradient Checkpointing on LLM Backbone", ctx_level=1)
69
+ self.vlm.llm_backbone.gradient_checkpointing_enable()
70
+
71
+ # Move to Device =>> Note parameters are in full precision (*mixed precision* will only autocast as appropriate)
72
+ overwatch.info("Placing Entire VLM (Vision Backbone, LLM Backbone, Projector Weights) on GPU", ctx_level=1)
73
+ self.vlm.to(self.device_id)
74
+
75
+ # Wrap with Distributed Data Parallel
76
+ # => Note: By default, wrapping naively with DDP(self.vlm) will initialize a *separate* buffer on GPU that
77
+ # is the same size/dtype as the model parameters; this will *double* GPU memory!
78
+ # - stackoverflow.com/questions/68949954/model-takes-twice-the-memory-footprint-with-distributed-data-parallel
79
+ overwatch.info("Wrapping VLM with Distributed Data Parallel", ctx_level=1)
80
+ self.vlm = DDP(self.vlm, device_ids=[self.device_id], gradient_as_bucket_view=True)
81
+
82
+ # Create Optimizer and LR Scheduler =>> note that most of the LR Schedulers we use require `max_steps/epochs`
83
+ # => Optimizer should only operate on parameters that are *unfrozen* / trainable!
84
+ trainable_params = [param for param in self.vlm.parameters() if param.requires_grad]
85
+ if self.max_steps is None:
86
+ num_training_steps = (n_train_examples * self.epochs) // self.global_batch_size
87
+ else:
88
+ num_training_steps = self.max_steps
89
+
90
+ if self.lr_scheduler_type == "linear-warmup+cosine-decay":
91
+ # Set warmup steps (floor) based on `warmup_ratio` (should be 0.03 - 0.05)
92
+ num_warmup_steps = int(num_training_steps * self.warmup_ratio)
93
+
94
+ assert self.weight_decay == 0, "DDP training does not currently support `weight_decay` > 0!"
95
+ self.optimizer = AdamW(trainable_params, lr=self.learning_rate, weight_decay=self.weight_decay)
96
+ self.lr_scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps, num_training_steps)
97
+ for param_group in self.optimizer.param_groups:
98
+ param_group["lr"] = 0.0
99
+
100
+ elif self.lr_scheduler_type == "constant":
101
+ num_warmup_steps = 0
102
+
103
+ assert self.weight_decay == 0, "DDP training does not currently support `weight_decay` > 0!"
104
+ self.optimizer = AdamW(trainable_params, lr=self.learning_rate, weight_decay=self.weight_decay)
105
+ self.lr_scheduler = get_constant_schedule(self.optimizer)
106
+
107
+ else:
108
+ raise ValueError(f"Learning Rate Schedule with type `{self.lr_scheduler_type}` is not supported!")
109
+
110
+ # Finalize Setup =>> Log
111
+ overwatch.info(
112
+ "DDP Strategy =>> Finalized Training Setup:\n"
113
+ f" |-> Global (Effective) Batch Size = {self.global_batch_size}\n"
114
+ f" |-> Per-Device Batch Size = {self.per_device_batch_size}\n"
115
+ f" |-> Distributed World Size = {overwatch.world_size()}\n"
116
+ f" |-> Gradient Accumulation Steps = {self.grad_accumulation_steps}\n\n"
117
+ f" |-> LLM Backbone Gradient Checkpointing = {self.enable_gradient_checkpointing}\n"
118
+ f" |-> Use Native AMP = {self.enable_mixed_precision_training} ({self.mixed_precision_dtype})\n\n"
119
+ f" |-> Default AdamW LR = {self.learning_rate}\n"
120
+ f" |-> AdamW Weight Decay = {self.weight_decay}\n"
121
+ f" |-> LR Scheduler Type = {self.lr_scheduler_type}\n"
122
+ f" |-> LR Scheduler Warmup Steps (Ratio) = {num_warmup_steps} ({self.warmup_ratio})\n"
123
+ f" |-> Dataset Size = {n_train_examples} Examples\n"
124
+ f" |-> Max Steps = {num_training_steps}\n"
125
+ )
126
+
127
+ def clip_grad_norm(self) -> None:
128
+ torch.nn.utils.clip_grad_norm_(self.vlm.parameters(), max_norm=self.max_grad_norm)
capvector-oft/prismatic/training/strategies/fsdp.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ fsdp.py
3
+
4
+ Core class definition for a strategy implementing Torch native Fully Sharded Data Parallel Training (with support for
5
+ fine-grained control over wrapping policies and mixed precision per component).
6
+ """
7
+
8
+ import math
9
+ from collections import OrderedDict
10
+ from functools import partial
11
+ from pathlib import Path
12
+ from typing import Callable, Optional
13
+
14
+ import torch
15
+ import torch.distributed as dist
16
+ import torch.nn as nn
17
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
18
+ CheckpointImpl,
19
+ apply_activation_checkpointing,
20
+ checkpoint_wrapper,
21
+ )
22
+ from torch.distributed.fsdp import (
23
+ FullStateDictConfig,
24
+ MixedPrecision,
25
+ ShardingStrategy,
26
+ StateDictType,
27
+ )
28
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
29
+ from torch.optim import AdamW
30
+ from transformers.optimization import get_constant_schedule, get_cosine_schedule_with_warmup
31
+
32
+ from prismatic.models.vlms import PrismaticVLM
33
+ from prismatic.overwatch import initialize_overwatch
34
+ from prismatic.training.strategies.base_strategy import TrainingStrategy
35
+
36
+ # Initialize Overwatch =>> Wraps `logging.Logger`
37
+ overwatch = initialize_overwatch(__name__)
38
+
39
+
40
+ class FSDPStrategy(TrainingStrategy):
41
+ def __init__(
42
+ self,
43
+ vlm: PrismaticVLM,
44
+ device_id: int,
45
+ stage: str,
46
+ epochs: int,
47
+ max_steps: Optional[int],
48
+ global_batch_size: int,
49
+ per_device_batch_size: int,
50
+ learning_rate: float,
51
+ weight_decay: float,
52
+ max_grad_norm: float,
53
+ lr_scheduler_type: str,
54
+ warmup_ratio: float,
55
+ enable_gradient_checkpointing: bool = True,
56
+ enable_mixed_precision_training: bool = True,
57
+ reduce_in_full_precision: bool = False,
58
+ mixed_precision_dtype: torch.dtype = torch.bfloat16,
59
+ worker_init_fn: Optional[Callable[[int], None]] = None,
60
+ sharding_strategy: str = "shard-grad-op",
61
+ state_dict_type: StateDictType = StateDictType.FULL_STATE_DICT,
62
+ ) -> None:
63
+ super().__init__(
64
+ vlm=vlm,
65
+ device_id=device_id,
66
+ stage=stage,
67
+ epochs=epochs,
68
+ max_steps=max_steps,
69
+ global_batch_size=global_batch_size,
70
+ per_device_batch_size=per_device_batch_size,
71
+ learning_rate=learning_rate,
72
+ weight_decay=weight_decay,
73
+ max_grad_norm=max_grad_norm,
74
+ lr_scheduler_type=lr_scheduler_type,
75
+ warmup_ratio=warmup_ratio,
76
+ enable_gradient_checkpointing=enable_gradient_checkpointing,
77
+ enable_mixed_precision_training=enable_mixed_precision_training,
78
+ reduce_in_full_precision=reduce_in_full_precision,
79
+ mixed_precision_dtype=mixed_precision_dtype,
80
+ worker_init_fn=worker_init_fn,
81
+ )
82
+
83
+ # FSDP-Specific Parameters
84
+ if sharding_strategy == "shard-grad-op":
85
+ self.fsdp_sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2
86
+ elif sharding_strategy == "full-shard":
87
+ self.fsdp_sharding_strategy = ShardingStrategy.HYBRID_SHARD
88
+ else:
89
+ raise ValueError(f"FSDP Sharding Strategy {sharding_strategy} is not supported!")
90
+
91
+ assert state_dict_type == StateDictType.FULL_STATE_DICT, "Sharded state saving is not yet implemented!"
92
+ self.fsdp_state_dict_type = state_dict_type
93
+ self.fsdp_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
94
+
95
+ def save_checkpoint(
96
+ self,
97
+ run_dir: Path,
98
+ global_step: int,
99
+ epoch: int,
100
+ train_loss: Optional[float] = None,
101
+ only_trainable: bool = True,
102
+ ) -> None:
103
+ """Save a checkpoint to the `run_dir` only containing the state_dicts for trainable parameters by default."""
104
+ assert isinstance(self.vlm, FSDP), "FSDPStrategy.save_checkpoint assumes VLM is already wrapped in FSDP!"
105
+
106
+ # Summon Full State Dictionary =>> Reconstitute from Shards
107
+ with FSDP.state_dict_type(self.vlm, self.fsdp_state_dict_type, self.fsdp_save_policy):
108
+ full_vlm_state_dict = self.vlm.state_dict()
109
+ model_state_dicts = {
110
+ mkey: OrderedDict() for mkey in (self.trainable_module_keys if only_trainable else self.all_module_keys)
111
+ }
112
+
113
+ # Iterate through `full_vlm_state_dict` and split `mkey.{full_dotted_path}` -> `mkey: {full_dotted_path}`
114
+ for key, param in full_vlm_state_dict.items():
115
+ for mkey in model_state_dicts:
116
+ if key.startswith(mprefix := f"{mkey}."):
117
+ model_state_dicts[mkey][key.removeprefix(mprefix)] = param
118
+
119
+ # Save on rank zero *only*
120
+ if overwatch.is_rank_zero():
121
+ checkpoint_dir = run_dir / "checkpoints"
122
+ if train_loss is None:
123
+ checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss=inf.pt"
124
+ else:
125
+ checkpoint_path = (
126
+ checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss={train_loss:.4f}.pt"
127
+ )
128
+
129
+ # Save Checkpoint & Copy Latest to `latest-checkpoint.pt`
130
+ torch.save({"model": model_state_dicts}, checkpoint_path)
131
+
132
+ # TODO (siddk) :: This breaks w/ Sagemaker default permissions (root vs. <user>)... skip?
133
+ # shutil.copy(checkpoint_path, checkpoint_dir / "latest-checkpoint.pt")
134
+
135
+ def run_setup(self, run_dir: Path, n_train_examples: int) -> None:
136
+ # Iteratively Assemble FSDP Wrapping Policy by fetching the wrapping policies for each backbone/constituent
137
+ vlm_fsdp_wrapping_policy = self.vlm.get_fsdp_wrapping_policy()
138
+
139
+ # Assemble the Default FSDP Mixed Precision Policy
140
+ if self.enable_mixed_precision_training and self.mixed_precision_dtype == torch.bfloat16:
141
+ # MixedPrecision `param_dtype` specifies *compute* dtype (for forward/backward only)
142
+ # => Reference: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision
143
+ reduce_buffer_dtype = torch.bfloat16 if not self.reduce_in_full_precision else torch.float32
144
+ fsdp_precision_policy = MixedPrecision(
145
+ param_dtype=torch.bfloat16, reduce_dtype=reduce_buffer_dtype, buffer_dtype=reduce_buffer_dtype
146
+ )
147
+
148
+ # When running FSDP with a frozen vision backbone --> move to half precision!
149
+ if self.stage not in {"full-finetune", "vla-full-train", "vla-sandwich-train"}:
150
+ overwatch.info("Casting Vision Backbone to *Half Precision* via `.to(dtype=...)`")
151
+ self.vlm.vision_backbone.to(dtype=self.vlm.vision_backbone.half_precision_dtype)
152
+
153
+ else:
154
+ # If we're not using mixed precision, everything is in default full precision!
155
+ fsdp_precision_policy = MixedPrecision(
156
+ param_dtype=torch.float32, reduce_dtype=torch.float32, buffer_dtype=torch.float32
157
+ )
158
+
159
+ # <FSDP> => note that FSDP will automatically take care of device placement (similar to `autocast`)
160
+ self.vlm = FSDP(
161
+ self.vlm,
162
+ auto_wrap_policy=vlm_fsdp_wrapping_policy,
163
+ mixed_precision=fsdp_precision_policy,
164
+ sharding_strategy=self.fsdp_sharding_strategy,
165
+ device_id=torch.cuda.current_device(),
166
+ limit_all_gathers=True,
167
+ use_orig_params=True,
168
+ )
169
+
170
+ # Gradient Checkpoint Setup
171
+ if self.enable_gradient_checkpointing:
172
+ # For Gradient Checkpointing under FSDP --> we make the same assumption as in the DDP/other strategies; the
173
+ # bulk of activation memory is taken up by the LLM activations. However, unlike other strategies, we
174
+ # cannot rely on the HF Transformers default `gradient_checkpointing_enable()` --> FSDP breaks semantics!
175
+ #
176
+ # Instead, we need to write our own *NO-REENTRANT* wrapper, and apply it to the LLM's Transformer Layer.
177
+ non_reentrant_wrapper = partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT)
178
+
179
+ def check_fn(submodule: nn.Module) -> bool:
180
+ return isinstance(submodule, self.llm_transformer_layer_cls)
181
+
182
+ # Note that the terms "activation checkpointing" and "gradient checkpointing" are synonymous!
183
+ apply_activation_checkpointing(self.vlm, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn)
184
+
185
+ # Barrier =>> Sharding takes a minute?
186
+ dist.barrier()
187
+
188
+ # Create Optimizer and LR Scheduler =>> note that most of the LR Schedulers we use require `max_steps/epochs`
189
+ # => Optimizer should only operate on parameters that are *unfrozen* / trainable!
190
+ n_train_examples = math.ceil(n_train_examples / self.global_batch_size) * self.global_batch_size
191
+ if self.max_steps is None:
192
+ num_training_steps = (n_train_examples * self.epochs) // self.global_batch_size
193
+ else:
194
+ num_training_steps = self.max_steps
195
+
196
+ if self.lr_scheduler_type == "linear-warmup+cosine-decay":
197
+ # Set warmup steps (floor) based on `warmup_ratio` (should be 0.03 - 0.05)
198
+ num_warmup_steps = int(num_training_steps * self.warmup_ratio)
199
+
200
+ # Default AdamW w/ specified LR & Linear Warmup / Cosine Decay & Weight Decay
201
+ # => Create Parameter Groups --> bias terms, normalization layer parameters shouldn't be decayed!
202
+ decay, no_decay = [], []
203
+ for name, param in self.vlm.named_parameters():
204
+ if not param.requires_grad:
205
+ continue
206
+
207
+ # Check on any parameters with fewer than 2 dimensions or with "bias" in the name
208
+ if param.ndim <= 1 or name.endswith(".bias"):
209
+ no_decay.append(param)
210
+ else:
211
+ decay.append(param)
212
+
213
+ # Build Parameter Groups
214
+ groups = [{"params": decay, "weight_decay": self.weight_decay}, {"params": no_decay, "weight_decay": 0.0}]
215
+
216
+ # Create Optimizer & LR Scheduler
217
+ self.optimizer = AdamW(groups, lr=self.learning_rate)
218
+ self.lr_scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps, num_training_steps)
219
+ for param_group in self.optimizer.param_groups:
220
+ param_group["lr"] = 0.0
221
+
222
+ elif self.lr_scheduler_type == "constant":
223
+ num_warmup_steps = 0
224
+
225
+ # Default AdamW w/ specified LR & Linear Warmup / Cosine Decay & Weight Decay
226
+ # => Create Parameter Groups --> bias terms, normalization layer parameters shouldn't be decayed!
227
+ decay, no_decay = [], []
228
+ for name, param in self.vlm.named_parameters():
229
+ if not param.requires_grad:
230
+ continue
231
+
232
+ # Check on any parameters with fewer than 2 dimensions or with "bias" in the name
233
+ if param.ndim <= 1 or name.endswith(".bias"):
234
+ no_decay.append(param)
235
+ else:
236
+ decay.append(param)
237
+
238
+ # Build Parameter Groups
239
+ groups = [{"params": decay, "weight_decay": self.weight_decay}, {"params": no_decay, "weight_decay": 0.0}]
240
+
241
+ # Create Optimizer & LR Scheduler
242
+ self.optimizer = AdamW(groups, lr=self.learning_rate)
243
+ self.lr_scheduler = get_constant_schedule(self.optimizer)
244
+
245
+ else:
246
+ raise ValueError(f"Learning Rate Schedule with type `{self.lr_scheduler_type}` is not supported!")
247
+
248
+ # Finalize Setup =>> Log!
249
+ overwatch.info(
250
+ "FSDP Full-Shard Strategy =>> Finalized Training Setup:\n"
251
+ f" |-> Global (Effective) Batch Size = {self.global_batch_size}\n"
252
+ f" |-> Per-Device Batch Size = {self.per_device_batch_size}\n"
253
+ f" |-> Distributed World Size = {overwatch.world_size()}\n"
254
+ f" |-> Gradient Accumulation Steps = {self.grad_accumulation_steps}\n\n"
255
+ f" |-> LLM Backbone FSDP Gradient Checkpointing = {self.enable_gradient_checkpointing}\n"
256
+ f" |-> Use FSDP Mixed Precision = {self.enable_mixed_precision_training}\n"
257
+ f" |-> Parameter Precision = {fsdp_precision_policy.param_dtype}\n"
258
+ f" |-> Reduction Precision = {fsdp_precision_policy.reduce_dtype}\n"
259
+ f" |-> Buffer Precision = {fsdp_precision_policy.buffer_dtype}\n\n"
260
+ f" |-> Default AdamW LR = {self.learning_rate}\n"
261
+ f" |-> AdamW Weight Decay = {self.weight_decay}\n"
262
+ f" |-> LR Scheduler Type = {self.lr_scheduler_type}\n"
263
+ f" |-> LR Scheduler Warmup Steps (Ratio) = {num_warmup_steps} ({self.warmup_ratio})\n"
264
+ f" |-> Dataset Size = {n_train_examples} Examples\n"
265
+ f" |-> Max Steps = {num_training_steps}\n"
266
+ )
267
+
268
+ def clip_grad_norm(self) -> None:
269
+ # Note =>> FSDP uses a custom `clip_grad_norm_` function; requires *uniform grad dtype*
270
+ self.vlm.clip_grad_norm_(max_norm=self.max_grad_norm)
capvector-oft/prismatic/training/train_utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utils for training/fine-tuning scripts."""
2
+
3
+ import torch
4
+
5
+ from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX
6
+
7
+
8
+ def get_current_action_mask(token_ids):
9
+ # Create a tensor marking positions of IGNORE_INDEX
10
+ newline_positions = token_ids != IGNORE_INDEX
11
+
12
+ # Calculate cumulative sum to identify regions between newlines
13
+ cumsum = torch.cumsum(newline_positions, dim=1)
14
+
15
+ # Create the mask
16
+ mask = (1 <= cumsum) & (cumsum <= ACTION_DIM)
17
+
18
+ # Extract the action part only
19
+ action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX
20
+ mask = action_tokens_only_mask * mask
21
+
22
+ return mask
23
+
24
+
25
+ def get_next_actions_mask(token_ids):
26
+ # Create a tensor marking positions of IGNORE_INDEX
27
+ newline_positions = token_ids != IGNORE_INDEX
28
+
29
+ # Calculate cumulative sum to identify regions between newlines
30
+ cumsum = torch.cumsum(newline_positions, dim=1)
31
+
32
+ # Create the mask
33
+ mask = cumsum > ACTION_DIM
34
+
35
+ # Extract the action part only
36
+ action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX
37
+ mask = action_tokens_only_mask * mask
38
+
39
+ return mask
40
+
41
+
42
+ def compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask):
43
+ correct_preds = (predicted_token_ids == ground_truth_token_ids) & mask
44
+ accuracy = correct_preds.sum().float() / mask.sum().float()
45
+ return accuracy
46
+
47
+
48
+ def compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask):
49
+ pred_continuous_actions = torch.tensor(
50
+ action_tokenizer.decode_token_ids_to_actions(predicted_token_ids[mask].cpu().numpy())
51
+ )
52
+ true_continuous_actions = torch.tensor(
53
+ action_tokenizer.decode_token_ids_to_actions(ground_truth_token_ids[mask].cpu().numpy())
54
+ )
55
+ l1_loss = torch.nn.functional.l1_loss(pred_continuous_actions, true_continuous_actions)
56
+ return l1_loss
capvector-oft/prismatic/util/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .torch_utils import check_bloat16_supported, set_global_seed
capvector-oft/prismatic/util/batching_utils.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ batching_utils.py
3
+
4
+ Core definitions of (Distributed) Samplers for VLM finetuning; provides functionality for construction and allocating
5
+ "split-modality" batches as described in the LLaVa paper; this makes sure that a given device/batch is either entirely
6
+ (vision, language) or (language-only) data, which leads to sizeable efficiency gains.
7
+ """
8
+
9
+ import math
10
+ from typing import Iterator, List, Optional, Tuple
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.distributed as dist
15
+ from torch.utils.data import Dataset, Sampler
16
+
17
+
18
+ # High-Fidelity Bitwise Reproduction of the LLaVa Codebase Sampler Strategy + Per-Rank Allocation Scheme (following
19
+ # the default batching behavior of HF's Trainer Class --> derived from `accelerate`).
20
+ #
21
+ # =>> Reference: https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L60
22
+ # =>> Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L603
23
+ class SplitModalitySampler(Sampler):
24
+ def __init__(
25
+ self,
26
+ dataset: Dataset,
27
+ modality_lengths: List[Tuple[bool, int]],
28
+ global_batch_size: int,
29
+ num_replicas: Optional[int] = None,
30
+ rank: Optional[int] = None,
31
+ seed: int = 0,
32
+ drop_last: bool = False,
33
+ ) -> None:
34
+ super().__init__()
35
+ self.num_replicas = num_replicas if num_replicas is not None else dist.get_world_size()
36
+ self.rank = rank if rank is not None else dist.get_rank()
37
+ self.seed, self.epoch = seed, 0
38
+
39
+ # Custom Parameters
40
+ self.dataset, self.modality_lengths, self.drop_last = dataset, modality_lengths, drop_last
41
+ self.global_batch_size = global_batch_size
42
+
43
+ # For our purposes, `drop_last` is always False!
44
+ assert not self.drop_last, "SplitModalitySampler must set `drop_last = False`!"
45
+ self.total_size = math.ceil(len(self.dataset) / self.global_batch_size) * self.global_batch_size
46
+ self.num_samples = self.total_size // self.num_replicas
47
+
48
+ @staticmethod
49
+ def reindex_batch(batch_idxs: List[int], idx2lengths: List[int], n_buckets: int) -> List[List[int]]:
50
+ """Re-indexes a batch in a way that is conducive to DistributedSampler + grouping by seqlen per rank."""
51
+ assert len(batch_idxs) % n_buckets == 0, "Batch length is not divisible by `num_replicas`!"
52
+
53
+ # Establish initial buckets, capacities, and max number of elements per bucket
54
+ n_examples_per_bucket = len(batch_idxs) // n_buckets
55
+ bucket_indices = [[] for _ in range(n_buckets)]
56
+ bucket_lengths = [0 for _ in range(n_buckets)]
57
+
58
+ # Note that `batch_idxs` is already sorted by corresponding length (in descending order)
59
+ for idx in batch_idxs:
60
+ shortest_bucket_idx = bucket_lengths.index(min(bucket_lengths))
61
+ bucket_indices[shortest_bucket_idx].append(idx)
62
+
63
+ # Update `bucket_lengths` --> set length to infinity if at capacity!
64
+ bucket_lengths[shortest_bucket_idx] += idx2lengths[idx]
65
+ if len(bucket_indices[shortest_bucket_idx]) == n_examples_per_bucket:
66
+ bucket_lengths[shortest_bucket_idx] = float("inf")
67
+
68
+ return bucket_indices
69
+
70
+ def get_modality_and_length_grouped_indices(self, generator: torch.Generator) -> List[int]:
71
+ """
72
+ Returns a list of indices so that each slice of `global_batch_size` consecutive indices corresponds to elements
73
+ of the same modality with each sub-sequence of `per_replica_batch_size` (the batch size each unique device sees
74
+ during distributed training) is roughly grouped by sequence length (for training efficiency).
75
+ """
76
+ multimodal_indices, multimodal_lengths = zip(
77
+ *[(idx, length) for idx, (is_multimodal, length) in enumerate(self.modality_lengths) if is_multimodal]
78
+ )
79
+
80
+ # Handle Special Case --> no "unimodal" inputs
81
+ unimodal_split = [
82
+ (idx, length) for idx, (is_multimodal, length) in enumerate(self.modality_lengths) if not is_multimodal
83
+ ]
84
+ if len(unimodal_split) == 0:
85
+ unimodal_indices, unimodal_lengths = [], []
86
+ else:
87
+ unimodal_indices, unimodal_lengths = zip(*unimodal_split)
88
+
89
+ # Create a permutation of indices for each of the multimodal and unimodal data
90
+ mm_shuffled_idxs = torch.randperm(len(multimodal_indices), generator=generator)
91
+ uni_shuffled_idxs = torch.randperm(len(unimodal_indices), generator=generator)
92
+
93
+ # We're going to be running sorting/grouping relative to `self.global_batch_size` and `self.num_replicas`
94
+ g_bsz = self.global_batch_size
95
+
96
+ # Break each of the permutations into batches of length `global_batch_size`
97
+ mm_batch_idxs = [mm_shuffled_idxs[i : i + g_bsz].tolist() for i in range(0, len(mm_shuffled_idxs), g_bsz)]
98
+ uni_batch_idxs = [uni_shuffled_idxs[i : i + g_bsz].tolist() for i in range(0, len(uni_shuffled_idxs), g_bsz)]
99
+
100
+ # If "last" batch is not of length `g_bsz` --> PAD by stealing indices from the first batch!
101
+ if len(mm_batch_idxs[-1]) < g_bsz:
102
+ n_missing = g_bsz - len(mm_batch_idxs[-1])
103
+ mm_batch_idxs[-1].extend(mm_batch_idxs[0][:n_missing])
104
+
105
+ if len(uni_batch_idxs) > 0 and len(uni_batch_idxs[-1]) < g_bsz:
106
+ n_missing = g_bsz - len(uni_batch_idxs[-1])
107
+ uni_batch_idxs[-1].extend(uni_batch_idxs[0][:n_missing])
108
+
109
+ # Now we're going to sort each batch by length --> this will aid in grouping by length by rank (efficiency!)
110
+ mm_sorted_batch_idxs = [sorted(b, key=lambda i: multimodal_lengths[i], reverse=True) for b in mm_batch_idxs]
111
+ uni_sorted_batch_idxs = [sorted(b, key=lambda i: unimodal_lengths[i], reverse=True) for b in uni_batch_idxs]
112
+
113
+ # IMPORTANT :: At this point, for each modality, we have a list of "batches" (made up of indices) where indices
114
+ # are sorted by example sequence length *within* each batch. To make this more concrete, consider the following:
115
+ # => World Size (`num_replicas`) = 2
116
+ # => Global Batch Size (`g_bsz`) = 4
117
+ # => `multimodal_indices` = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
118
+ # `multimodal_lengths` = [20, 90, 21, 22, 91, 18, 89, 19, 93, 88, 92, 17]
119
+ #
120
+ # At this point in the code, `mm_sorted_batch_idxs` might then look like the following (length in parenthesis):
121
+ # => `mm_sorted_batch_idxs`: [
122
+ # [4 (91), 3 (21), 0 (20), 5 (18)] => Batch 1
123
+ # [6 (89), 9 (88), 7 (19), 11 (17)] => Batch 2
124
+ # [8 (93), 10 (92), 1 (90), 2 (21)] => Batch 3
125
+ # ]
126
+ #
127
+ # In practice: `g_bsz` is large (= 128), and for contiguous mini-batch "slices", length variance is low.
128
+
129
+ # PROBLEM :: We want to split these "global batches" into equal-sized pieces, so that each "replica" (GPU)
130
+ # sees a "mini-batch" of roughly the same sequence lengths; this is super useful for efficient training.
131
+
132
+ # HOWEVER :: The default "access pattern" for splitting a large batch into mini-batches by a DistributedSampler
133
+ # is akin to a "take every k" where `k` is equal to the number of replicas (GPUs) you're training on. Or, in
134
+ # Python notation --> `rank_k_indices = flatten(mm_sorted_batch_idxs)[k::num_replicas].
135
+ #
136
+ # Naively translating this our example means each GPU (in our world of 2 total) sees the following indices
137
+ # (grouped by "mini-batch" = `g_bsz / num_replicas` = 2 for convenience):
138
+ # => `rank_0_indices`: [ [4 (91), 0 (20)] =>> [6 (89), 7 (19)] =>> [8 (93), 1 (90)] ]
139
+ # => `rank_1_indices`: [ [3 (21), 5 (18)] =>> [9 (88), 11 (17)] =>> [10 (92), 2 (21)] ]
140
+ #
141
+ # We get lucky sometimes, but for the most part, each "mini-batch" has VASTLY DIFFERENT lengths! Bad!
142
+
143
+ # FIX :: If we "undo" the access pattern with the following code and re-arrange the way we allocate batches
144
+ # inside the __iter__ method below, we can allocate indices appropriately. Running the following code gives us
145
+ # the following indices (grouped by "mini-batch" again for convenience):
146
+ # => `rank_0_indices`: [ [4 (91), 3 (21)] =>> [6 (89), 9 (88)] =>> [8 (93), 10 (92)] ]
147
+ # => `rank_1_indices`: [ [5 (18), 0 (20)] =>> [11 (17), 7 (19)] =>> [2 (21), 1 (90)] ]
148
+ #
149
+ # Much better! As `g_bsz` and `dataset` grow, we're more often than not getting *decent* groupings!
150
+ mm_length_bucketed_idxs = [
151
+ self.reindex_batch(batch, multimodal_lengths, self.num_replicas) for batch in mm_sorted_batch_idxs
152
+ ]
153
+ uni_length_bucketed_idxs = [
154
+ self.reindex_batch(batch, unimodal_lengths, self.num_replicas) for batch in uni_sorted_batch_idxs
155
+ ]
156
+
157
+ # Note :: Because of the initial `randperm` --> we're indexing both sets from 0 (we're clobbering the range)
158
+ # => Flatten indices --> index into original `{modality}_indices` then re-batch!
159
+ mm_output_idxs = [idx for batch in mm_length_bucketed_idxs for bucket in batch for idx in bucket]
160
+ mm_reindexed = [multimodal_indices[idx] for idx in mm_output_idxs]
161
+ mm_batches = [mm_reindexed[i : i + g_bsz] for i in range(0, len(mm_reindexed), g_bsz)]
162
+
163
+ uni_output_idxs = [idx for batch in uni_length_bucketed_idxs for bucket in batch for idx in bucket]
164
+ uni_reindexed = [unimodal_indices[idx] for idx in uni_output_idxs]
165
+ uni_batches = [uni_reindexed[i : i + g_bsz] for i in range(0, len(uni_reindexed), g_bsz)]
166
+
167
+ # Finally, randomly permute the multimodal & unimodal batches, merging into a single stream of indices
168
+ merged_batches = mm_batches + uni_batches
169
+ merge_idxs = torch.randperm(len(merged_batches), generator=generator)
170
+ all_batches = [merged_batches[idx] for idx in merge_idxs]
171
+
172
+ # [Quality of Life] Shift "max length" batch to index 0 --> if we OOM, it happens immediately!
173
+ all_lengths = [length + ((_n_patches := 24 * 24) if is_mm else 0) for is_mm, length in self.modality_lengths]
174
+ all_batches_max_lengths = []
175
+ for batch in all_batches:
176
+ all_batches_max_lengths.append(max([all_lengths[idx] for idx in batch]))
177
+
178
+ # Identify Batch with "max length" --> Swap into Index 0
179
+ longest_batch_idx = np.argmax(all_batches_max_lengths)
180
+ all_batches[0], all_batches[longest_batch_idx] = all_batches[longest_batch_idx], all_batches[0]
181
+
182
+ # Flatten & Return all Indices
183
+ indices = [idx for batch in all_batches for idx in batch]
184
+ return indices
185
+
186
+ def __iter__(self) -> Iterator:
187
+ """Deterministically shuffle, then split indices by modality and length."""
188
+ g = torch.Generator()
189
+ g.manual_seed(self.seed + self.epoch)
190
+ indices = self.get_modality_and_length_grouped_indices(g)
191
+ assert len(set(indices)) == len(self.modality_lengths) == len(self.dataset), "Oops!"
192
+ assert (len(indices) % self.global_batch_size == 0) and (len(indices) % self.num_replicas) == 0, "Oops"
193
+
194
+ # Note :: We compute per-replica batch size as a function of `global_batch` and `num_replicas` to ensure that
195
+ # gradient accumulation doesn't affect what indices are assigned a given rank.
196
+ per_replica_batch_size = self.global_batch_size // self.num_replicas
197
+
198
+ # Tensorize & Unravel --> rather than yielding via a `take_every` --> we want to partition a global batch
199
+ # across replicas by assigning each a contiguous sub-sequence.
200
+ indices_t = torch.as_tensor(indices)
201
+ per_replica_batch_indices_t = indices_t.reshape(-1, per_replica_batch_size)
202
+ replica_indices_t = per_replica_batch_indices_t[self.rank :: self.num_replicas]
203
+
204
+ replica_indices = replica_indices_t.flatten().tolist()
205
+ return iter(replica_indices)
206
+
207
+ def __len__(self) -> int:
208
+ return self.num_samples
209
+
210
+ def set_epoch(self, epoch: int) -> None:
211
+ """To be called *between* epochs, prior to DataLoader instantiation; ensures random order across epochs."""
212
+ self.epoch = epoch
capvector-oft/prismatic/util/data_utils.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ data_utils.py
3
+
4
+ General utilities and classes for facilitating data loading and collation.
5
+ """
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Callable, Dict, Sequence, Tuple
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torch.nn.utils.rnn import pad_sequence
13
+
14
+ # HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels)
15
+ IGNORE_INDEX = -100
16
+
17
+
18
+ def tree_map(fn: Callable, tree: dict) -> dict:
19
+ """Maps a function over a nested dictionary."""
20
+ return {k: tree_map(fn, v) if isinstance(v, dict) else fn(v) for k, v in tree.items()}
21
+
22
+
23
+ def tree_map_with_key(fn: Callable, tree: dict, keys: Sequence = ()) -> dict:
24
+ """Maps a function over a nested dictionary."""
25
+ return {
26
+ k: tree_map_with_key(fn, v, (*keys, k)) if isinstance(v, dict) else fn((*keys, k), v) for k, v in tree.items()
27
+ }
28
+
29
+
30
+ @dataclass
31
+ class PaddedCollatorForLanguageModeling:
32
+ model_max_length: int
33
+ pad_token_id: int
34
+ default_image_resolution: Tuple[int, int, int]
35
+ padding_side: str = "right"
36
+ pixel_values_dtype: torch.dtype = torch.float32
37
+
38
+ def __post_init__(self) -> None:
39
+ self.dummy_pixel_values = torch.zeros(self.default_image_resolution, dtype=self.pixel_values_dtype)
40
+
41
+ def __call__(self, instances: Sequence[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
42
+ input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
43
+ pixel_values = [instance["pixel_values"] for instance in instances]
44
+
45
+ # For now, we only support Tokenizers with `padding_side = "right"` during Training (but plan to extend!)
46
+ # => Handle padding via RNN Utils => `pad_sequence`
47
+ input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id)
48
+ labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
49
+
50
+ # Truncate (if necessary)
51
+ input_ids, labels = input_ids[:, : self.model_max_length], labels[:, : self.model_max_length]
52
+
53
+ # Get `attention_mask` by checking for `pad_token_id`
54
+ attention_mask = input_ids.ne(self.pad_token_id)
55
+
56
+ # === Handle "unimodal" (language-only) vs. "multimodal" ===
57
+
58
+ # Some examples are "language-only" --> build a Tensor of `multimodal_indices` that we can slice into easily
59
+ multimodal_indices = torch.tensor(
60
+ [idx for idx in range(len(pixel_values)) if pixel_values[idx] is not None], dtype=torch.long
61
+ )
62
+
63
+ # Stack all `pixel_values` --> depending on type (torch.Tensor, or Dict[str, torch.Tensor]) & presence of None
64
+ if len(multimodal_indices) == 0:
65
+ pixel_values = torch.stack([self.dummy_pixel_values for _ in range(len(input_ids))])
66
+ elif isinstance(pv_example := pixel_values[multimodal_indices[0]], torch.Tensor):
67
+ pixel_values = torch.stack(
68
+ [
69
+ pixel_values[idx] if idx in multimodal_indices else self.dummy_pixel_values
70
+ for idx in range(len(input_ids))
71
+ ]
72
+ )
73
+ elif isinstance(pv_example, dict):
74
+ pixel_values = {
75
+ k: torch.stack(
76
+ [
77
+ pixel_values[idx][k] if idx in multimodal_indices else self.dummy_pixel_values
78
+ for idx in range(len(input_ids))
79
+ ]
80
+ )
81
+ for k in pv_example
82
+ }
83
+ else:
84
+ raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}")
85
+
86
+ return dict(
87
+ pixel_values=pixel_values,
88
+ input_ids=input_ids,
89
+ attention_mask=attention_mask,
90
+ labels=labels,
91
+ multimodal_indices=multimodal_indices,
92
+ )
93
+
94
+
95
+ @dataclass
96
+ class PaddedCollatorForActionPrediction:
97
+ model_max_length: int
98
+ pad_token_id: int
99
+ padding_side: str = "right"
100
+ pixel_values_dtype: torch.dtype = torch.float32
101
+
102
+ def __call__(self, instances: Sequence[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
103
+ input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
104
+ pixel_values = [instance["pixel_values"] for instance in instances]
105
+ if "dataset_name" in instances[0]:
106
+ dataset_names = [instance["dataset_name"] for instance in instances]
107
+ else:
108
+ dataset_names = None
109
+
110
+ # For now, we only support Tokenizers with `padding_side = "right"` during training
111
+ # => Handle padding via RNN Utils => `pad_sequence`
112
+ assert self.padding_side == "right", f"Invalid Tokenizer `{self.padding_side = }`"
113
+ input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id)
114
+ labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
115
+
116
+ # Truncate (if necessary)
117
+ input_ids, labels = input_ids[:, : self.model_max_length], labels[:, : self.model_max_length]
118
+
119
+ # Get `attention_mask` by checking for `pad_token_id`
120
+ attention_mask = input_ids.ne(self.pad_token_id)
121
+
122
+ # [Contract] For VLA Training =>> No "Unimodal" Data!
123
+ assert all([pv is not None for pv in pixel_values]), "Invalid VLA Example with `pixel_values = None`!"
124
+
125
+ # Stack all `pixel_values` --> depending on type is torch.Tensor or Dict[str, torch.Tensor]
126
+ if isinstance(pixel_values[0], torch.Tensor):
127
+ if "pixel_values_wrist" in instances[0]:
128
+ pixel_values_wrist = [instance["pixel_values_wrist"] for instance in instances]
129
+ pixel_values = torch.cat((torch.stack(pixel_values), torch.stack(pixel_values_wrist)), dim=1)
130
+ else:
131
+ pixel_values = torch.stack(pixel_values)
132
+ else:
133
+ raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}")
134
+
135
+ # Stack all actions
136
+ actions = [torch.from_numpy(np.copy(instance["actions"])) for instance in instances]
137
+ actions = torch.stack(actions)
138
+
139
+ # Stack proprio
140
+ if "proprio" in instances[0]:
141
+ proprio = [instance["proprio"] for instance in instances]
142
+ proprio = torch.Tensor(np.squeeze(np.stack(proprio)))
143
+ else:
144
+ proprio = None
145
+
146
+ output = dict(
147
+ pixel_values=pixel_values,
148
+ proprio=proprio,
149
+ input_ids=input_ids,
150
+ attention_mask=attention_mask,
151
+ labels=labels,
152
+ actions=actions,
153
+ )
154
+ if dataset_names is not None:
155
+ output["dataset_names"] = dataset_names
156
+ return output
capvector-oft/prismatic/util/nn_utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ nn_utils.py
3
+
4
+ Utility functions and PyTorch submodule definitions.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ # === Definitions for Various Projection Modules, with Signature :: [..., in_dim] --> [..., out_dim] ===
12
+ class LinearProjector(nn.Module):
13
+ def __init__(self, vision_dim: int, llm_dim: int) -> None:
14
+ super().__init__()
15
+ self.projector = nn.Linear(vision_dim, llm_dim, bias=True)
16
+
17
+ def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
18
+ return self.projector(img_patches)
19
+
20
+
21
+ class MLPProjector(nn.Module):
22
+ def __init__(self, vision_dim: int, llm_dim: int, mlp_type: str = "gelu-mlp") -> None:
23
+ super().__init__()
24
+ if mlp_type == "gelu-mlp":
25
+ self.projector = nn.Sequential(
26
+ nn.Linear(vision_dim, llm_dim, bias=True),
27
+ nn.GELU(),
28
+ nn.Linear(llm_dim, llm_dim, bias=True),
29
+ )
30
+ else:
31
+ raise ValueError(f"Projector with `{mlp_type = }` is not supported!")
32
+
33
+ def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
34
+ return self.projector(img_patches)
35
+
36
+
37
+ class FusedMLPProjector(nn.Module):
38
+ def __init__(self, fused_vision_dim: int, llm_dim: int, mlp_type: str = "fused-gelu-mlp") -> None:
39
+ super().__init__()
40
+ self.initial_projection_dim = fused_vision_dim * 4
41
+ if mlp_type == "fused-gelu-mlp":
42
+ self.projector = nn.Sequential(
43
+ nn.Linear(fused_vision_dim, self.initial_projection_dim, bias=True),
44
+ nn.GELU(),
45
+ nn.Linear(self.initial_projection_dim, llm_dim, bias=True),
46
+ nn.GELU(),
47
+ nn.Linear(llm_dim, llm_dim, bias=True),
48
+ )
49
+ else:
50
+ raise ValueError(f"Fused Projector with `{mlp_type = }` is not supported!")
51
+
52
+ def forward(self, fused_img_patches: torch.Tensor) -> torch.Tensor:
53
+ return self.projector(fused_img_patches)
capvector-oft/prismatic/util/torch_utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ torch_utils.py
3
+
4
+ General utilities for randomness, mixed precision training, and miscellaneous checks in PyTorch.
5
+
6
+ Random `set_global_seed` functionality is taken directly from PyTorch-Lighting:
7
+ > Ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/seed.py
8
+
9
+ This is pretty important to get right if we're every randomly generating our masks (or prefix dropout) inside our
10
+ Dataset __getitem__() with multiple workers... if not handled properly, we will get repeated augmentations anytime
11
+ we inject randomness from non-PyTorch sources (e.g., numpy, random)!
12
+ > Ref: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/
13
+
14
+ Terminology
15
+ -> World Size :: Total number of processes distributed over (# nodes x # devices) -- assumed homogenous!
16
+ -> Rank :: Integer index of current process in the total world size
17
+ -> Local Rank :: Local index on given node in [0, Devices per Node]
18
+ """
19
+
20
+ import os
21
+ import random
22
+ from typing import Callable, Optional
23
+
24
+ import numpy as np
25
+ import torch
26
+
27
+ # === Randomness ===
28
+
29
+
30
+ def set_global_seed(seed: int, get_worker_init_fn: bool = False) -> Optional[Callable[[int], None]]:
31
+ """Sets seed for all randomness libraries (mostly random, numpy, torch) and produces a `worker_init_fn`"""
32
+ assert np.iinfo(np.uint32).min < seed < np.iinfo(np.uint32).max, "Seed outside the np.uint32 bounds!"
33
+
34
+ # Set Seed as an Environment Variable
35
+ os.environ["EXPERIMENT_GLOBAL_SEED"] = str(seed)
36
+ random.seed(seed)
37
+ np.random.seed(seed)
38
+ torch.manual_seed(seed)
39
+
40
+ return worker_init_function if get_worker_init_fn else None
41
+
42
+
43
+ def worker_init_function(worker_id: int) -> None:
44
+ """
45
+ Borrowed directly from PyTorch-Lightning; inspired by this issue comment in the PyTorch repo:
46
+ > Ref: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
47
+
48
+ Intuition: You can think of the seed sequence spawn function as a "janky" torch.Generator() or jax.PRNGKey that
49
+ you can run iterative splitting on to get new (predictable) randomness.
50
+
51
+ :param worker_id: Identifier for the given worker [0, num_workers) for the Dataloader in question.
52
+ """
53
+ # Get current `rank` (if running distributed) and `process_seed`
54
+ global_rank, process_seed = int(os.environ["LOCAL_RANK"]), torch.initial_seed()
55
+
56
+ # Back out the "base" (original) seed - the per-worker seed is set in PyTorch:
57
+ # > https://pytorch.org/docs/stable/data.html#data-loading-randomness
58
+ base_seed = process_seed - worker_id
59
+
60
+ # "Magic" code --> basically creates a seed sequence that mixes different "sources" and seeds every library...
61
+ seed_seq = np.random.SeedSequence([base_seed, worker_id, global_rank])
62
+
63
+ # Use 128 bits (4 x 32-bit words) to represent seed --> generate_state(k) produces a `k` element array!
64
+ np.random.seed(seed_seq.generate_state(4))
65
+
66
+ # Spawn distinct child sequences for PyTorch (reseed) and stdlib random
67
+ torch_seed_seq, random_seed_seq = seed_seq.spawn(2)
68
+
69
+ # Torch Manual seed takes 64 bits (so just specify a dtype of uint64
70
+ torch.manual_seed(torch_seed_seq.generate_state(1, dtype=np.uint64)[0])
71
+
72
+ # Use 128 Bits for `random`, but express as integer instead of as an array
73
+ random_seed = (random_seed_seq.generate_state(2, dtype=np.uint64).astype(list) * [1 << 64, 1]).sum()
74
+ random.seed(random_seed)
75
+
76
+
77
+ # === BFloat16 Support ===
78
+
79
+
80
+ def check_bloat16_supported() -> bool:
81
+ try:
82
+ import packaging.version
83
+ import torch.cuda.nccl as nccl
84
+ import torch.distributed as dist
85
+
86
+ return (
87
+ (torch.version.cuda is not None)
88
+ and torch.cuda.is_bf16_supported()
89
+ and (packaging.version.parse(torch.version.cuda).release >= (11, 0))
90
+ and dist.is_nccl_available()
91
+ and (nccl.version() >= (2, 10))
92
+ )
93
+
94
+ except Exception:
95
+ return False