diff --git a/capvector-oft/prismatic/models/__init__.py b/capvector-oft/prismatic/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..85a3ebb94a024811e1e567cab3d80d805f9a48f5 --- /dev/null +++ b/capvector-oft/prismatic/models/__init__.py @@ -0,0 +1,2 @@ +from .load import available_model_names, available_models, get_model_description, load, load_vla +from .materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform, get_vlm diff --git a/capvector-oft/prismatic/models/backbones/__init__.py b/capvector-oft/prismatic/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/capvector-oft/prismatic/models/backbones/llm/__init__.py b/capvector-oft/prismatic/models/backbones/llm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a040f37d9c4f1e5354b74b6e24483c76fc11bf2c --- /dev/null +++ b/capvector-oft/prismatic/models/backbones/llm/__init__.py @@ -0,0 +1,4 @@ +from .base_llm import LLMBackbone +from .llama2 import LLaMa2LLMBackbone +from .mistral import MistralLLMBackbone +from .phi import PhiLLMBackbone diff --git a/capvector-oft/prismatic/models/backbones/llm/base_llm.py b/capvector-oft/prismatic/models/backbones/llm/base_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..fab6d971e71f8251421c91fbcb009f6c882111ae --- /dev/null +++ b/capvector-oft/prismatic/models/backbones/llm/base_llm.py @@ -0,0 +1,223 @@ +""" +base_llm.py + +Abstract class definition of a large (autoregressive) language model backbone (LLM), with full annotations of class +methods, utility functions, and initialization logic. + +We also define the generic HFLLMBackbone class here, providing a default interface for loading any HF +AutoModelForCausalLM (e.g., LLamaForCausalLM). In general, we make the assumption that any given LLM backbone implements +the AutoModelForCausalLM API (though we may add Seq2Seq models in the future). + +We make this assumption to keep the LLM handling in this codebase relatively lightweight, and to inherit all the nice HF +utilities around different types of decoding/generation strategies. +""" + +import warnings +from abc import ABC, abstractmethod +from functools import partial +from typing import Callable, List, Optional, Sequence, Type + +import torch +import torch.nn as nn +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from transformers import AutoConfig, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase +from transformers.modeling_outputs import CausalLMOutputWithPast + +from prismatic.models.backbones.llm.prompting import PromptBuilder +from prismatic.overwatch import initialize_overwatch + +# Suppress HF Deprecation Warnings +warnings.filterwarnings("ignore", category=FutureWarning) + +# Initialize Overwatch =>> Wraps `logging.Logger` +overwatch = initialize_overwatch(__name__) + + +# === Abstract Base Class for arbitrary HF LLM Backbones === +class LLMBackbone(nn.Module, ABC): + def __init__(self, llm_backbone_id: str) -> None: + super().__init__() + self.identifier = llm_backbone_id + + # Instance attributes for an LLM Backbone + self.llm: PreTrainedModel = None + self.tokenizer: PreTrainedTokenizerBase = None + + def get_tokenizer(self) -> PreTrainedTokenizerBase: + return self.tokenizer + + @abstractmethod + def get_fsdp_wrapping_policy(self) -> Callable: ... + + @abstractmethod + def enable_gradient_checkpointing(self) -> None: ... + + @abstractmethod + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> CausalLMOutputWithPast: + """Run a forward pass through the LLM given targets (labels), returning the scalar Cross-Entropy Loss""" + raise NotImplementedError + + @abstractmethod + def embed_input_ids(self, input_ids: torch.LongTensor) -> torch.Tensor: ... + + @property + @abstractmethod + def prompt_builder_fn(self) -> Type[PromptBuilder]: ... + + @property + @abstractmethod + def transformer_layer_cls(self) -> Type[nn.Module]: ... + + @property + @abstractmethod + def half_precision_dtype(self) -> torch.dtype: ... + + @property + @abstractmethod + def last_layer_finetune_modules(self) -> Sequence[nn.Module]: ... + + @property + def embed_dim(self) -> int: + return self.llm.config.hidden_size + + @property + def pad_token_id(self) -> int: + return self.tokenizer.pad_token_id + + +# === Abstract Base Class for Arbitrary HF Causal LLMs === +class HFCausalLLMBackbone(LLMBackbone, ABC): + def __init__( + self, + llm_backbone_id: str, + llm_family: str, + llm_cls: Type[PreTrainedModel], + hf_hub_path: str, + llm_max_length: int = 2048, + hf_token: Optional[str] = None, + inference_mode: bool = False, + use_flash_attention_2: bool = False, + ) -> None: + super().__init__(llm_backbone_id) + self.llm_family = llm_family + self.llm_max_length = llm_max_length + self.inference_mode = inference_mode + + # Initialize LLM (downloading from HF Hub if necessary) --> `llm_cls` is the actual {Model}ForCausalLM class! + # => Note: We're eschewing use of the AutoModel API so that we can be more explicit about LLM-specific details + if not self.inference_mode: + overwatch.info(f"Loading [bold]{llm_family}[/] LLM from [underline]`{hf_hub_path}`[/]", ctx_level=1) + self.llm = llm_cls.from_pretrained( + hf_hub_path, + token=hf_token, + use_flash_attention_2=use_flash_attention_2 if not self.inference_mode else False, + # The following parameters are set to prevent `UserWarnings` from HF; we want greedy decoding! + do_sample=False, + temperature=1.0, + top_p=1.0, + ) + + # [Contract] `inference_mode` means we're loading from a pretrained checkpoint; no need to load base weights! + else: + overwatch.info(f"Building empty [bold]{llm_family}[/] LLM from [underline]`{hf_hub_path}`[/]", ctx_level=1) + llm_config = AutoConfig.from_pretrained(hf_hub_path, token=hf_token) + self.llm = llm_cls._from_config(llm_config) + + # Lightweight Handling (with extended explanation) for setting some LLM Parameters + # => Set `decoder.use_cache = False` --> incompatible with gradient checkpointing (+ training in general) + # + # Reference: https://discuss.huggingface.co/t/what-is-the-purpose-of-use-cache-in-decoder/958 + self.llm.config.use_cache = False if not self.inference_mode else True + + # => Turns out that when gradient checkpointing is on and the underlying LLM has no "trainable" parameters + # (requires_grad is False), backprop will fail; setting `enable_input_requires_grad()` registers a new + # forward hook that fixes this =>> also totally safe for the "full finetuning" setting! + if not self.inference_mode: + self.llm.enable_input_require_grads() + + # Load (Fast) Tokenizer + overwatch.info(f"Loading [bold]{llm_family}[/] (Fast) Tokenizer via the AutoTokenizer API", ctx_level=1) + self.tokenizer = AutoTokenizer.from_pretrained( + hf_hub_path, model_max_length=self.llm_max_length, token=hf_token, padding_side="right" + ) + + # Validation =>> Our VLM logic currently operates under the assumption that the tokenization of a new input + # starts with a token unless `add_special_tokens = False`; for these models, we empirically + # find that adding image patches *after* the BOS leads to much better performance. + # + # As a result we explicitly validate that a tokenizer conforms to the expected behavior; if you're reading this + # line, it's probably because you're adding a new LLM with a different tokenizer behavior. If so, feel free to + # override the `SPECIAL_CASES` set below, but make sure to make the appropriate changes in the `datasets.py` + # and VLM `forward()` logic! + SPECIAL_CASES = { + # Phi-2 Tokenizer doesn't add any BOS tokens by default, and sets BOS == EOS == "<|endoftext|>" + # =>> We'll prepend BOS to first input (to play nicely with image token insertion logic; verified that + # this works well with base LLM generation. + # =>> Like Llama-2 Tokenizers -- we'll add a special PAD token for training purposes. + "phi-2-3b", + } + if self.identifier in SPECIAL_CASES: + return + + # Note =>> this assert should hold for all Llama-derived tokenizers (`LlamaTokenizerFast` ==> includes Mistral! + assert (self.tokenizer("Test 123", add_special_tokens=True).input_ids[0] == self.tokenizer.bos_token_id) and ( + self.tokenizer("Test 123", add_special_tokens=False).input_ids[0] != self.tokenizer.bos_token_id + ), ( + f"Default Tokenizer of type `{type(self.tokenizer)}` does not automatically prefix inputs with BOS token!\n" + "Please read the comment in `base_llm.py` for more information!" + ) + + def get_fsdp_wrapping_policy(self) -> Callable: + """Return a `transformer_auto_wrap_policy` where we wrap each instance of `self.transformer_layer_cls`""" + transformer_block_policy = partial( + transformer_auto_wrap_policy, transformer_layer_cls={self.transformer_layer_cls} + ) + + return transformer_block_policy + + def enable_gradient_checkpointing(self) -> None: + """Dispatch to underlying LLM instance's `gradient_checkpointing_enable`; defined for all `PretrainedModel`.""" + self.llm.gradient_checkpointing_enable() + + def embed_input_ids(self, input_ids: torch.LongTensor) -> torch.Tensor: + return self.llm.get_input_embeddings()(input_ids) + + # [Contract] Should match the `forward` call of the underlying `llm` instance! + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> CausalLMOutputWithPast: + output: CausalLMOutputWithPast = self.llm( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return output diff --git a/capvector-oft/prismatic/models/backbones/llm/llama2.py b/capvector-oft/prismatic/models/backbones/llm/llama2.py new file mode 100644 index 0000000000000000000000000000000000000000..559409e2e54d7c0c4d06f2691d104ffba991b48f --- /dev/null +++ b/capvector-oft/prismatic/models/backbones/llm/llama2.py @@ -0,0 +1,102 @@ +""" +llama2.py + +Class definition for all LLMs derived from LlamaForCausalLM. +""" + +from typing import Optional, Sequence, Type + +import torch +from torch import nn as nn +from transformers import LlamaForCausalLM +from transformers.models.llama.modeling_llama import LlamaDecoderLayer + +from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone +from prismatic.models.backbones.llm.prompting import ( + LLaMa2ChatPromptBuilder, + PromptBuilder, + PurePromptBuilder, + VicunaV15ChatPromptBuilder, +) + +# Registry =>> Support LLaMa-2 Models (from HF Transformers) +# fmt: off +LLAMA2_MODELS = { + # === Pure Meta LLaMa-2 (non-instruct/chat-tuned) Models === + "llama2-7b-pure": { + "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-7b-hf" + }, + + "llama2-13b-pure": { + "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-13b-hf" + }, + + # === Meta LLaMa-2 Chat Models === + "llama2-7b-chat": { + "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-7b-chat-hf" + }, + + "llama2-13b-chat": { + "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-13b-chat-hf" + }, + + # === Vicuna v1.5 Chat Models === + "vicuna-v15-7b": { + "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "lmsys/vicuna-7b-v1.5" + }, + + "vicuna-v15-13b": { + "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "lmsys/vicuna-13b-v1.5" + }, +} +# fmt: on + + +class LLaMa2LLMBackbone(HFCausalLLMBackbone): + def __init__( + self, + llm_backbone_id: str, + llm_max_length: int = 2048, + hf_token: Optional[str] = None, + inference_mode: bool = False, + use_flash_attention_2: bool = True, + ) -> None: + super().__init__( + llm_backbone_id, + llm_max_length=llm_max_length, + hf_token=hf_token, + inference_mode=inference_mode, + use_flash_attention_2=use_flash_attention_2, + **LLAMA2_MODELS[llm_backbone_id], + ) + + # [Special Case] LLaMa-2 PAD Token Handling --> for clarity, we add an extra token (and resize) + self.tokenizer.add_special_tokens({"pad_token": ""}) + self.llm.config.pad_token_id = self.tokenizer.pad_token_id + self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) + + @property + def prompt_builder_fn(self) -> Type[PromptBuilder]: + if self.identifier.startswith("llama2-") and self.identifier.endswith("-pure"): + return PurePromptBuilder + + elif self.identifier.startswith("llama2-") and self.identifier.endswith("-chat"): + return LLaMa2ChatPromptBuilder + + elif self.identifier.startswith("vicuna"): + return VicunaV15ChatPromptBuilder + + raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") + + @property + def transformer_layer_cls(self) -> Type[nn.Module]: + return LlamaDecoderLayer + + @property + def half_precision_dtype(self) -> torch.dtype: + """LLaMa-2 was trained in BF16; see https://huggingface.co/docs/transformers/main/model_doc/llama2.""" + return torch.bfloat16 + + @property + def last_layer_finetune_modules(self) -> Sequence[nn.Module]: + return (self.llm.model.embed_tokens, self.llm.model.layers[-1], self.llm.lm_head) diff --git a/capvector-oft/prismatic/models/backbones/llm/mistral.py b/capvector-oft/prismatic/models/backbones/llm/mistral.py new file mode 100644 index 0000000000000000000000000000000000000000..0d2a41fc33a1acc5f92860625c0018736358338e --- /dev/null +++ b/capvector-oft/prismatic/models/backbones/llm/mistral.py @@ -0,0 +1,72 @@ +""" +mistral.py + +Class definition for all LLMs derived from MistralForCausalLM. +""" + +from typing import Optional, Type + +import torch +from torch import nn as nn +from transformers import MistralForCausalLM +from transformers.models.mistral.modeling_mistral import MistralDecoderLayer + +from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone +from prismatic.models.backbones.llm.prompting import MistralInstructPromptBuilder, PromptBuilder, PurePromptBuilder + +# Registry =>> Support Mistral Models (from HF Transformers) +# fmt: off +MISTRAL_MODELS = { + # === Base Mistral v0.1 === + "mistral-v0.1-7b-pure": { + "llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-v0.1" + }, + + # === Mistral Instruct v0.1 === + "mistral-v0.1-7b-instruct": { + "llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-Instruct-v0.1" + } +} +# fmt: on + + +class MistralLLMBackbone(HFCausalLLMBackbone): + def __init__( + self, + llm_backbone_id: str, + llm_max_length: int = 2048, + hf_token: Optional[str] = None, + inference_mode: bool = False, + use_flash_attention_2: bool = True, + ) -> None: + super().__init__( + llm_backbone_id, + llm_max_length=llm_max_length, + hf_token=hf_token, + inference_mode=inference_mode, + use_flash_attention_2=use_flash_attention_2, + **MISTRAL_MODELS[llm_backbone_id], + ) + + # [Special Case] Mistral PAD Token Handling --> for clarity, we add an extra token (and resize) + self.tokenizer.add_special_tokens({"pad_token": ""}) + self.llm.config.pad_token_id = self.tokenizer.pad_token_id + self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) + + @property + def prompt_builder_fn(self) -> Type[PromptBuilder]: + if self.identifier.endswith("-pure"): + return PurePromptBuilder + + elif self.identifier.endswith("-instruct"): + return MistralInstructPromptBuilder + + raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") + + @property + def transformer_layer_cls(self) -> Type[nn.Module]: + return MistralDecoderLayer + + @property + def half_precision_dtype(self) -> torch.dtype: + return torch.bfloat16 diff --git a/capvector-oft/prismatic/models/backbones/llm/phi.py b/capvector-oft/prismatic/models/backbones/llm/phi.py new file mode 100644 index 0000000000000000000000000000000000000000..e9063b3f9ccd1b9b792400600a7fbb0a02150c1b --- /dev/null +++ b/capvector-oft/prismatic/models/backbones/llm/phi.py @@ -0,0 +1,64 @@ +""" +phi.py + +Class definition for all LLMs derived from PhiForCausalLM. +""" + +from typing import Optional, Type + +import torch +from torch import nn as nn +from transformers import PhiForCausalLM +from transformers.models.phi.modeling_phi import PhiDecoderLayer + +from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone +from prismatic.models.backbones.llm.prompting import PhiPromptBuilder, PromptBuilder + +# Registry ==> Support Phi Models (from HF Transformers) +# fmt: off +PHI_MODELS = { + # === Phi-2 === + "phi-2-3b": { + "llm_family": "phi", "llm_cls": PhiForCausalLM, "hf_hub_path": "microsoft/phi-2" + } +} +# fmt: on + + +class PhiLLMBackbone(HFCausalLLMBackbone): + def __init__( + self, + llm_backbone_id: str, + llm_max_length: int = 2048, + hf_token: Optional[str] = None, + inference_mode: bool = False, + use_flash_attention_2: bool = True, + ) -> None: + super().__init__( + llm_backbone_id, + llm_max_length=llm_max_length, + hf_token=hf_token, + inference_mode=inference_mode, + use_flash_attention_2=use_flash_attention_2, + **PHI_MODELS[llm_backbone_id], + ) + + # [Special Case] Phi PAD Token Handling --> for clarity, we add an extra token (and resize) + self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) + self.llm.config.pad_token_id = self.tokenizer.pad_token_id + self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) + + @property + def prompt_builder_fn(self) -> Type[PromptBuilder]: + if self.identifier.startswith("phi-2"): + return PhiPromptBuilder + + raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") + + @property + def transformer_layer_cls(self) -> Type[nn.Module]: + return PhiDecoderLayer + + @property + def half_precision_dtype(self) -> torch.dtype: + return torch.bfloat16 diff --git a/capvector-oft/prismatic/models/backbones/llm/prompting/__init__.py b/capvector-oft/prismatic/models/backbones/llm/prompting/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5c73b61119d11b4d246f9e4a98d1aa70aa821621 --- /dev/null +++ b/capvector-oft/prismatic/models/backbones/llm/prompting/__init__.py @@ -0,0 +1,5 @@ +from .base_prompter import PromptBuilder, PurePromptBuilder +from .llama2_chat_prompter import LLaMa2ChatPromptBuilder +from .mistral_instruct_prompter import MistralInstructPromptBuilder +from .phi_prompter import PhiPromptBuilder +from .vicuna_v15_prompter import VicunaV15ChatPromptBuilder diff --git a/capvector-oft/prismatic/models/backbones/llm/prompting/base_prompter.py b/capvector-oft/prismatic/models/backbones/llm/prompting/base_prompter.py new file mode 100644 index 0000000000000000000000000000000000000000..65c2e16d703d42c7849cd8de9a9cdbbd3361fee3 --- /dev/null +++ b/capvector-oft/prismatic/models/backbones/llm/prompting/base_prompter.py @@ -0,0 +1,73 @@ +""" +base_prompter.py + +Abstract class definition of a multi-turn prompt builder for ensuring consistent formatting for chat-based LLMs. +""" + +from abc import ABC, abstractmethod +from typing import Optional + + +class PromptBuilder(ABC): + def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: + self.model_family = model_family + + # Only some models define a system prompt => let subclasses handle this logic! + self.system_prompt = system_prompt + + @abstractmethod + def add_turn(self, role: str, message: str) -> str: ... + + @abstractmethod + def get_potential_prompt(self, user_msg: str) -> None: ... + + @abstractmethod + def get_prompt(self) -> str: ... + + +class PurePromptBuilder(PromptBuilder): + def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: + super().__init__(model_family, system_prompt) + + # TODO (siddk) =>> Can't always assume LlamaTokenizer --> FIX ME! + self.bos, self.eos = "", "" + + # Get role-specific "wrap" functions + self.wrap_human = lambda msg: f"In: {msg}\nOut: " + self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" + + # === `self.prompt` gets built up over multiple turns === + self.prompt, self.turn_count = "", 0 + + def add_turn(self, role: str, message: str) -> str: + assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") + message = message.replace("", "").strip() + + if (self.turn_count % 2) == 0: + human_message = self.wrap_human(message) + wrapped_message = human_message + else: + gpt_message = self.wrap_gpt(message) + wrapped_message = gpt_message + + # Update Prompt + self.prompt += wrapped_message + + # Bump Turn Counter + self.turn_count += 1 + + # Return "wrapped_message" (effective string added to context) + return wrapped_message + + def get_potential_prompt(self, message: str) -> None: + # Assumes that it's always the user's (human's) turn! + prompt_copy = str(self.prompt) + + human_message = self.wrap_human(message) + prompt_copy += human_message + + return prompt_copy.removeprefix(self.bos).rstrip() + + def get_prompt(self) -> str: + # Remove prefix (if exists) because it gets auto-inserted by tokenizer! + return self.prompt.removeprefix(self.bos).rstrip() diff --git a/capvector-oft/prismatic/models/backbones/llm/prompting/llama2_chat_prompter.py b/capvector-oft/prismatic/models/backbones/llm/prompting/llama2_chat_prompter.py new file mode 100644 index 0000000000000000000000000000000000000000..3b5609aaec9f8f6b688182f96fed9a1db038ba95 --- /dev/null +++ b/capvector-oft/prismatic/models/backbones/llm/prompting/llama2_chat_prompter.py @@ -0,0 +1,91 @@ +""" +llama2_prompter.py + +Defines a PromptBuilder for building LLaMa-2 Chat Prompts --> not sure if this is "optimal", but this is the pattern +that's used by HF and other online tutorials. + +Reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 +""" + +from typing import Optional + +from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder + +# Default System Prompt for Prismatic Models +SYS_PROMPTS = { + "prismatic": ( + "You are a helpful language and vision assistant. " + "You are able to understand the visual content that the user provides, " + "and assist the user with a variety of tasks using natural language." + ), + "openvla": ( + "You are a helpful language and vision assistant. " + "You are able to understand the visual content that the user provides, " + "and assist the user with a variety of tasks using natural language." + ), +} + + +def format_system_prompt(system_prompt: str) -> str: + return f"<\n{system_prompt.strip()}\n<>\n\n" + + +class LLaMa2ChatPromptBuilder(PromptBuilder): + def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: + super().__init__(model_family, system_prompt) + self.system_prompt = format_system_prompt( + SYS_PROMPTS[self.model_family] if system_prompt is None else system_prompt + ) + + # LLaMa-2 Specific + self.bos, self.eos = "", "" + + # Get role-specific "wrap" functions + self.wrap_human = lambda msg: f"[INST] {msg} [/INST] " + self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" + + # === `self.prompt` gets built up over multiple turns === + self.prompt, self.turn_count = "", 0 + + def add_turn(self, role: str, message: str) -> str: + assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") + message = message.replace("", "").strip() + + # Special Handling for "system" prompt (turn_count == 0) + if self.turn_count == 0: + sys_message = self.wrap_human(self.system_prompt + message) + wrapped_message = sys_message + elif (self.turn_count % 2) == 0: + human_message = self.wrap_human(message) + wrapped_message = human_message + else: + gpt_message = self.wrap_gpt(message) + wrapped_message = gpt_message + + # Update Prompt + self.prompt += wrapped_message + + # Bump Turn Counter + self.turn_count += 1 + + # Return "wrapped_message" (effective string added to context) + return wrapped_message + + def get_potential_prompt(self, message: str) -> None: + # Assumes that it's always the user's (human's) turn! + prompt_copy = str(self.prompt) + + # Special Handling for "system" prompt (turn_count == 0) + if self.turn_count == 0: + sys_message = self.wrap_human(self.system_prompt + message) + prompt_copy += sys_message + + else: + human_message = self.wrap_human(message) + prompt_copy += human_message + + return prompt_copy.removeprefix(self.bos).rstrip() + + def get_prompt(self) -> str: + # Remove prefix because it gets auto-inserted by tokenizer! + return self.prompt.removeprefix(self.bos).rstrip() diff --git a/capvector-oft/prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py b/capvector-oft/prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py new file mode 100644 index 0000000000000000000000000000000000000000..e9a22b541ff9aabe3a69fa663dc422ed14acae31 --- /dev/null +++ b/capvector-oft/prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py @@ -0,0 +1,60 @@ +""" +mistral_instruct_prompter.py + +Defines a PromptBuilder for building Mistral Instruct Chat Prompts --> recommended pattern used by HF / online tutorial.s + +Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format +""" + +from typing import Optional + +from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder + + +class MistralInstructPromptBuilder(PromptBuilder): + def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: + super().__init__(model_family, system_prompt) + + # Note =>> Mistral Tokenizer is an instance of `LlamaTokenizer(Fast)` + # =>> Mistral Instruct *does not* use a System Prompt + self.bos, self.eos = "", "" + + # Get role-specific "wrap" functions + self.wrap_human = lambda msg: f"[INST] {msg} [/INST] " + self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" + + # === `self.prompt` gets built up over multiple turns === + self.prompt, self.turn_count = "", 0 + + def add_turn(self, role: str, message: str) -> str: + assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") + message = message.replace("", "").strip() + + if (self.turn_count % 2) == 0: + human_message = self.wrap_human(message) + wrapped_message = human_message + else: + gpt_message = self.wrap_gpt(message) + wrapped_message = gpt_message + + # Update Prompt + self.prompt += wrapped_message + + # Bump Turn Counter + self.turn_count += 1 + + # Return "wrapped_message" (effective string added to context) + return wrapped_message + + def get_potential_prompt(self, message: str) -> None: + # Assumes that it's always the user's (human's) turn! + prompt_copy = str(self.prompt) + + human_message = self.wrap_human(message) + prompt_copy += human_message + + return prompt_copy.removeprefix(self.bos).rstrip() + + def get_prompt(self) -> str: + # Remove prefix because it gets auto-inserted by tokenizer! + return self.prompt.removeprefix(self.bos).rstrip() diff --git a/capvector-oft/prismatic/models/backbones/llm/prompting/phi_prompter.py b/capvector-oft/prismatic/models/backbones/llm/prompting/phi_prompter.py new file mode 100644 index 0000000000000000000000000000000000000000..3843a33bc86a716002ef52a0e067bff5101407f2 --- /dev/null +++ b/capvector-oft/prismatic/models/backbones/llm/prompting/phi_prompter.py @@ -0,0 +1,65 @@ +""" +phi_prompter.py + +Defines a PromptBuilder for building Phi-2 Input/Output Prompts --> recommended pattern used by HF / Microsoft. +Also handles Phi special case BOS token additions. + +Reference: https://huggingface.co/microsoft/phi-2#qa-format +""" + +from typing import Optional + +from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder + + +class PhiPromptBuilder(PromptBuilder): + def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: + super().__init__(model_family, system_prompt) + + # Note =>> Phi Tokenizer is an instance of `CodeGenTokenizer(Fast)` + # =>> By default, does *not* append / tokens --> we handle that here (IMPORTANT)! + self.bos, self.eos = "<|endoftext|>", "<|endoftext|>" + + # Get role-specific "wrap" functions + # =>> Note that placement of / were based on experiments generating from Phi-2 in Input/Output mode + self.wrap_human = lambda msg: f"Input: {msg}\nOutput: " + self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}\n{self.eos}" + + # === `self.prompt` gets built up over multiple turns === + self.prompt, self.turn_count = "", 0 + + def add_turn(self, role: str, message: str) -> str: + assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") + message = message.replace("", "").strip() + + # Special Handling for "first" input --> prepend a token (expected by Prismatic) + if self.turn_count == 0: + bos_human_message = f"{self.bos}{self.wrap_human(message)}" + wrapped_message = bos_human_message + elif (self.turn_count % 2) == 0: + human_message = self.wrap_human(message) + wrapped_message = human_message + else: + gpt_message = self.wrap_gpt(message) + wrapped_message = gpt_message + + # Update Prompt + self.prompt += wrapped_message + + # Bump Turn Counter + self.turn_count += 1 + + # Return "wrapped_message" (effective string added to context) + return wrapped_message + + def get_potential_prompt(self, message: str) -> None: + # Assumes that it's always the user's (human's) turn! + prompt_copy = str(self.prompt) + + human_message = self.wrap_human(message) + prompt_copy += human_message + + return prompt_copy.rstrip() + + def get_prompt(self) -> str: + return self.prompt.rstrip() diff --git a/capvector-oft/prismatic/models/backbones/llm/prompting/vicuna_v15_prompter.py b/capvector-oft/prismatic/models/backbones/llm/prompting/vicuna_v15_prompter.py new file mode 100644 index 0000000000000000000000000000000000000000..5ea246a16533f580332716361755f98f5aabe01d --- /dev/null +++ b/capvector-oft/prismatic/models/backbones/llm/prompting/vicuna_v15_prompter.py @@ -0,0 +1,82 @@ +""" +vicuna_v15_prompter.py + +Defines a PromptBuilder for building Vicuna-v1.5 Chat Prompts. + +Reference: https://huggingface.co/lmsys/vicuna-13b-v1.5 +""" + +from typing import Optional + +from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder + +# Default System Prompt for LLaVa Models +SYS_PROMPTS = { + "prismatic": ( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), + "openvla": ( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), +} + + +class VicunaV15ChatPromptBuilder(PromptBuilder): + def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: + super().__init__(model_family, system_prompt) + self.system_prompt = (SYS_PROMPTS[self.model_family] if system_prompt is None else system_prompt).strip() + " " + + # LLaMa-2 Specific + self.bos, self.eos = "", "" + + # Get role-specific "wrap" functions + self.wrap_human = lambda msg: f"USER: {msg} ASSISTANT: " + self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" + + # === `self.prompt` gets built up over multiple turns === + self.prompt, self.turn_count = "", 0 + + def add_turn(self, role: str, message: str) -> str: + assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") + message = message.replace("", "").strip() + + # Special Handling for "system" prompt (turn_count == 0) + if self.turn_count == 0: + sys_message = self.system_prompt + self.wrap_human(message) + wrapped_message = sys_message + elif (self.turn_count % 2) == 0: + human_message = self.wrap_human(message) + wrapped_message = human_message + else: + gpt_message = self.wrap_gpt(message) + wrapped_message = gpt_message + + # Update Prompt + self.prompt += wrapped_message + + # Bump Turn Counter + self.turn_count += 1 + + # Return "wrapped_message" (effective string added to context) + return wrapped_message + + def get_potential_prompt(self, message: str) -> None: + # Assumes that it's always the user's (human's) turn! + prompt_copy = str(self.prompt) + + # Special Handling for "system" prompt (turn_count == 0) + if self.turn_count == 0: + sys_message = self.system_prompt + self.wrap_human(message) + prompt_copy += sys_message + + else: + human_message = self.wrap_human(message) + prompt_copy += human_message + + return prompt_copy.removeprefix(self.bos).rstrip() + + def get_prompt(self) -> str: + # Remove prefix (if exists) because it gets auto-inserted by tokenizer! + return self.prompt.removeprefix(self.bos).rstrip() diff --git a/capvector-oft/prismatic/models/backbones/vision/__init__.py b/capvector-oft/prismatic/models/backbones/vision/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3c6da9a186cb68050eb11688b20177fc0ee4359c --- /dev/null +++ b/capvector-oft/prismatic/models/backbones/vision/__init__.py @@ -0,0 +1,7 @@ +from .base_vision import ImageTransform, VisionBackbone +from .clip_vit import CLIPViTBackbone +from .dinoclip_vit import DinoCLIPViTBackbone +from .dinosiglip_vit import DinoSigLIPViTBackbone +from .dinov2_vit import DinoV2ViTBackbone +from .in1k_vit import IN1KViTBackbone +from .siglip_vit import SigLIPViTBackbone diff --git a/capvector-oft/prismatic/models/backbones/vision/base_vision.py b/capvector-oft/prismatic/models/backbones/vision/base_vision.py new file mode 100644 index 0000000000000000000000000000000000000000..8268c4dd53c2caa7bc98efbf1818644d46515cd3 --- /dev/null +++ b/capvector-oft/prismatic/models/backbones/vision/base_vision.py @@ -0,0 +1,207 @@ +""" +base_vision.py + +Abstract class definition of a Vision Backbone (Visual Featurizer), with full annotations of class methods, utility +functions, and initialization logic. + +We also define the generic TimmViTBackbone class here, providing a default interface for loading any TIMM Vision +Transformer model for feature extraction. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union + +import timm +import torch +import torch.nn as nn +import torchvision.transforms.functional as TVF +from PIL.Image import Image +from timm.models.vision_transformer import Block, VisionTransformer +from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy +from torchvision.transforms import Compose, Resize + + +# === Utility Functions for Monkey-Patching === +def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]: + def wrapper(*args: Any, **kwargs: Any) -> Any: + result = fn(*args, **kwargs) + return result[0] if isinstance(result, tuple) else result + + return wrapper + + +# === Interface for an Image Transform === +class ImageTransform(Protocol): + def __call__(self, img: Image, **kwargs: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: ... + + +# === Custom Torchvision Image Transforms === +@dataclass +class LetterboxPad: + padding_fill_value: Tuple[int, int, int] + + def __call__(self, image: Image) -> Image: + """Given a PIL.Image, pad to square by adding a symmetric border around the height/width.""" + (w, h), max_wh = image.size, max(image.size) + horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2) + padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad) + return TVF.pad(image, padding, fill=self.padding_fill_value, padding_mode="constant") + + +# === Abstract Base Class for arbitrary Vision Backbones === +class VisionBackbone(nn.Module, ABC): + def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: + super().__init__() + self.identifier: str = vision_backbone_id + self.image_resize_strategy: str = image_resize_strategy + self.default_image_size: int = default_image_size + + # Instance attributes for a Vision Backbone + self.featurizer: nn.Module = None + self.image_transform: ImageTransform = None + + def get_image_transform(self) -> ImageTransform: + return self.image_transform + + @abstractmethod + def get_fsdp_wrapping_policy(self) -> Callable: ... + + @abstractmethod + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + """Run a forward pass through the featurizer given a set of processed images, returning patch/grid features.""" + raise NotImplementedError + + @property + @abstractmethod + def default_image_resolution(self) -> Tuple[int, int, int]: ... + + @property + @abstractmethod + def embed_dim(self) -> int: ... + + @property + @abstractmethod + def num_patches(self) -> int: ... + + @property + @abstractmethod + def half_precision_dtype(self) -> torch.dtype: ... + + +# === Abstract Base Class for Arbitrary TIMM Vision Transformer Backbones === +class TimmViTBackbone(VisionBackbone, ABC): + def __init__( + self, + vision_backbone_id: str, + timm_path_or_url: str, + image_resize_strategy: str, + default_image_size: int = 224, + override_act_layer: Optional[str] = None, + ) -> None: + super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size) + self.timm_path_or_url = timm_path_or_url + self.override_act_layer = override_act_layer + self.dtype = torch.bfloat16 + + # Initialize Featurizer (ViT) by downloading from HF / TIMM Hub if necessary + if self.override_act_layer is None: + self.featurizer: VisionTransformer = timm.create_model( + self.timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size + ) + else: + self.featurizer: VisionTransformer = timm.create_model( + self.timm_path_or_url, + pretrained=True, + num_classes=0, + img_size=self.default_image_size, + act_layer=self.override_act_layer, + ) + self.featurizer.eval() + + # Monkey-Patch the `forward()` function of the featurizer to ensure FSDP-compatibility + # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches! + # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385 + self.featurizer.forward = unpack_tuple( + partial(self.featurizer.get_intermediate_layers, n={len(self.featurizer.blocks) - 2}) + ) + + # Validation =>> for now, this class *only* supports TIMM Vision Transformers (but can be extended!) + assert isinstance(self.featurizer, VisionTransformer), ( + "Featurizer is not a TIMM VisionTransformer; if you would like to support a new visual representation, " + "file an issue or implement the requisite logic (see `prismatic/models/backbones/vision/base_vision.py`)!" + ) + + # Get Config =>> Note :: Override default image size to ensure correct image transform + self.data_cfg = timm.data.resolve_model_data_config(self.featurizer) + self.data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) + + # Initialize Default Image Transform --> Modified by `self.image_resize_strategy` + default_image_transform = timm.data.create_transform(**self.data_cfg, is_training=False) + + # Fix =>> SigLIP & IN1K default transforms resize to *larger* than `self.default_image_size` (crops image)! + if "siglip" in self.timm_path_or_url or "in1k" in self.timm_path_or_url: + assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!" + assert isinstance(default_image_transform.transforms[0], Resize) + default_image_transform = Compose( + [ + Resize(self.default_image_size, interpolation=default_image_transform.transforms[0].interpolation), + *default_image_transform.transforms[1:], + ] + ) + + # Switch on `image_resize_strategy` + if self.image_resize_strategy == "resize-naive": + assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!" + assert isinstance(default_image_transform.transforms[0], Resize) + + target_size = (self.default_image_size, self.default_image_size) + self.image_transform = Compose( + [ + Resize(target_size, interpolation=default_image_transform.transforms[0].interpolation), + *default_image_transform.transforms[1:], + ] + ) + + elif self.image_resize_strategy == "resize-crop": + self.image_transform = default_image_transform + + elif self.image_resize_strategy == "letterbox": + assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!" + assert "mean" in self.data_cfg, "TIMM `data_cfg` missing image normalization mean!" + + # Compute Padding Fill Value (rescaled normalization mean if applicable) + fill = tuple([int(x * 255) for x in self.data_cfg["mean"]]) + + # Build New Transform + self.image_transform = Compose([LetterboxPad(fill), *default_image_transform.transforms]) + + else: + raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!") + + def get_fsdp_wrapping_policy(self) -> Callable: + """Return a simple FSDP policy that wraps each ViT block and then the _entire_ featurizer.""" + vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer}) + transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) + return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy]) + + def forward(self, pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor: + """Runs transformed image/pixel tensor through vision backbone, returning _all_ patch features.""" + return self.featurizer(pixel_values) + + @property + def default_image_resolution(self) -> Tuple[int, int, int]: + return self.data_cfg["input_size"] + + @property + def embed_dim(self) -> int: + return self.featurizer.embed_dim + + @property + def num_patches(self) -> int: + return self.featurizer.patch_embed.num_patches + + @property + def half_precision_dtype(self) -> torch.dtype: + return self.dtype diff --git a/capvector-oft/prismatic/models/backbones/vision/clip_vit.py b/capvector-oft/prismatic/models/backbones/vision/clip_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..1023d0b8ee9500547c1649bfb3d82493e6c2659d --- /dev/null +++ b/capvector-oft/prismatic/models/backbones/vision/clip_vit.py @@ -0,0 +1,27 @@ +""" +clip_vit.py +""" + +from prismatic.models.backbones.vision.base_vision import TimmViTBackbone + +# Registry =>> Supported CLIP Vision Backbones (from TIMM) +CLIP_VISION_BACKBONES = { + "clip-vit-b": "vit_base_patch16_clip_224.openai", + "clip-vit-l": "vit_large_patch14_clip_224.openai", + "clip-vit-l-336px": "vit_large_patch14_clip_336.openai", +} + + +# [IMPORTANT] By Default, TIMM initialized OpenAI CLIP models with the standard GELU activation from PyTorch. +# HOWEVER =>> Original OpenAI models were trained with the quick_gelu *approximation* -- while it's +# a decent approximation, the resulting features are *worse*; this was a super tricky bug +# to identify, but luckily there's an easy fix (`override_act_layer`) +class CLIPViTBackbone(TimmViTBackbone): + def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: + super().__init__( + vision_backbone_id, + CLIP_VISION_BACKBONES[vision_backbone_id], + image_resize_strategy, + default_image_size=default_image_size, + override_act_layer="quick_gelu" if CLIP_VISION_BACKBONES[vision_backbone_id].endswith(".openai") else None, + ) diff --git a/capvector-oft/prismatic/models/backbones/vision/dinoclip_vit.py b/capvector-oft/prismatic/models/backbones/vision/dinoclip_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..318598a3197ef72afb2ab10c7c48285a7e6d4284 --- /dev/null +++ b/capvector-oft/prismatic/models/backbones/vision/dinoclip_vit.py @@ -0,0 +1,147 @@ +""" +dinoclip_vit.py + +Vision backbone that returns concatenated features from both DINOv2 and CLIP. +""" + +from dataclasses import dataclass +from functools import partial +from typing import Callable, Dict, Tuple + +import timm +import torch +from PIL import Image +from timm.models.vision_transformer import Block, VisionTransformer +from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy +from torchvision.transforms import Compose, Resize + +from prismatic.models.backbones.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple + +# Registry =>> Supported DinoCLIP Pairs (as TIMM identifiers) +DINOCLIP_VISION_BACKBONES = { + "dinoclip-vit-l-336px": { + "dino": "vit_large_patch14_reg4_dinov2.lvd142m", + "clip": "vit_large_patch14_clip_336.openai", + }, +} + + +@dataclass +class DinoCLIPImageTransform: + dino_image_transform: ImageTransform + clip_image_transform: ImageTransform + is_prismatic: bool = True + + def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]: + return {"dino": self.dino_image_transform(img, **kwargs), "clip": self.clip_image_transform(img, **kwargs)} + + +class DinoCLIPViTBackbone(VisionBackbone): + def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: + super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size) + self.dino_timm_path_or_url = DINOCLIP_VISION_BACKBONES[vision_backbone_id]["dino"] + self.clip_timm_path_or_url = DINOCLIP_VISION_BACKBONES[vision_backbone_id]["clip"] + + # Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary + self.dino_featurizer: VisionTransformer = timm.create_model( + self.dino_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size + ) + self.dino_featurizer.eval() + + self.clip_featurizer: VisionTransformer = timm.create_model( + self.clip_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size + ) + self.clip_featurizer.eval() + + # Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility + # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches! + # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385 + self.dino_featurizer.forward = unpack_tuple( + partial(self.dino_featurizer.get_intermediate_layers, n={len(self.dino_featurizer.blocks) - 2}) + ) + self.clip_featurizer.forward = unpack_tuple( + partial(self.clip_featurizer.get_intermediate_layers, n={len(self.clip_featurizer.blocks) - 2}) + ) + + # Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models + self.dino_data_cfg = timm.data.resolve_model_data_config(self.dino_featurizer) + self.dino_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) + + self.clip_data_cfg = timm.data.resolve_model_data_config(self.clip_featurizer) + self.clip_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) + + # Initialize *both* Transforms + default_dino_transform = timm.data.create_transform(**self.dino_data_cfg, is_training=False) + default_clip_transform = timm.data.create_transform(**self.clip_data_cfg, is_training=False) + if self.image_resize_strategy == "resize-naive": + assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_image_transform`!" + assert isinstance(default_clip_transform, Compose), "Unexpected `default_clip_image_transform`!" + assert isinstance(default_dino_transform.transforms[0], Resize) + assert isinstance(default_clip_transform.transforms[0], Resize) + + target_size = (self.default_image_size, self.default_image_size) + dino_transform = Compose( + [ + Resize(target_size, interpolation=default_dino_transform.transforms[0].interpolation), + *default_dino_transform.transforms[1:], + ] + ) + clip_transform = Compose( + [ + Resize(target_size, interpolation=default_clip_transform.transforms[0].interpolation), + *default_clip_transform.transforms[1:], + ] + ) + + self.image_transform = DinoCLIPImageTransform(dino_transform, clip_transform) + + elif self.image_resize_strategy == "resize-crop": + self.image_transform = DinoCLIPImageTransform(default_dino_transform, default_clip_transform) + + elif self.image_resize_strategy == "letterbox": + assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_transform`!" + assert isinstance(default_clip_transform, Compose), "Unexpected `default_clip_transform`!" + assert "mean" in self.dino_data_cfg and "mean" in self.clip_data_cfg, "DinoCLIP `data_cfg` missing `mean`!" + + # Compute Padding Fill Value(s) (rescaled normalization mean if applicable) + dino_fill = tuple([int(x * 255) for x in self.dino_data_cfg["mean"]]) + clip_fill = tuple([int(x * 255) for x in self.clip_data_cfg["mean"]]) + + # Build New Transform + self.image_transform = DinoCLIPImageTransform( + Compose([LetterboxPad(dino_fill), *default_dino_transform.transforms]), + Compose([LetterboxPad(clip_fill), *default_clip_transform.transforms]), + ) + + else: + raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!") + + def get_fsdp_wrapping_policy(self) -> Callable: + """Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers.""" + vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer}) + transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) + return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy]) + + def forward(self, pixel_values: Dict[str, torch.Tensor]) -> torch.Tensor: + """Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches.""" + dino_patches = self.dino_featurizer(pixel_values["dino"]) + clip_patches = self.clip_featurizer(pixel_values["clip"]) + + return torch.cat([dino_patches, clip_patches], dim=2) + + @property + def default_image_resolution(self) -> Tuple[int, int, int]: + return self.dino_data_cfg["input_size"] + + @property + def embed_dim(self) -> int: + return self.dino_featurizer.embed_dim + self.clip_featurizer.embed_dim + + @property + def num_patches(self) -> int: + assert self.dino_featurizer.patch_embed.num_patches == self.clip_featurizer.patch_embed.num_patches + return self.dino_featurizer.patch_embed.num_patches + + @property + def half_precision_dtype(self) -> torch.dtype: + return torch.bfloat16 diff --git a/capvector-oft/prismatic/models/backbones/vision/dinosiglip_vit.py b/capvector-oft/prismatic/models/backbones/vision/dinosiglip_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..c8762dadf0ff756bd9b12642917c21ac58c5d5eb --- /dev/null +++ b/capvector-oft/prismatic/models/backbones/vision/dinosiglip_vit.py @@ -0,0 +1,164 @@ +""" +dinosiglip_vit.py + +Vision backbone that returns concatenated features from both DINOv2 and SigLIP. +""" + +from dataclasses import dataclass +from functools import partial +from typing import Callable, Dict, Tuple + +import timm +import torch +from PIL import Image +from timm.models.vision_transformer import Block, VisionTransformer +from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy +from torchvision.transforms import Compose, Resize + +from prismatic.models.backbones.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple + +# Registry =>> Supported DinoSigLIP Pairs (as TIMM identifiers) +DINOSigLIP_VISION_BACKBONES = { + "dinosiglip-vit-so-224px": { + "dino": "vit_large_patch14_reg4_dinov2.lvd142m", + "siglip": "vit_so400m_patch14_siglip_224", + }, + "dinosiglip-vit-so-384px": { + "dino": "vit_large_patch14_reg4_dinov2.lvd142m", + "siglip": "vit_so400m_patch14_siglip_384", + }, +} + + +@dataclass +class DinoSigLIPImageTransform: + dino_image_transform: ImageTransform + siglip_image_transform: ImageTransform + is_prismatic: bool = True + + def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]: + return {"dino": self.dino_image_transform(img, **kwargs), "siglip": self.siglip_image_transform(img, **kwargs)} + + +class DinoSigLIPViTBackbone(VisionBackbone): + def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: + super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size) + self.dino_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[vision_backbone_id]["dino"] + self.siglip_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[vision_backbone_id]["siglip"] + + # Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary + self.dino_featurizer: VisionTransformer = timm.create_model( + self.dino_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size + ) + self.dino_featurizer.eval() + + self.siglip_featurizer: VisionTransformer = timm.create_model( + self.siglip_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size + ) + self.siglip_featurizer.eval() + + # Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility + # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches! + # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385 + self.dino_featurizer.forward = unpack_tuple( + partial(self.dino_featurizer.get_intermediate_layers, n={len(self.dino_featurizer.blocks) - 2}) + ) + self.siglip_featurizer.forward = unpack_tuple( + partial(self.siglip_featurizer.get_intermediate_layers, n={len(self.siglip_featurizer.blocks) - 2}) + ) + + # Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models + self.dino_data_cfg = timm.data.resolve_model_data_config(self.dino_featurizer) + self.dino_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) + + self.siglip_data_cfg = timm.data.resolve_model_data_config(self.siglip_featurizer) + self.siglip_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) + + # Initialize *both* Transforms + default_dino_transform = timm.data.create_transform(**self.dino_data_cfg, is_training=False) + default_siglip_transform = timm.data.create_transform(**self.siglip_data_cfg, is_training=False) + + # Fix =>> SigLIP default transform resizes to *larger* than `self.default_image_size` (crops image)!! + assert isinstance(default_siglip_transform, Compose), "Unexpected `default_image_transform`!" + assert isinstance(default_siglip_transform.transforms[0], Resize) + default_siglip_transform = Compose( + [ + Resize(self.default_image_size, interpolation=default_siglip_transform.transforms[0].interpolation), + *default_siglip_transform.transforms[1:], + ] + ) + + if self.image_resize_strategy == "resize-naive": + assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_image_transform`!" + assert isinstance(default_siglip_transform, Compose), "Unexpected `default_siglip_image_transform`!" + assert isinstance(default_dino_transform.transforms[0], Resize) + assert isinstance(default_siglip_transform.transforms[0], Resize) + + target_size = (self.default_image_size, self.default_image_size) + dino_transform = Compose( + [ + Resize(target_size, interpolation=default_dino_transform.transforms[0].interpolation), + *default_dino_transform.transforms[1:], + ] + ) + siglip_transform = Compose( + [ + Resize(target_size, interpolation=default_siglip_transform.transforms[0].interpolation), + *default_siglip_transform.transforms[1:], + ] + ) + + self.image_transform = DinoSigLIPImageTransform(dino_transform, siglip_transform) + + elif self.image_resize_strategy == "resize-crop": + self.image_transform = DinoSigLIPImageTransform(default_dino_transform, default_siglip_transform) + + elif self.image_resize_strategy == "letterbox": + assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_transform`!" + assert isinstance(default_siglip_transform, Compose), "Unexpected `default_siglip_transform`!" + assert ( + "mean" in self.dino_data_cfg and "mean" in self.siglip_data_cfg + ), "DinoSigLIP `data_cfg` missing `mean`!" + + # Compute Padding Fill Value(s) (rescaled normalization mean if applicable) + dino_fill = tuple([int(x * 255) for x in self.dino_data_cfg["mean"]]) + siglip_fill = tuple([int(x * 255) for x in self.siglip_data_cfg["mean"]]) + + # Build New Transform + self.image_transform = DinoSigLIPImageTransform( + Compose([LetterboxPad(dino_fill), *default_dino_transform.transforms]), + Compose([LetterboxPad(siglip_fill), *default_siglip_transform.transforms]), + ) + + else: + raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!") + + def get_fsdp_wrapping_policy(self) -> Callable: + """Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers.""" + vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer}) + transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) + return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy]) + + def forward(self, pixel_values: Dict[str, torch.Tensor]) -> torch.Tensor: + """Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches.""" + dino_patches = self.dino_featurizer(pixel_values["dino"]) + siglip_patches = self.siglip_featurizer(pixel_values["siglip"]) + + return torch.cat([dino_patches, siglip_patches], dim=2) + + @property + def default_image_resolution(self) -> Tuple[int, int, int]: + return self.dino_data_cfg["input_size"] + + @property + def embed_dim(self) -> int: + return self.dino_featurizer.embed_dim + self.siglip_featurizer.embed_dim + + @property + def num_patches(self) -> int: + assert self.dino_featurizer.patch_embed.num_patches == self.siglip_featurizer.patch_embed.num_patches + return self.dino_featurizer.patch_embed.num_patches + + @property + def half_precision_dtype(self) -> torch.dtype: + return torch.bfloat16 diff --git a/capvector-oft/prismatic/models/backbones/vision/dinov2_vit.py b/capvector-oft/prismatic/models/backbones/vision/dinov2_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..d36acee29fffd3a202d24c1ee1d7c55cd53fa8ea --- /dev/null +++ b/capvector-oft/prismatic/models/backbones/vision/dinov2_vit.py @@ -0,0 +1,19 @@ +""" +dinov2_vit.py +""" + +from prismatic.models.backbones.vision.base_vision import TimmViTBackbone + +# Registry =>> Supported DINOv2 Vision Backbones (from TIMM) =>> Note:: Using DINOv2 w/ Registers! +# => Reference: https://arxiv.org/abs/2309.16588 +DINOv2_VISION_BACKBONES = {"dinov2-vit-l": "vit_large_patch14_reg4_dinov2.lvd142m"} + + +class DinoV2ViTBackbone(TimmViTBackbone): + def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: + super().__init__( + vision_backbone_id, + DINOv2_VISION_BACKBONES[vision_backbone_id], + image_resize_strategy, + default_image_size=default_image_size, + ) diff --git a/capvector-oft/prismatic/models/backbones/vision/in1k_vit.py b/capvector-oft/prismatic/models/backbones/vision/in1k_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..ba8fb0ee919851e5b9698b998b35513371873f6a --- /dev/null +++ b/capvector-oft/prismatic/models/backbones/vision/in1k_vit.py @@ -0,0 +1,22 @@ +""" +in1k_vit.py + +Vision Transformers trained / finetuned on ImageNet (ImageNet-21K =>> ImageNet-1K) +""" + +from prismatic.models.backbones.vision.base_vision import TimmViTBackbone + +# Registry =>> Supported Vision Backbones (from TIMM) +IN1K_VISION_BACKBONES = { + "in1k-vit-l": "vit_large_patch16_224.augreg_in21k_ft_in1k", +} + + +class IN1KViTBackbone(TimmViTBackbone): + def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: + super().__init__( + vision_backbone_id, + IN1K_VISION_BACKBONES[vision_backbone_id], + image_resize_strategy, + default_image_size=default_image_size, + ) diff --git a/capvector-oft/prismatic/models/backbones/vision/siglip_vit.py b/capvector-oft/prismatic/models/backbones/vision/siglip_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..618ff087134003b1ee7d237d500cfe90a42ef798 --- /dev/null +++ b/capvector-oft/prismatic/models/backbones/vision/siglip_vit.py @@ -0,0 +1,24 @@ +""" +siglip_vit.py +""" + +from prismatic.models.backbones.vision.base_vision import TimmViTBackbone + +# Registry =>> Supported SigLIP Vision Backbones (from TIMM) =>> Note:: Using SigLIP w/ Patch = 14 (but SO400M Arch) +SIGLIP_VISION_BACKBONES = { + "siglip-vit-b16-224px": "vit_base_patch16_siglip_224", + "siglip-vit-b16-256px": "vit_base_patch16_siglip_256", + "siglip-vit-b16-384px": "vit_base_patch16_siglip_384", + "siglip-vit-so400m": "vit_so400m_patch14_siglip_224", + "siglip-vit-so400m-384px": "vit_so400m_patch14_siglip_384", +} + + +class SigLIPViTBackbone(TimmViTBackbone): + def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: + super().__init__( + vision_backbone_id, + SIGLIP_VISION_BACKBONES[vision_backbone_id], + image_resize_strategy, + default_image_size=default_image_size, + ) diff --git a/capvector-oft/prismatic/models/load.py b/capvector-oft/prismatic/models/load.py new file mode 100644 index 0000000000000000000000000000000000000000..76cc3a3ae2362806d4d179fc4e427a9ac89eeb84 --- /dev/null +++ b/capvector-oft/prismatic/models/load.py @@ -0,0 +1,226 @@ +""" +load.py + +Entry point for loading pretrained VLMs for inference; exposes functions for listing available models (with canonical +IDs, mappings to paper experiments, and short descriptions), as well as for loading models (from disk or HF Hub). +""" + +import json +import os +from pathlib import Path +from typing import List, Optional, Union + +from huggingface_hub import HfFileSystem, hf_hub_download + +from prismatic.conf import ModelConfig +from prismatic.models.materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform +from prismatic.models.registry import GLOBAL_REGISTRY, MODEL_REGISTRY +from prismatic.models.vlas import OpenVLA +from prismatic.models.vlms import PrismaticVLM +from prismatic.overwatch import initialize_overwatch +from prismatic.vla.action_tokenizer import ActionTokenizer + +# Initialize Overwatch =>> Wraps `logging.Logger` +overwatch = initialize_overwatch(__name__) + + +# === HF Hub Repository === +HF_HUB_REPO = "TRI-ML/prismatic-vlms" +VLA_HF_HUB_REPO = "openvla/openvla-dev" + + +# === Available Models === +def available_models() -> List[str]: + return list(MODEL_REGISTRY.keys()) + + +def available_model_names() -> List[str]: + return list(GLOBAL_REGISTRY.items()) + + +def get_model_description(model_id_or_name: str) -> str: + if model_id_or_name not in GLOBAL_REGISTRY: + raise ValueError(f"Couldn't find `{model_id_or_name = }; check `prismatic.available_model_names()`") + + # Print Description & Return + print(json.dumps(description := GLOBAL_REGISTRY[model_id_or_name]["description"], indent=2)) + + return description + + +# === Load Pretrained Model === +def load( + model_id_or_path: Union[str, Path], + hf_token: Optional[str] = None, + cache_dir: Optional[Union[str, Path]] = None, + load_for_training: bool = False, +) -> PrismaticVLM: + """Loads a pretrained PrismaticVLM from either local disk or the HuggingFace Hub.""" + if os.path.isdir(model_id_or_path): + overwatch.info(f"Loading from local path `{(run_dir := Path(model_id_or_path))}`") + + # Get paths for `config.json` and pretrained checkpoint + config_json, checkpoint_pt = run_dir / "config.json", run_dir / "checkpoints" / "latest-checkpoint.pt" + assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`" + assert checkpoint_pt.exists(), f"Missing checkpoint for `{run_dir = }`" + else: + if model_id_or_path not in GLOBAL_REGISTRY: + raise ValueError(f"Couldn't find `{model_id_or_path = }; check `prismatic.available_model_names()`") + + overwatch.info(f"Downloading `{(model_id := GLOBAL_REGISTRY[model_id_or_path]['model_id'])} from HF Hub") + with overwatch.local_zero_first(): + config_json = hf_hub_download(repo_id=HF_HUB_REPO, filename=f"{model_id}/config.json", cache_dir=cache_dir) + checkpoint_pt = hf_hub_download( + repo_id=HF_HUB_REPO, filename=f"{model_id}/checkpoints/latest-checkpoint.pt", cache_dir=cache_dir + ) + + # Load Model Config from `config.json` + with open(config_json, "r") as f: + model_cfg = json.load(f)["model"] + + # = Load Individual Components necessary for Instantiating a VLM = + # =>> Print Minimal Config + overwatch.info( + f"Found Config =>> Loading & Freezing [bold blue]{model_cfg['model_id']}[/] with:\n" + f" Vision Backbone =>> [bold]{model_cfg['vision_backbone_id']}[/]\n" + f" LLM Backbone =>> [bold]{model_cfg['llm_backbone_id']}[/]\n" + f" Arch Specifier =>> [bold]{model_cfg['arch_specifier']}[/]\n" + f" Checkpoint Path =>> [underline]`{checkpoint_pt}`[/]" + ) + + # Load Vision Backbone + overwatch.info(f"Loading Vision Backbone [bold]{model_cfg['vision_backbone_id']}[/]") + vision_backbone, image_transform = get_vision_backbone_and_transform( + model_cfg["vision_backbone_id"], + model_cfg["image_resize_strategy"], + ) + + # Load LLM Backbone --> note `inference_mode = True` by default when calling `load()` + overwatch.info(f"Loading Pretrained LLM [bold]{model_cfg['llm_backbone_id']}[/] via HF Transformers") + llm_backbone, tokenizer = get_llm_backbone_and_tokenizer( + model_cfg["llm_backbone_id"], + llm_max_length=model_cfg.get("llm_max_length", 2048), + hf_token=hf_token, + inference_mode=not load_for_training, + ) + + # Load VLM using `from_pretrained` (clobbers HF syntax... eventually should reconcile) + overwatch.info(f"Loading VLM [bold blue]{model_cfg['model_id']}[/] from Checkpoint") + vlm = PrismaticVLM.from_pretrained( + checkpoint_pt, + model_cfg["model_id"], + vision_backbone, + llm_backbone, + arch_specifier=model_cfg["arch_specifier"], + freeze_weights=not load_for_training, + ) + + return vlm + + +# === Load Pretrained VLA Model === +def load_vla( + model_id_or_path: Union[str, Path], + hf_token: Optional[str] = None, + cache_dir: Optional[Union[str, Path]] = None, + load_for_training: bool = False, + step_to_load: Optional[int] = None, + model_type: str = "pretrained", +) -> OpenVLA: + """Loads a pretrained OpenVLA from either local disk or the HuggingFace Hub.""" + + # TODO (siddk, moojink) :: Unify semantics with `load()` above; right now, `load_vla()` assumes path points to + # checkpoint `.pt` file, rather than the top-level run directory! + if os.path.isfile(model_id_or_path): + overwatch.info(f"Loading from local checkpoint path `{(checkpoint_pt := Path(model_id_or_path))}`") + + # [Validate] Checkpoint Path should look like `...//checkpoints/.pt` + assert (checkpoint_pt.suffix == ".pt") and (checkpoint_pt.parent.name == "checkpoints"), "Invalid checkpoint!" + run_dir = checkpoint_pt.parents[1] + + # Get paths for `config.json`, `dataset_statistics.json` and pretrained checkpoint + config_json, dataset_statistics_json = run_dir / "config.json", run_dir / "dataset_statistics.json" + assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`" + assert dataset_statistics_json.exists(), f"Missing `dataset_statistics.json` for `{run_dir = }`" + + # Otherwise =>> try looking for a match on `model_id_or_path` on the HF Hub (`VLA_HF_HUB_REPO`) + else: + # Search HF Hub Repo via fsspec API + overwatch.info(f"Checking HF for `{(hf_path := str(Path(VLA_HF_HUB_REPO) / model_type / model_id_or_path))}`") + if not (tmpfs := HfFileSystem()).exists(hf_path): + raise ValueError(f"Couldn't find valid HF Hub Path `{hf_path = }`") + + # Identify Checkpoint to Load (via `step_to_load`) + step_to_load = f"{step_to_load:06d}" if step_to_load is not None else None + valid_ckpts = tmpfs.glob(f"{hf_path}/checkpoints/step-{step_to_load if step_to_load is not None else ''}*.pt") + if (len(valid_ckpts) == 0) or (step_to_load is not None and len(valid_ckpts) != 1): + raise ValueError(f"Couldn't find a valid checkpoint to load from HF Hub Path `{hf_path}/checkpoints/") + + # Call to `glob` will sort steps in ascending order (if `step_to_load` is None); just grab last element + target_ckpt = Path(valid_ckpts[-1]).name + + overwatch.info(f"Downloading Model `{model_id_or_path}` Config & Checkpoint `{target_ckpt}`") + with overwatch.local_zero_first(): + relpath = Path(model_type) / model_id_or_path + config_json = hf_hub_download( + repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'config.json')!s}", cache_dir=cache_dir + ) + dataset_statistics_json = hf_hub_download( + repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'dataset_statistics.json')!s}", cache_dir=cache_dir + ) + checkpoint_pt = hf_hub_download( + repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'checkpoints' / target_ckpt)!s}", cache_dir=cache_dir + ) + + # Load VLA Config (and corresponding base VLM `ModelConfig`) from `config.json` + with open(config_json, "r") as f: + vla_cfg = json.load(f)["vla"] + model_cfg = ModelConfig.get_choice_class(vla_cfg["base_vlm"])() + + # Load Dataset Statistics for Action Denormalization + with open(dataset_statistics_json, "r") as f: + norm_stats = json.load(f) + + # = Load Individual Components necessary for Instantiating a VLA (via base VLM components) = + # =>> Print Minimal Config + overwatch.info( + f"Found Config =>> Loading & Freezing [bold blue]{model_cfg.model_id}[/] with:\n" + f" Vision Backbone =>> [bold]{model_cfg.vision_backbone_id}[/]\n" + f" LLM Backbone =>> [bold]{model_cfg.llm_backbone_id}[/]\n" + f" Arch Specifier =>> [bold]{model_cfg.arch_specifier}[/]\n" + f" Checkpoint Path =>> [underline]`{checkpoint_pt}`[/]" + ) + + # Load Vision Backbone + overwatch.info(f"Loading Vision Backbone [bold]{model_cfg.vision_backbone_id}[/]") + vision_backbone, image_transform = get_vision_backbone_and_transform( + model_cfg.vision_backbone_id, + model_cfg.image_resize_strategy, + ) + + # Load LLM Backbone --> note `inference_mode = True` by default when calling `load()` + overwatch.info(f"Loading Pretrained LLM [bold]{model_cfg.llm_backbone_id}[/] via HF Transformers") + llm_backbone, tokenizer = get_llm_backbone_and_tokenizer( + model_cfg.llm_backbone_id, + llm_max_length=model_cfg.llm_max_length, + hf_token=hf_token, + inference_mode=not load_for_training, + ) + + # Create Action Tokenizer + action_tokenizer = ActionTokenizer(llm_backbone.get_tokenizer()) + + # Load VLM using `from_pretrained` (clobbers HF syntax... eventually should reconcile) + overwatch.info(f"Loading VLA [bold blue]{model_cfg.model_id}[/] from Checkpoint") + vla = OpenVLA.from_pretrained( + checkpoint_pt, + model_cfg.model_id, + vision_backbone, + llm_backbone, + arch_specifier=model_cfg.arch_specifier, + freeze_weights=not load_for_training, + norm_stats=norm_stats, + action_tokenizer=action_tokenizer, + ) + + return vla diff --git a/capvector-oft/prismatic/models/materialize.py b/capvector-oft/prismatic/models/materialize.py new file mode 100644 index 0000000000000000000000000000000000000000..90b1fd4ba4dc15fe1a1da47db5ec970424f7c066 --- /dev/null +++ b/capvector-oft/prismatic/models/materialize.py @@ -0,0 +1,130 @@ +""" +materialize.py + +Factory class for initializing Vision Backbones, LLM Backbones, and VLMs from a set registry; provides and exports +individual functions for clear control flow. +""" + +from typing import Optional, Tuple + +from transformers import PreTrainedTokenizerBase + +from prismatic.models.backbones.llm import LLaMa2LLMBackbone, LLMBackbone, MistralLLMBackbone, PhiLLMBackbone +from prismatic.models.backbones.vision import ( + CLIPViTBackbone, + DinoCLIPViTBackbone, + DinoSigLIPViTBackbone, + DinoV2ViTBackbone, + ImageTransform, + IN1KViTBackbone, + SigLIPViTBackbone, + VisionBackbone, +) +from prismatic.models.vlms import PrismaticVLM + +# === Registries =>> Maps ID --> {cls(), kwargs} :: Different Registries for Vision Backbones, LLM Backbones, VLMs === +# fmt: off + +# === Vision Backbone Registry === +VISION_BACKBONES = { + # === 224px Backbones === + "clip-vit-l": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}}, + "siglip-vit-so400m": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}}, + "dinov2-vit-l": {"cls": DinoV2ViTBackbone, "kwargs": {"default_image_size": 224}}, + "in1k-vit-l": {"cls": IN1KViTBackbone, "kwargs": {"default_image_size": 224}}, + "dinosiglip-vit-so-224px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 224}}, + + # === Assorted CLIP Backbones === + "clip-vit-b": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}}, + "clip-vit-l-336px": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 336}}, + + # === Assorted SigLIP Backbones === + "siglip-vit-b16-224px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}}, + "siglip-vit-b16-256px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 256}}, + "siglip-vit-b16-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}}, + "siglip-vit-so400m-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}}, + + # === Fused Backbones === + "dinoclip-vit-l-336px": {"cls": DinoCLIPViTBackbone, "kwargs": {"default_image_size": 336}}, + "dinosiglip-vit-so-384px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 384}}, +} + + +# === Language Model Registry === +LLM_BACKBONES = { + # === LLaMa-2 Pure (Non-Chat) Backbones === + "llama2-7b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, + "llama2-13b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, + + # === LLaMa-2 Chat Backbones === + "llama2-7b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, + "llama2-13b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, + + # === Vicuna-v1.5 Backbones === + "vicuna-v15-7b": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, + "vicuna-v15-13b": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, + + # === Mistral v0.1 Backbones === + "mistral-v0.1-7b-pure": {"cls": MistralLLMBackbone, "kwargs": {}}, + "mistral-v0.1-7b-instruct": {"cls": MistralLLMBackbone, "kwargs": {}}, + + # === Phi-2 Backbone === + "phi-2-3b": {"cls": PhiLLMBackbone, "kwargs": {}}, +} + +# fmt: on + + +def get_vision_backbone_and_transform( + vision_backbone_id: str, image_resize_strategy: str +) -> Tuple[VisionBackbone, ImageTransform]: + """Instantiate a Vision Backbone, returning both the nn.Module wrapper class and default Image Transform.""" + if vision_backbone_id in VISION_BACKBONES: + vision_cfg = VISION_BACKBONES[vision_backbone_id] + vision_backbone: VisionBackbone = vision_cfg["cls"]( + vision_backbone_id, image_resize_strategy, **vision_cfg["kwargs"] + ) + image_transform = vision_backbone.get_image_transform() + return vision_backbone, image_transform + + else: + raise ValueError(f"Vision Backbone `{vision_backbone_id}` is not supported!") + + +def get_llm_backbone_and_tokenizer( + llm_backbone_id: str, + llm_max_length: int = 2048, + hf_token: Optional[str] = None, + inference_mode: bool = False, +) -> Tuple[LLMBackbone, PreTrainedTokenizerBase]: + if llm_backbone_id in LLM_BACKBONES: + llm_cfg = LLM_BACKBONES[llm_backbone_id] + llm_backbone: LLMBackbone = llm_cfg["cls"]( + llm_backbone_id, + llm_max_length=llm_max_length, + hf_token=hf_token, + inference_mode=inference_mode, + **llm_cfg["kwargs"], + ) + tokenizer = llm_backbone.get_tokenizer() + return llm_backbone, tokenizer + + else: + raise ValueError(f"LLM Backbone `{llm_backbone_id}` is not supported!") + + +def get_vlm( + model_id: str, + arch_specifier: str, + vision_backbone: VisionBackbone, + llm_backbone: LLMBackbone, + enable_mixed_precision_training: bool = True, +) -> PrismaticVLM: + """Lightweight wrapper around initializing a VLM, mostly for future-proofing (if one wants to add a new VLM).""" + return PrismaticVLM( + model_id, + vision_backbone, + llm_backbone, + enable_mixed_precision_training=enable_mixed_precision_training, + arch_specifier=arch_specifier, + ) diff --git a/capvector-oft/prismatic/models/projectors.py b/capvector-oft/prismatic/models/projectors.py new file mode 100644 index 0000000000000000000000000000000000000000..80aee7f02198b6a271b122e1427dad3faafb4be0 --- /dev/null +++ b/capvector-oft/prismatic/models/projectors.py @@ -0,0 +1,49 @@ +"""Implementation of additional projectors for additional inputs to the VLA models.""" +import torch +import torch.nn as nn + + +class ProprioProjector(nn.Module): + """ + Projects proprio state inputs into the LLM's embedding space. + """ + def __init__(self, llm_dim: int, proprio_dim: int) -> None: + super().__init__() + self.llm_dim = llm_dim + self.proprio_dim = proprio_dim + + self.fc1 = nn.Linear(self.proprio_dim, self.llm_dim, bias=True) + self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True) + self.act_fn1 = nn.GELU() + + def forward(self, proprio: torch.Tensor = None) -> torch.Tensor: + # proprio: (bsz, proprio_dim) + projected_features = self.fc1(proprio) + projected_features = self.act_fn1(projected_features) + projected_features = self.fc2(projected_features) + return projected_features + + +class NoisyActionProjector(nn.Module): + """ + [Diffusion] Projects noisy action inputs into the LLM's embedding space. + + Note that since each action is tokenized into 7 tokens in OpenVLA (rather + than having 1 token per action), each noisy action token will have dimension 1 + instead of 7. + """ + def __init__(self, llm_dim: int) -> None: + super().__init__() + self.llm_dim = llm_dim + self.action_token_dim = 1 + + self.fc1 = nn.Linear(self.action_token_dim, self.llm_dim, bias=True) + self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True) + self.act_fn1 = nn.GELU() + + def forward(self, noisy_actions: torch.Tensor = None) -> torch.Tensor: + # noisy_actions: (bsz, num_action_tokens=chunk_len*action_dim, 1) + projected_features = self.fc1(noisy_actions) + projected_features = self.act_fn1(projected_features) + projected_features = self.fc2(projected_features) + return projected_features diff --git a/capvector-oft/prismatic/models/registry.py b/capvector-oft/prismatic/models/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..cde181f1f50998e166fcbc376407c87777ec642b --- /dev/null +++ b/capvector-oft/prismatic/models/registry.py @@ -0,0 +1,691 @@ +""" +registry.py + +Exhaustive list of pretrained VLMs (with full descriptions / links to corresponding names and sections of paper). +""" + +# === Pretrained Model Registry === +# fmt: off +MODEL_REGISTRY = { + # === LLaVa v1.5 Reproductions === + "reproduction-llava-v15+7b": { + "model_id": "reproduction-llava-v15+7b", + "names": ["LLaVa v1.5 7B (Reproduction)"], + "description": { + "name": "LLaVa v1.5 7B (Reproduction)", + "optimization_procedure": "multi-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "reproduction-llava-v15+13b": { + "model_id": "reproduction-llava-v15+13b", + "names": ["LLaVa v1.5 13B (Reproduction)"], + "description": { + "name": "LLaVa v1.5 13B (Reproduction)", + "optimization_procedure": "multi-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 13B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + + # === Section 4.1 :: Optimization Procedure === + "one-stage+7b": { + "model_id": "one-stage+7b", + "names": [ + "One-Stage 7B", + "Single-Stage 7B", + "Frozen ViT (Single-Stage)", + "CLIP ViT-L 336px (Letterbox)", + "CLIP ViT-L 336px", + "Vicuña v1.5 7B", + "1 Epoch", + "Base", + ], + "description": { + "name": "Single-Stage 7B", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "one-stage+13b": { + "model_id": "one-stage+13b", + "names": [ + "One-Stage 13B", + "Single-Stage 13B", + "Vicuña v1.5 13B", + ], + "description": { + "name": "Single-Stage 13B", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 13B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + + "full-ft-multi-stage+7b": { + "model_id": "full-ft-multi-stage+7b", + "names": ["Finetune ViT (Multi-Stage)"], + "description": { + "name": "Finetune ViT (Multi-Stage)", + "optimization_procedure": "multi-stage-full-finetune", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "full-ft-one-stage+7b": { + "model_id": "full-ft-one-stage+7b", + "names": ["Finetune ViT (Single-Stage)"], + "description": { + "name": "Finetune ViT (Single-Stage)", + "optimization_procedure": "single-stage-full-finetune", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + + # === Section 4.2 :: Image Processing and Visual Representations === + "in1k-224px+7b": { + "model_id": "in1k-224px+7b", + "names": ["IN1K ViT-L 224px"], + "description": { + "name": "IN1K ViT-L 224px", + "optimization_procedure": "single-stage", + "visual_representation": "ImageNet-21K+1K ViT-L/16 @ 224px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + }, + }, + "dinov2-224px+7b": { + "model_id": "dinov2-224px+7b", + "names": ["DINOv2 ViT-L 224px"], + "description": { + "name": "DINOv2 ViT-L 224px", + "optimization_procedure": "single-stage", + "visual_representation": "DINOv2 ViT-L/14 @ 224px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + }, + }, + "clip-224px+7b": { + "model_id": "clip-224px+7b", + "names": ["CLIP ViT-L 224px"], + "description": { + "name": "CLIP ViT-L 224px", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 224px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + }, + }, + "siglip-224px+7b": { + "model_id": "siglip-224px+7b", + "names": ["SigLIP ViT-SO 224px"], + "description": { + "name": "SigLIP ViT-SO 224px", + "optimization_procedure": "single-stage", + "visual_representation": "SigLIP ViT-SO/14 @ 224px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + }, + }, + + "clip-336px-resize-crop+7b": { + "model_id": "clip-336px-resize-crop+7b", + "names": ["CLIP ViT-L 336px (Resize Crop)"], + "description": { + "name": "CLIP ViT-L 336px (Resize Crop)", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Resize Crop", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "clip-336px-resize-naive+7b": { + "model_id": "clip-336px-resize-naive+7b", + "names": ["CLIP ViT-L 336px (Naive Resize)", "CLIP 336px (Naive Resize)"], + "description": { + "name": "CLIP ViT-L 336px (Naive Resize)", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Naive Resize", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "siglip-384px-letterbox+7b": { + "model_id": "siglip-384px-letterbox+7b", + "names": ["SigLIP ViT-SO 384px (Letterbox)", "SigLIP ViT-SO 384px"], + "description": { + "name": "SigLIP ViT-SO 384px (Letterbox)", + "optimization_procedure": "single-stage", + "visual_representation": "SigLIP ViT-SO/14 @ 384px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "siglip-384px-resize-crop+7b": { + "model_id": "siglip-384px-resize-crop+7b", + "names": ["SigLIP ViT-SO 384px (Resize Crop)"], + "description": { + "name": "SigLIP ViT-SO 384px (Resize Crop)", + "optimization_procedure": "single-stage", + "visual_representation": "SigLIP ViT-SO/14 @ 384px", + "image_processing": "Resize Crop", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "siglip-384px-resize-naive+7b": { + "model_id": "siglip-384px-resize-naive+7b", + "names": ["SigLIP ViT-SO 384px (Naive Resize)", "SigLIP 384px (Naive Resize)"], + "description": { + "name": "SigLIP ViT-SO 384px (Naive Resize)", + "optimization_procedure": "single-stage", + "visual_representation": "SigLIP ViT-SO/14 @ 384px", + "image_processing": "Naive Resize", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + + "dinoclip-336px-letterbox+7b": { + "model_id": "dinoclip-336px-letterbox+7b", + "names": ["DINOv2 + CLIP 336px (Letterbox)"], + "description": { + "name": "DINOv2 + CLIP 336px (Letterbox)", + "optimization_procedure": "single-stage", + "visual_representation": "DINOv2 ViT-L/14 + CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "dinoclip-336px-resize-naive+7b": { + "model_id": "dinoclip-336px-resize-naive+7b", + "names": ["DINOv2 + CLIP 336px (Naive Resize)"], + "description": { + "name": "DINOv2 + CLIP 336px (Naive Resize)", + "optimization_procedure": "single-stage", + "visual_representation": "DINOv2 ViT-L/14 + CLIP ViT-L/14 @ 336px", + "image_processing": "Naive Resize", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "dinosiglip-384px-letterbox+7b": { + "model_id": "dinosiglip-384px-letterbox+7b", + "names": ["DINOv2 + SigLIP 384px (Letterbox)"], + "description": { + "name": "DINOv2 + SigLIP 384px (Letterbox)", + "optimization_procedure": "single-stage", + "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-L/14 @ 384px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "dinosiglip-384px-resize-naive+7b": { + "model_id": "dinosiglip-384px-resize-naive+7b", + "names": ["DINOv2 + SigLIP 384px (Naive Resize)"], + "description": { + "name": "DINOv2 + SigLIP 384px (Naive Resize)", + "optimization_procedure": "single-stage", + "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-L/14 @ 384px", + "image_processing": "Naive Resize", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + + # === Section 4.3 :: Language Models === + "llama2+7b": { + "model_id": "llama2+7b", + "names": ["Llama-2 7B"], + "description": { + "name": "Llama-2 7B", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Llama-2 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + }, + }, + "llama2+13b": { + "model_id": "llama2+13b", + "names": ["Llama-2 13B"], + "description": { + "name": "Llama-2 13B", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Llama-2 13B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + }, + }, + + "vicuna-no-cotraining+7b": { + "model_id": "vicuna-no-cotraining+7b", + "names": ["Vicuña v1.5 7B (No Co-training)"], + "description": { + "name": "Vicuña v1.5 7B (No Co-training)", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Multimodal-Only"], + "train_epochs": 1, + }, + }, + "llama2-no-cotraining+7b": { + "model_id": "llama2-no-cotraining+7b", + "names": ["Llama-2 7B (No Co-training)"], + "description": { + "name": "Llama-2 7B (No Co-training)", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Llama-2 7B", + "datasets": ["LLaVa v1.5 Multimodal-Only"], + "train_epochs": 1, + }, + }, + + # === Section 4.4 :: Scaling Properties === + "train-1.25-epochs+7b": { + "model_id": "train-1.25-epochs+7b", + "names": ["1.25 Epochs"], + "description": { + "name": "1.25 Epochs", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1.25, + } + }, + "train-1.5-epochs+7b": { + "model_id": "train-1.5-epochs+7b", + "names": ["1.5 Epochs"], + "description": { + "name": "1.5 Epochs", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1.5, + } + }, + "train-2-epochs+7b": { + "model_id": "train-2-epochs+7b", + "names": ["2 Epochs"], + "description": { + "name": "2 Epochs", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 2, + } + }, + "train-3-epochs+7b": { + "model_id": "train-3-epochs+7b", + "names": ["3 Epochs"], + "description": { + "name": "3 Epochs", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 3, + } + }, + + "llava-lvis4v+7b": { + "model_id": "llava-lvis4v+7b", + "names": ["Base + LVIS-4V"], + "description": { + "name": "Base + LVIS-4V", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V"], + "train_epochs": 1, + } + }, + "llava-lrv+7b": { + "model_id": "llava-lrv+7b", + "names": ["Base + LRV"], + "description": { + "name": "Base + LRV", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct", "LRV-Instruct"], + "train_epochs": 1, + } + }, + "llava-lvis4v-lrv+7b": { + "model_id": "llava-lvis4v-lrv+7b", + "names": ["Base + LVIS-4V + LRV"], + "description": { + "name": "Base + LVIS-4V + LRV", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"], + "train_epochs": 1, + } + }, + + # === + + # === CLIP Prism Models === + "prism-clip-controlled+7b": { + "model_id": "prism-clip-controlled+7b", + "names": ["Prism-CLIP 7B (Controlled)"], + "description": { + "name": "CLIP Prism 7B (Controlled)", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "prism-clip-controlled+13b": { + "model_id": "prism-clip-controlled+13b", + "names": ["Prism-CLIP 13B (Controlled)"], + "description": { + "name": "CLIP Prism 13B (Controlled)", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 13B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "prism-clip+7b": { + "model_id": "prism-clip+7b", + "names": ["Prism-CLIP 7B"], + "description": { + "name": "CLIP Prism 7B", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 7B", + "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"], + "train_epochs": 2, + }, + }, + "prism-clip+13b": { + "model_id": "prism-clip+13b", + "names": ["Prism-CLIP 13B"], + "description": { + "name": "CLIP Prism 13B", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 13B", + "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"], + "train_epochs": 2, + }, + }, + + # === SigLIP Prism Models == + "prism-siglip-controlled+7b": { + "model_id": "prism-siglip-controlled+7b", + "names": ["Prism-SigLIP 7B (Controlled)"], + "description": { + "name": "SigLIP Prism 7B (Controlled)", + "optimization_procedure": "single-stage", + "visual_representation": "SigLIP ViT-SO/14 @ 384px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "prism-siglip-controlled+13b": { + "model_id": "prism-siglip-controlled+7b", + "names": ["Prism-SigLIP 13B (Controlled)"], + "description": { + "name": "SigLIP Prism 13B (Controlled)", + "optimization_procedure": "single-stage", + "visual_representation": "SigLIP ViT-SO/14 @ 384px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 13B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "prism-siglip+7b": { + "model_id": "prism-siglip+7b", + "names": ["Prism-SigLIP 7B"], + "description": { + "name": "SigLIP Prism 7B", + "optimization_procedure": "single-stage", + "visual_representation": "SigLIP ViT-SO/14 @ 384px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 7B", + "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"], + "train_epochs": 2, + } + }, + "prism-siglip+13b": { + "model_id": "prism-siglip+13b", + "names": ["Prism-SigLIP 13B"], + "description": { + "name": "SigLIP Prism 13B", + "optimization_procedure": "single-stage", + "visual_representation": "SigLIP ViT-SO/14 @ 384px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 13B", + "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"], + "train_epochs": 2, + } + }, + + # === DINOSigLIP Prism Models === + "prism-dinosiglip-controlled+7b": { + "model_id": "prism-dinosiglip-controlled+7b", + "names": ["Prism-DINOSigLIP 7B (Controlled)", "Prism 7B (Controlled)"], + "description": { + "name": "DINOSigLIP Prism 7B (Controlled)", + "optimization_procedure": "single-stage", + "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO/14 @ 384px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "prism-dinosiglip-controlled+13b": { + "model_id": "prism-dinosiglip-controlled+13b", + "names": ["Prism-DINOSigLIP 13B (Controlled)", "Prism 13B (Controlled)"], + "description": { + "name": "DINOSigLIP Prism 13B (Controlled)", + "optimization_procedure": "single-stage", + "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO/14 @ 384px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 13B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "prism-dinosiglip+7b": { + "model_id": "prism-dinosiglip+7b", + "names": ["Prism-DINOSigLIP 7B"], + "description": { + "name": "DINOSigLIP Prism 7B", + "optimization_procedure": "single-stage", + "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO/14 @ 384px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 7B", + "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"], + "train_epochs": 2, + }, + }, + "prism-dinosiglip+13b": { + "model_id": "prism-dinosiglip+13b", + "names": ["Prism-DINOSigLIP 13B"], + "description": { + "name": "DINOSigLIP Prism 13B", + "optimization_procedure": "single-stage", + "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO/14 @ 384px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 13B", + "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"], + "train_epochs": 2, + }, + }, + + # === DINOSigLIP 224px Prism Models === + "prism-dinosiglip-224px-controlled+7b": { + "model_id": "prism-dinosiglip-224px-controlled+7b", + "names": ["Prism-DINOSigLIP 224px 7B (Controlled)"], + "description": { + "name": "DINOSigLIP 224px 7B (Controlled)", + "optimization_procedure": "single-stage", + "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO 14 @ 224px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "prism-dinosiglip-224px+7b": { + "model_id": "prism-dinosiglip-224px+7b", + "names": ["Prism-DINOSigLIP 224px 7B"], + "description": { + "name": "DINOSigLIP 224px 7B", + "optimization_procedure": "single-stage", + "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO 14 @ 224px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 7B", + "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"], + "train_epochs": 2, + } + }, + + # === Additional LLM Backbones === + "llama2-chat+7b": { + "model_id": "llama2-chat+7b", + "names": ["Llama-2 Chat 7B"], + "description": { + "name": "Llama-2 Chat 7B", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Llama-2 Chat 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "llama2-chat+13b": { + "model_id": "llama2-chat+13b", + "names": ["Llama-2 Chat 13B"], + "description": { + "name": "Llama-2 Chat 13B", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Llama-2 Chat 13B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "mistral-v0.1+7b": { + "model_id": "mistral-v0.1+7b", + "names": ["Mistral v0.1 7B"], + "description": { + "name": "Mistral v0.1 7B", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Mistral v0.1 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "mistral-instruct-v0.1+7b": { + "model_id": "mistral-instruct-v0.1+7b", + "names": ["Mistral Instruct v0.1 7B"], + "description": { + "name": "Mistral Instruct v0.1 7B", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Mistral Instruct v0.1 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "phi-2+3b": { + "model_id": "phi-2+3b", + "names": ["Phi-2 3B"], + "description": { + "name": "Phi-2 3B", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Phi-2 3B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, +} + +# Build Global Registry (Model ID, Name) -> Metadata +GLOBAL_REGISTRY = {name: v for k, v in MODEL_REGISTRY.items() for name in [k] + v["names"]} + +# fmt: on diff --git a/capvector-oft/prismatic/models/vlas/__init__.py b/capvector-oft/prismatic/models/vlas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f6889016694a373807c125c4459ecf9b9369811 --- /dev/null +++ b/capvector-oft/prismatic/models/vlas/__init__.py @@ -0,0 +1 @@ +from .openvla import OpenVLA diff --git a/capvector-oft/prismatic/models/vlas/openvla.py b/capvector-oft/prismatic/models/vlas/openvla.py new file mode 100644 index 0000000000000000000000000000000000000000..4aa1e3fe8d69f12401d80e8f18708ce8cdaf47be --- /dev/null +++ b/capvector-oft/prismatic/models/vlas/openvla.py @@ -0,0 +1,131 @@ +""" +openvla.py + +PyTorch Module defining OpenVLA as a lightweight wrapper around a PrismaticVLM; defines custom logic around +discretizing actions with the ActionTokenizer. +""" + +from typing import Dict, List, Optional + +import numpy as np +import torch +from PIL import Image +from transformers import LlamaTokenizerFast + +from prismatic.models.vlms.prismatic import PrismaticVLM +from prismatic.overwatch import initialize_overwatch +from prismatic.vla.action_tokenizer import ActionTokenizer + +# Initialize Overwatch =>> Wraps `logging.Logger` +overwatch = initialize_overwatch(__name__) + + +class OpenVLA(PrismaticVLM): + def __init__( + self, + *args, + norm_stats: Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]], + action_tokenizer: ActionTokenizer, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.norm_stats = norm_stats + self.action_tokenizer = action_tokenizer + + @torch.inference_mode() + def predict_action( + self, image: Image, instruction: str, unnorm_key: Optional[str] = None, **kwargs: str + ) -> np.ndarray: + """ + Core function for VLA inference; maps input image and task instruction to continuous action (de-tokenizes). + + @param image: PIL Image as [height, width, 3] + @param instruction: Task instruction string + @param unnorm_key: Optional dataset name for retrieving un-normalizing statistics; if None, checks that model + was trained only on a single dataset, and retrieves those statistics. + + @return Unnormalized (continuous) action vector --> end-effector deltas. + """ + image_transform, tokenizer = self.vision_backbone.image_transform, self.llm_backbone.tokenizer + + # Build VLA Prompt + prompt_builder = self.get_prompt_builder() + prompt_builder.add_turn(role="human", message=f"What action should the robot take to {instruction.lower()}?") + prompt_text = prompt_builder.get_prompt() + + # Prepare Inputs + input_ids = tokenizer(prompt_text, truncation=True, return_tensors="pt").input_ids.to(self.device) + if isinstance(tokenizer, LlamaTokenizerFast): + # If the special empty token ('') does not already appear after the colon (':') token in the prompt + # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time + if not torch.all(input_ids[:, -1] == 29871): + input_ids = torch.cat( + (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1 + ) + else: + raise ValueError(f"Unsupported `tokenizer` type = {type(tokenizer)}") + + # Preprocess Image + pixel_values = image_transform(image) + if isinstance(pixel_values, torch.Tensor): + pixel_values = pixel_values[None, ...].to(self.device) + elif isinstance(pixel_values, dict): + pixel_values = {k: v[None, ...].to(self.device) for k, v in pixel_values.items()} + else: + raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") + + # Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()` + autocast_dtype = self.llm_backbone.half_precision_dtype + with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training): + # fmt: off + generated_ids = super(PrismaticVLM, self).generate( + input_ids=input_ids, # Shape: [1, seq] + pixel_values=pixel_values, # Shape: [1, 3, res, res] or Dict[str, ...] + max_new_tokens=self.get_action_dim(unnorm_key), + **kwargs + ) + # fmt: on + + # Extract predicted action tokens and translate into (normalized) continuous actions + predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :] + normalized_actions = self.action_tokenizer.decode_token_ids_to_actions(predicted_action_token_ids.cpu().numpy()) + + # Un-normalize Actions + action_norm_stats = self.get_action_stats(unnorm_key) + mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool)) + action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"]) + actions = np.where( + mask, + 0.5 * (normalized_actions + 1) * (action_high - action_low) + action_low, + normalized_actions, + ) + + return actions + + @staticmethod + def _check_unnorm_key(norm_stats: Dict, unnorm_key: str) -> str: + if unnorm_key is None: + assert len(norm_stats) == 1, ( + f"Your model was trained on more than one dataset, please pass a `unnorm_key` from the following " + f"options to choose the statistics used for un-normalizing actions: {norm_stats.keys()}" + ) + unnorm_key = next(iter(norm_stats.keys())) + + # Error Handling + assert ( + unnorm_key in norm_stats + ), f"The `unnorm_key` you chose is not in the set of available statistics; choose from: {norm_stats.keys()}" + + return unnorm_key + + def get_action_dim(self, unnorm_key: Optional[str] = None) -> int: + """Dimensionality of the policy's action space.""" + unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) + + return len(self.norm_stats[unnorm_key]["action"]["q01"]) + + def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict: + """Dimensionality of the policy's action space.""" + unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) + + return self.norm_stats[unnorm_key]["action"] diff --git a/capvector-oft/prismatic/models/vlms/__init__.py b/capvector-oft/prismatic/models/vlms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..07b1d606a27c25694d7f2f24ac2bbf686adfe0ab --- /dev/null +++ b/capvector-oft/prismatic/models/vlms/__init__.py @@ -0,0 +1 @@ +from .prismatic import PrismaticVLM diff --git a/capvector-oft/prismatic/models/vlms/base_vlm.py b/capvector-oft/prismatic/models/vlms/base_vlm.py new file mode 100644 index 0000000000000000000000000000000000000000..24e180470926cccb7f2d059f83197efb17b1a573 --- /dev/null +++ b/capvector-oft/prismatic/models/vlms/base_vlm.py @@ -0,0 +1,108 @@ +""" +base_vlm.py + +Abstract class definition of a Vision-Language Model (VLM), with full annotations of class methods, utility functions, +and initialization logic. This is mostly to future-proof the codebase; while all our experiments instantiate +from PrismaticVLM, theoretically, this base class should be general enough to cover almost all models (e.g., IDEFICS, +PALI, Fuyu) in the future. + +We use Abstract base classes *sparingly* -- mostly as a way to encapsulate any redundant logic or nested inheritance +(e.g., dependence on nn.Module, HF PretrainedModel, etc.). For other abstract objects (e.g., Tokenizers/Transforms), +prefer Protocol definitions instead. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Callable, List, Optional + +import torch +import torch.nn as nn +from transformers import GenerationMixin, PretrainedConfig +from transformers.modeling_outputs import CausalLMOutputWithPast + +from prismatic.models.backbones.llm import LLMBackbone +from prismatic.models.backbones.llm.prompting import PromptBuilder +from prismatic.models.backbones.vision import VisionBackbone + + +# === Abstract Base Class for arbitrary Vision-Language Models === +class VLM(nn.Module, GenerationMixin, ABC): + def __init__( + self, + model_family: str, + model_id: str, + vision_backbone: VisionBackbone, + llm_backbone: LLMBackbone, + enable_mixed_precision_training: bool = True, + ) -> None: + super().__init__() + self.model_family, self.model_id = model_family, model_id + self.vision_backbone, self.llm_backbone = vision_backbone, llm_backbone + self.enable_mixed_precision_training = enable_mixed_precision_training + + # Instance Attributes for a generic VLM + self.all_module_keys, self.trainable_module_keys = None, None + + # === GenerationMixin Expected Attributes =>> *DO NOT MODIFY* === + self.generation_config = self.llm_backbone.llm.generation_config + self.main_input_name = "input_ids" + + @property + def device(self) -> torch.device: + """Borrowed from `transformers.modeling_utils.py` -- checks parameter device; assumes model on *ONE* device!""" + return next(self.parameters()).device + + @classmethod + @abstractmethod + def from_pretrained( + cls, + pretrained_checkpoint: Path, + model_family: str, + model_id: str, + vision_backbone: VisionBackbone, + llm_backbone: LLMBackbone, + **kwargs: str, + ) -> VLM: ... + + @abstractmethod + def get_prompt_builder(self, system_prompt: Optional[str] = None) -> PromptBuilder: ... + + @abstractmethod + def freeze_backbones(self, stage: str) -> None: ... + + @abstractmethod + def load_from_checkpoint(self, stage: str, run_dir: Path, pretrained_checkpoint: Optional[Path] = None) -> None: ... + + @abstractmethod + def get_fsdp_wrapping_policy(self) -> Callable: ... + + @abstractmethod + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + multimodal_indices: Optional[torch.LongTensor] = None, + ) -> CausalLMOutputWithPast: ... + + # === GenerationMixin Expected Properties & Methods (DO NOT MODIFY) === + @staticmethod + def can_generate() -> bool: + return True + + @property + def config(self) -> PretrainedConfig: + return self.llm_backbone.llm.config + + # => Beam Search Utility + def _reorder_cache(self, past_key_values, beam_idx): + return self.llm_backbone.llm._reorder_cache(past_key_values, beam_idx) diff --git a/capvector-oft/prismatic/models/vlms/prismatic.py b/capvector-oft/prismatic/models/vlms/prismatic.py new file mode 100644 index 0000000000000000000000000000000000000000..07477f2a14b1da650c4da747a24e407e2633e8a6 --- /dev/null +++ b/capvector-oft/prismatic/models/vlms/prismatic.py @@ -0,0 +1,621 @@ +""" +prismatic.py + +PyTorch Module defining a PrismaticVLM, our general interface for defining the various different VLMs in our work. + +Notes: + - For now, we don't subclass `transformers.PretrainedModel` (or CausalLM). Instead, we assume a very limited subset + of the {Model}ForCausalLM API that enables dispatch to the underlying LLM's `generate` utilities (feeding inputs + through our custom projection shim). +""" + +from __future__ import annotations + +from functools import partial +from pathlib import Path +from typing import Callable, Dict, List, Optional, Type, Union + +import torch +from PIL import Image +from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy +from transformers.modeling_outputs import CausalLMOutputWithPast + +from prismatic.models.backbones.llm import LLMBackbone +from prismatic.models.backbones.llm.prompting import PromptBuilder +from prismatic.models.backbones.vision import VisionBackbone +from prismatic.models.vlms.base_vlm import VLM +from prismatic.overwatch import initialize_overwatch +from prismatic.util.nn_utils import FusedMLPProjector, LinearProjector, MLPProjector + +# Initialize Overwatch =>> Wraps `logging.Logger` +overwatch = initialize_overwatch(__name__) + + +# HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels) +IGNORE_INDEX = -100 + + +class PrismaticVLM(VLM): + def __init__( + self, + model_id: str, + vision_backbone: VisionBackbone, + llm_backbone: LLMBackbone, + enable_mixed_precision_training: bool = True, + arch_specifier: str = "gelu-mlp", + **kwargs, + ) -> None: + super().__init__( + "prismatic", + model_id, + vision_backbone, + llm_backbone, + enable_mixed_precision_training=enable_mixed_precision_training, + ) + + # Set Weight Initialization Seed for Projector Consistency + torch.manual_seed(vision_backbone.embed_dim) + + # Initialize Projection (Adapter) based on `arch_specifier` + self.arch_specifier = arch_specifier + if arch_specifier == "linear": + self.projector = LinearProjector(vision_backbone.embed_dim, llm_backbone.embed_dim) + elif arch_specifier.endswith("fused-gelu-mlp"): + self.projector = FusedMLPProjector(vision_backbone.embed_dim, llm_backbone.embed_dim) + elif arch_specifier.endswith("gelu-mlp"): + self.projector = MLPProjector(vision_backbone.embed_dim, llm_backbone.embed_dim) + else: + raise ValueError(f"PrismaticVLM with `{arch_specifier = }` is not supported!") + + # Trackers + self.vision_backbone_requires_grad = False + + # Set Module Keys =>> used in Checkpoint Saving / Model Loading + self.all_module_keys = ["vision_backbone", "llm_backbone", "projector"] + self.trainable_module_keys = [] + + # === Generation Utilities === + # => For computing likelihoods --> get tokens corresponding to "True", "False" and "Yes", "No" + self.string2idx = {} + for trigger_string in ["True", "False", "Yes", "No"] + [chr(ord("A") + i) for i in range(26)]: + token_idx_list = self.llm_backbone.tokenizer.encode(trigger_string, add_special_tokens=False) + assert len(token_idx_list) == 1, f'String "{trigger_string}" is tokenized as more than one token!' + self.string2idx[trigger_string] = token_idx_list[0] + + @classmethod + def from_pretrained( + cls, + pretrained_checkpoint: Path, + model_id: str, + vision_backbone: VisionBackbone, + llm_backbone: LLMBackbone, + enable_mixed_precision_training: bool = True, + arch_specifier: str = "gelu-mlp", + freeze_weights: bool = True, + **kwargs, + ) -> PrismaticVLM: + """Initialize a PrismaticVLM from a pretrained checkpoint, freezing all weights, tailored for inference.""" + vlm = cls( + model_id, + vision_backbone, + llm_backbone, + enable_mixed_precision_training=enable_mixed_precision_training, + arch_specifier=arch_specifier, + **kwargs, + ) + + # Load from Checkpoint (Custom --> should load both *projector* and *llm* weights) + model_state_dict = torch.load(pretrained_checkpoint, map_location="cpu")["model"] + assert ( + "projector" in model_state_dict and "llm_backbone" in model_state_dict + ), "PrismaticVLM `from_pretrained` expects checkpoint with keys for `projector` AND `llm_backbone`!" + + vlm.projector.load_state_dict(model_state_dict["projector"]) + vlm.llm_backbone.load_state_dict(model_state_dict["llm_backbone"]) + if "vision_backbone" in model_state_dict.keys(): + vlm.vision_backbone.load_state_dict(model_state_dict["vision_backbone"]) + + # Freeze Weights + if freeze_weights: + vlm.requires_grad_(False) + vlm.eval() + + return vlm + + def get_prompt_builder(self, system_prompt: Optional[str] = None) -> PromptBuilder: + prompt_initializer: Type[PromptBuilder] = self.llm_backbone.prompt_builder_fn + return prompt_initializer(self.model_family, system_prompt=system_prompt) + + def freeze_backbones(self, stage: str) -> None: + """ + This function sets `requires_grad_` on each of the component modules explicitly, depending on stage. + + We support two separate stages --> "align" and "finetune". + => "align" --> vision_backbone*, llm_backbone* are frozen; only the `projector` is trained. + => "finetune" --> vision_backbone* is frozen; both `projector` and `llm_backbone` are trained. + + :param stage: Pretraining stage in < "align" | "finetune" | "full-finetune" | "vla-train" | "vla-full-train" > + """ + if stage == "align": + self.vision_backbone.requires_grad_(False) + self.llm_backbone.requires_grad_(False) + self.projector.requires_grad_(True) + + # Add to `self.trainable_module_keys` + self.trainable_module_keys = ["projector"] + + # Update Trackers + self.vision_backbone_requires_grad = False + + # Explicitly Log Frozen / Trainable Components + overwatch.info(f"[Frozen] 🥶 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1) + overwatch.info(f"[Frozen] 🥶 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1) + overwatch.info(f"[TRAINABLE] 🔥 =>> Projector `{self.arch_specifier}`", ctx_level=1) + + elif stage in {"finetune", "vla-train"}: + self.vision_backbone.requires_grad_(False) + self.llm_backbone.requires_grad_(True) + self.projector.requires_grad_(True) + + # Add to `self.trainable_module_keys` + self.trainable_module_keys = ["projector", "llm_backbone"] + + # Update Trackers + self.vision_backbone_requires_grad = False + + # Explicitly Log Frozen / Unfrozen Components + overwatch.info(f"[Frozen] 🥶 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1) + overwatch.info(f"[TRAINABLE] 🔥 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1) + overwatch.info(f"[TRAINABLE] 🔥 =>> Projector `{self.arch_specifier}`", ctx_level=1) + + elif stage in {"full-finetune", "vla-full-train"}: + self.vision_backbone.dtype = torch.float32 + self.vision_backbone.requires_grad_(True) + self.llm_backbone.requires_grad_(True) + self.projector.requires_grad_(True) + + # Add to `self.trainable_module_keys` + self.trainable_module_keys = ["vision_backbone", "projector", "llm_backbone"] + + # Update Trackers + self.vision_backbone_requires_grad = True + + # Explicitly Log Frozen / Unfrozen Components + overwatch.info(f"[TRAINABLE] 🔥 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1) + overwatch.info(f"[TRAINABLE] 🔥 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1) + overwatch.info(f"[TRAINABLE] 🔥 =>> Projector `{self.arch_specifier}`", ctx_level=1) + + elif stage in {"last-layer-finetune", "vla-last-layer-train"}: + self.vision_backbone.requires_grad_(False) + self.projector.requires_grad_(False) + self.llm_backbone.requires_grad_(False) + + # Unfreeze final LLM layer + for module in self.llm_backbone.last_layer_finetune_modules: + module.requires_grad_(True) + + # Add to `self.trainable_module_keys` + self.trainable_module_keys = ["llm_backbone"] + + # Update Trackers + self.vision_backbone_requires_grad = False + + # Explicitly Log Frozen / Unfrozen Components + # fmt: off + overwatch.info(f"[Frozen] 🥶 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1) # noqa: E501 + overwatch.info(f"[Frozen, except last layer] 🥶🔥 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1) # noqa: E501 + overwatch.info(f"[Frozen] 🥶 =>> Projector `{self.arch_specifier}`", ctx_level=1) + # fmt: on + + elif stage in {"vla-sandwich-train"}: + self.vision_backbone.dtype = torch.float32 + self.vision_backbone.requires_grad_(True) + self.projector.requires_grad_(True) + self.llm_backbone.requires_grad_(False) + + # Unfreeze final LLM layer + for module in self.llm_backbone.last_layer_finetune_modules: + module.requires_grad_(True) + + # Add to `self.trainable_module_keys` + self.trainable_module_keys = ["vision_backbone", "projector", "llm_backbone"] + + # Update Trackers + self.vision_backbone_requires_grad = True + + # Explicitly Log Frozen / Unfrozen Components + # fmt: off + overwatch.info(f"[TRAINABLE] 🔥 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1) # noqa: E501 + overwatch.info(f"[Frozen, except last layer] 🥶🔥 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1) # noqa: E501 + overwatch.info(f"[TRAINABLE] 🔥 =>> Projector `{self.arch_specifier}`", ctx_level=1) + # fmt: on + + else: + raise ValueError(f"Stage `{stage}` is not supported for LLaVa! Try < align | finetune >") + + overwatch.debug("##################################################") + overwatch.debug("##### Trainable Network Parameters: #####") + overwatch.debug("##################################################") + for name, param in self.named_parameters(): + if param.requires_grad: + overwatch.debug(name) + + def load_from_checkpoint(self, stage: str, run_dir: Path, pretrained_checkpoint: Optional[Path] = None) -> None: + """Load weights from checkpoint (if required by the given stage).""" + assert stage in {"align", "finetune", "full-finetune"}, f"Stage {stage} is not supported!" + + # If we're running a `no-align` architecture, we're good! + if self.arch_specifier.startswith("no-align"): + overwatch.info( + f"PrismaticVLM with `{self.arch_specifier = }` does not require pretrained weights!", ctx_level=1 + ) + return + + # Otherwise, handle stage-specific logic! + if stage == "align": + overwatch.info("Stage `align` does not require pretrained weights =>> Starting Training", ctx_level=1) + return + + # Otherwise, load from `pretrained_checkpoint` or match on `run_dir` (s/+stage-finetune/+stage-align/g) + overwatch.info("Stage `finetune` requires `align` pretrained weights", ctx_level=1) + + # Config specifies path to a checkpoint to load + if pretrained_checkpoint is not None: + overwatch.info(f"Loading from Provided Checkpoint `{pretrained_checkpoint}`", ctx_level=1) + model_state_dict = torch.load(pretrained_checkpoint)["model"] + self.projector.load_state_dict(model_state_dict["projector"]) + + return + + # [Contract] If no `pretrained_checkpoint`, assume `align` lives in the run directory; string substitution! + model, scale, _, seed = run_dir.name.split("+") + align_dirs = [ + d + for d in run_dir.parent.iterdir() + if (d.name.startswith(f"{model}+{scale}") and d.name.endswith(f"+stage-align+{seed}")) + ] + assert len(align_dirs) == 1, "Multiple or No Valid Pretrained Directories Exist -- Double Check `runs`!" + if (pretrained_checkpoint := (align_dirs[0] / "checkpoints" / "latest-checkpoint.pt")).exists(): + overwatch.info(f"Loading from Discovered Checkpoint `{pretrained_checkpoint}`", ctx_level=1) + model_state_dict = torch.load(pretrained_checkpoint)["model"] + self.projector.load_state_dict(model_state_dict["projector"]) + else: + raise ValueError(f"Could not find valid `align` checkpoint at {pretrained_checkpoint}!") + + def get_fsdp_wrapping_policy(self) -> Callable: + """Return an FSDP _or_policy over the policies returned by each individual backbone (and our VLM policy).""" + vision_fsdp_wrapping_policy = self.vision_backbone.get_fsdp_wrapping_policy() + llm_fsdp_wrapping_policy = self.llm_backbone.get_fsdp_wrapping_policy() + + # Get Prismatic Wrapping Policy =>> just a module wrapping policy around `self.projector` + prismatic_fsdp_wrapping_policy = partial( + _module_wrap_policy, + module_classes={LinearProjector, MLPProjector, FusedMLPProjector}, + ) + + # Return union (_or_) over constituent policies + # => Note: there is *not* a fall-through policy; any module that isn't covered by the above constituents will + # automatically be folded into the root VLM FSDP instance. + return partial( + _or_policy, + policies=[ + vision_fsdp_wrapping_policy, + llm_fsdp_wrapping_policy, + prismatic_fsdp_wrapping_policy, + ], + ) + + # Note =>> We're not explicitly subclassing `PreTrainedModel` because we don't need the bloat; however, `forward()` + # *must* match the signature of a `{Model}ForCausalLM` so that we can inherit from `GenerationMixin` + + # ruff: noqa: C901 + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + multimodal_indices: Optional[torch.LongTensor] = None, + ) -> CausalLMOutputWithPast: + """Run a forward pass through the VLM, returning a CausalLMOutputWithPast instance (contains loss).""" + + # Handle Inference (leverage cache, short-circuit on just LLM forward) + if input_ids.shape[1] == 1 and past_key_values is not None: + # We're leveraging the cache, so just redirect to `self.llm_backbone` with `input_ids` and `past_key_values` + output = self.llm_backbone( + input_ids=input_ids, + attention_mask=None, + position_ids=None, + past_key_values=past_key_values, + inputs_embeds=None, + labels=None, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return output + + elif input_ids.shape[1] == 1 or pixel_values is None: + raise RuntimeError("Invalid `forward()` call!") + + # Handle Multimodal Indices is None --> pretend like the batch is fully multimodal (always image + text)! + if multimodal_indices is None: + multimodal_indices = torch.arange(len(input_ids), dtype=torch.long, device=input_ids.device) + + # Handle Multimodal Indices is Empty (len == 0) --> simple unimodal forward + elif len(multimodal_indices) == 0: + return self.llm_backbone( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=None, + past_key_values=past_key_values, + inputs_embeds=None, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # Run Visual Feature Extraction + with torch.set_grad_enabled(self.vision_backbone_requires_grad): + if isinstance(pixel_values, dict): + patch_features = self.vision_backbone({k: pixel_values[k][multimodal_indices] for k in pixel_values}) + else: + patch_features = self.vision_backbone(pixel_values[multimodal_indices]) + + # Projection Logic :: [bsz, num_patches, llm_embed_dim] =>> num_patches = (2 *) (256 + 1) for ViT-L + CLS + projected_patch_embeddings = self.projector(patch_features) + projected_patch_attention_mask = None + if attention_mask is not None: + projected_patch_attention_mask = torch.full( + (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]), + True, + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + # Get Input Embeddings from LLM Backbone :: [bsz, input_seq_len, llm_embed_dim] + input_embeddings = self.llm_backbone.embed_input_ids(input_ids) + + # Build Multimodal Embeddings (and build resulting attention mask) + multimodal_embeddings = torch.cat( + [ + input_embeddings[multimodal_indices, :1, :], + projected_patch_embeddings, + input_embeddings[multimodal_indices, 1:, :], + ], + dim=1, + ) + multimodal_attention_mask = None + if attention_mask is not None: + multimodal_attention_mask = torch.cat( + [ + attention_mask[multimodal_indices, :1], + projected_patch_attention_mask, + attention_mask[multimodal_indices, 1:], + ], + dim=1, + ) + + # [Contract] We assume the first token of `labels` (associated with ) is already marked as "IGNORE" + # => We'll ignore the per-token outputs for each of the patch embeddings as well! + multimodal_labels = None + if labels is not None: + projected_patch_labels = torch.full( + (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]), + IGNORE_INDEX, + dtype=labels.dtype, + device=labels.device, + ) + multimodal_labels = torch.cat( + [labels[multimodal_indices, :1], projected_patch_labels, labels[multimodal_indices, 1:]], dim=1 + ) + + # === Add Unimodal Handling === + + # Create Fused Embeddings, Attention Mask, and Labels by Merging with "unimodal" Inputs (if applicable) + unimodal_indices = torch.tensor( + [idx for idx in range(len(input_ids)) if idx not in multimodal_indices], + dtype=torch.long, + device=multimodal_indices.device, + ) + + # No "unimodal" data --> Fused == Multimodal + if len(unimodal_indices) == 0: + fused_embeddings = multimodal_embeddings + fused_attention_mask = multimodal_attention_mask + fused_labels = multimodal_labels + + else: + # Otherwise --> Merge w/ unimodal data + + # This doesn't matter --> but in the "normal" case this is the embedding of the token + # => NOTE :: Verified that `zeros/randn/empty/ embedding` all return the same result! + unimodal_embeddings_pad = torch.zeros( + (len(unimodal_indices), projected_patch_embeddings.shape[1], input_embeddings.shape[2]), + dtype=input_embeddings.dtype, + device=input_embeddings.device, + ) + unimodal_attention_pad = torch.full( + (len(unimodal_indices), projected_patch_embeddings.shape[1]), + False, + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + unimodal_labels_pad = torch.full( + (len(unimodal_indices), projected_patch_embeddings.shape[1]), + IGNORE_INDEX, + dtype=labels.dtype, + device=labels.device, + ) + + unimodal_embeddings = torch.cat([input_embeddings[unimodal_indices], unimodal_embeddings_pad], dim=1) + unimodal_attention_mask = torch.cat([attention_mask[unimodal_indices], unimodal_attention_pad], dim=1) + unimodal_labels = torch.cat([labels[unimodal_indices], unimodal_labels_pad], dim=1) + + # Create "Fused" Tensors by Stacking Multimodal & Unimodal + fused_embeddings = torch.vstack([multimodal_embeddings, unimodal_embeddings]) + fused_attention_mask = torch.vstack([multimodal_attention_mask, unimodal_attention_mask]) + fused_labels = torch.vstack([multimodal_labels, unimodal_labels]) + + # Run LLM Forward --> returns CausalLMOutputWithPast! + return self.llm_backbone( + input_ids=None, + attention_mask=fused_attention_mask, + position_ids=None, + past_key_values=past_key_values, + inputs_embeds=fused_embeddings, + labels=fused_labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # === GenerationMixin Methods === + # => Note: The following methods override the functionality of `transformers.GenerationMixin`; these expect the + # contract in each of the function signatures, and also expect our `forward` function to roughly take + # the same arguments as the underlying LLM (see `LlamaModelForCausalLM` as an example) + + def prepare_inputs_for_generation( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + **kwargs: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + """Borrowed from `LlamaForCausalLM` --> in general, just handles caching logic during generation.""" + if past_key_values: + input_ids = input_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + # Make sure `pixel_values` are preserved in `model_inputs` + model_inputs.update( + { + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "past_key_values": past_key_values, + "use_cache": use_cache, + } + ) + + return model_inputs + + @torch.inference_mode() + def generate_batch( + self, + pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]], + texts: List[str], + return_string_probabilities: Optional[List[str]] = None, + **kwargs: str, + ) -> Union[List[str], List[List[float]]]: + # For now, only support generation with a batch size of 1 for simplicity + tokenizer = self.llm_backbone.tokenizer + + # Prepare Inputs + batch_input_ids = [ + tokenizer(text, truncation=True, return_tensors="pt").input_ids.to(self.device) for text in texts + ] + if isinstance(pixel_values, torch.Tensor): + pixel_values = pixel_values[None, ...].to(self.device) + elif isinstance(pixel_values, dict): + pixel_values = {k: v[None, ...].to(self.device) for k, v in pixel_values.items()} + else: + raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") + + # Create Output Lists + gen_texts, gen_probabilities = [], [] + + # Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()` + autocast_dtype = self.llm_backbone.half_precision_dtype + with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training): + for idx, input_ids in enumerate(batch_input_ids): + if isinstance(pixel_values, torch.Tensor): + pixel_values = pixel_values[idx] + elif isinstance(pixel_values, dict): + pixel_values = {k: pixel_values[k][idx] for k in pixel_values} + else: + raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") + + # Handle `return_string_probabilities` + if return_string_probabilities is None: + full_out_ids = super().generate(input_ids=input_ids, pixel_values=pixel_values, **kwargs) + gen_ids = full_out_ids[0, input_ids.shape[1] :] + + # Decode `gen_ids` and strip any tokens + gen_texts.append(tokenizer.decode(gen_ids, skip_special_tokens=True).strip()) + + else: + full_out_dict = super().generate( + input_ids=input_ids, + pixel_values=pixel_values, + output_scores=True, + return_dict_in_generate=True, + **kwargs, + ) + + # Generation pattern should usually be [TOKEN] for True/False and Yes/No Generations + gen_ids = full_out_dict.sequences[0, input_ids.shape[1] :] + + # [Debug] Verify that the first token generated is in `self.string2idx.values()` + # assert gen_ids[0] in self.string2idx.values(), "Generated ID not in mapping!" + + # Decode `gen_ids` and strip any tokens + gen_texts.append(tokenizer.decode(gen_ids, skip_special_tokens=True).strip()) + + # Get all token probabilities --> softmax over logits + token_probs = torch.softmax(full_out_dict.scores[0][0], dim=0) + + # Get *normalized* probabilities for all values in `return_token_probabilities` + slice_idxs = torch.tensor([self.string2idx[s] for s in return_string_probabilities]) + string_probs_unnormalized = token_probs[slice_idxs] + string_probs = string_probs_unnormalized / string_probs_unnormalized.sum() + gen_probabilities.append(string_probs.cpu().numpy().tolist()) + + return gen_texts if return_string_probabilities is None else gen_probabilities + + @torch.inference_mode() + def generate(self, image: Image, prompt_text: str, **kwargs: str) -> str: + # For now, only support generation with a batch size of 1 for simplicity + image_transform, tokenizer = self.vision_backbone.image_transform, self.llm_backbone.tokenizer + + # Prepare Inputs + input_ids = tokenizer(prompt_text, truncation=True, return_tensors="pt").input_ids.to(self.device) + pixel_values = image_transform(image) + if isinstance(pixel_values, torch.Tensor): + pixel_values = pixel_values[None, ...].to(self.device) + elif isinstance(pixel_values, dict): + pixel_values = {k: v[None, ...].to(self.device) for k, v in pixel_values.items()} + else: + raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") + + # Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()` + autocast_dtype = self.llm_backbone.half_precision_dtype + with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training): + # fmt: off + generated_ids = super().generate( + input_ids=input_ids, # Shape: [1, seq] + pixel_values=pixel_values, # Shape: [1, 3, res, res] or Dict[str, Shape[1, 3, res, res]] + **kwargs + ) + # fmt: on + + generated_text = tokenizer.decode(generated_ids[0, input_ids.shape[1] :], skip_special_tokens=True).strip() + + return generated_text diff --git a/capvector-oft/prismatic/overwatch/__init__.py b/capvector-oft/prismatic/overwatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..157cd648c6b711bc24f59ea2b356b1c0816a1c11 --- /dev/null +++ b/capvector-oft/prismatic/overwatch/__init__.py @@ -0,0 +1 @@ +from .overwatch import initialize_overwatch diff --git a/capvector-oft/prismatic/overwatch/overwatch.py b/capvector-oft/prismatic/overwatch/overwatch.py new file mode 100644 index 0000000000000000000000000000000000000000..2e72048bddbf9a4b622b736ba7d25b6c70ac9a04 --- /dev/null +++ b/capvector-oft/prismatic/overwatch/overwatch.py @@ -0,0 +1,147 @@ +""" +overwatch.py + +Utility class for creating a centralized/standardized logger (built on Rich) and accelerate handler. +""" + +import logging +import logging.config +import os +from contextlib import nullcontext +from logging import LoggerAdapter +from typing import Any, Callable, ClassVar, Dict, MutableMapping, Tuple, Union + +# Overwatch Default Format String +RICH_FORMATTER, DATEFMT = "| >> %(message)s", "%m/%d [%H:%M:%S]" + +# Set Logging Configuration +LOG_CONFIG = { + "version": 1, + "disable_existing_loggers": True, + "formatters": {"simple-console": {"format": RICH_FORMATTER, "datefmt": DATEFMT}}, + "handlers": { + "console": { + "class": "rich.logging.RichHandler", + "formatter": "simple-console", + "markup": True, + "rich_tracebacks": True, + "show_level": True, + "show_path": True, + "show_time": True, + } + }, + "root": {"level": "INFO", "handlers": ["console"]}, +} +logging.config.dictConfig(LOG_CONFIG) + + +# === Custom Contextual Logging Logic === +class ContextAdapter(LoggerAdapter): + CTX_PREFIXES: ClassVar[Dict[int, str]] = {**{0: "[*] "}, **{idx: "|=> ".rjust(4 + (idx * 4)) for idx in [1, 2, 3]}} + + def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> Tuple[str, MutableMapping[str, Any]]: + ctx_level = kwargs.pop("ctx_level", 0) + return f"{self.CTX_PREFIXES[ctx_level]}{msg}", kwargs + + +class DistributedOverwatch: + def __init__(self, name: str) -> None: + """Initializer for an Overwatch object that wraps logging & `accelerate.PartialState`.""" + from accelerate import PartialState + + # Note that PartialState is always safe to initialize regardless of `accelerate launch` or `torchrun` + # =>> However, might be worth actually figuring out if we need the `accelerate` dependency at all! + self.logger, self.distributed_state = ContextAdapter(logging.getLogger(name), extra={}), PartialState() + + # Logger Delegation (for convenience; would be nice to just compose & dynamic dispatch eventually) + self.debug = self.logger.debug + self.info = self.logger.info + self.warning = self.logger.warning + self.error = self.logger.error + self.critical = self.logger.critical + + # Logging Defaults =>> only Log `INFO` on Main Process, `ERROR` on others! + self.logger.setLevel(logging.INFO if self.distributed_state.is_main_process else logging.ERROR) + + @property + def rank_zero_only(self) -> Callable[..., Any]: + return self.distributed_state.on_main_process + + @property + def local_zero_only(self) -> Callable[..., Any]: + return self.distributed_state.on_local_main_process + + @property + def rank_zero_first(self) -> Callable[..., Any]: + return self.distributed_state.main_process_first + + @property + def local_zero_first(self) -> Callable[..., Any]: + return self.distributed_state.local_main_process_first + + def is_rank_zero(self) -> bool: + return self.distributed_state.is_main_process + + def rank(self) -> int: + return self.distributed_state.process_index + + def local_rank(self) -> int: + return self.distributed_state.local_process_index + + def world_size(self) -> int: + return self.distributed_state.num_processes + + +class PureOverwatch: + def __init__(self, name: str) -> None: + """Initializer for an Overwatch object that just wraps logging.""" + self.logger = ContextAdapter(logging.getLogger(name), extra={}) + + # Logger Delegation (for convenience; would be nice to just compose & dynamic dispatch eventually) + self.debug = self.logger.debug + self.info = self.logger.info + self.warning = self.logger.warning + self.error = self.logger.error + self.critical = self.logger.critical + + # Logging Defaults =>> INFO + self.logger.setLevel(logging.INFO) + + @staticmethod + def get_identity_ctx() -> Callable[..., Any]: + def identity(fn: Callable[..., Any]) -> Callable[..., Any]: + return fn + + return identity + + @property + def rank_zero_only(self) -> Callable[..., Any]: + return self.get_identity_ctx() + + @property + def local_zero_only(self) -> Callable[..., Any]: + return self.get_identity_ctx() + + @property + def rank_zero_first(self) -> Callable[..., Any]: + return nullcontext + + @property + def local_zero_first(self) -> Callable[..., Any]: + return nullcontext + + @staticmethod + def is_rank_zero() -> bool: + return True + + @staticmethod + def rank() -> int: + return 0 + + @staticmethod + def world_size() -> int: + return 1 + + +def initialize_overwatch(name: str) -> Union[DistributedOverwatch, PureOverwatch]: + return DistributedOverwatch(name) if int(os.environ.get("WORLD_SIZE", -1)) != -1 else PureOverwatch(name) diff --git a/capvector-oft/prismatic/preprocessing/__init__.py b/capvector-oft/prismatic/preprocessing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5b3a1dcb91afb745463f05da3fe7a547f030a4e7 --- /dev/null +++ b/capvector-oft/prismatic/preprocessing/__init__.py @@ -0,0 +1,2 @@ +from .download import convert_to_jpg, download_extract +from .materialize import get_dataset_and_collator diff --git a/capvector-oft/prismatic/preprocessing/datasets/__init__.py b/capvector-oft/prismatic/preprocessing/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..04f5fe1a8dd8122308f1a5f707b48e7cc87a2311 --- /dev/null +++ b/capvector-oft/prismatic/preprocessing/datasets/__init__.py @@ -0,0 +1 @@ +from .datasets import AlignDataset, FinetuneDataset diff --git a/capvector-oft/prismatic/preprocessing/datasets/datasets.py b/capvector-oft/prismatic/preprocessing/datasets/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..be86002805411363db96840878e039b86923f109 --- /dev/null +++ b/capvector-oft/prismatic/preprocessing/datasets/datasets.py @@ -0,0 +1,200 @@ +""" +datasets.py + +PyTorch Dataset Definitions for Prismatic models; supports processing for both the `align` and `finetune` stages, with +utilities for formatting conversations during the `finetune` stage subject to the given LLM backbone's expected +formatting (e.g., SYS_PROMPT + USER: ... ASSISTANT: ... for Vicuña v1.5 Chat models). + +We currently only support Map-style Datasets; assumes that all files (annotations, images) are on local disk, and that +random access image reading is relatively cheap/fast. +""" + +import copy +import json +from pathlib import Path +from typing import Dict, List, Tuple, Type + +import torch +from PIL import Image +from torch.utils.data import Dataset +from transformers import CodeGenTokenizerFast, LlamaTokenizerFast, PreTrainedTokenizerBase + +from prismatic.models.backbones.llm.prompting import PromptBuilder +from prismatic.models.backbones.vision import ImageTransform + +# HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels) +IGNORE_INDEX = -100 + + +class AlignDataset(Dataset[Dict[str, torch.Tensor]]): + def __init__( + self, + chat_json: Path, + image_dir: Path, + image_transform: ImageTransform, + tokenizer: PreTrainedTokenizerBase, + ) -> None: + super().__init__() + self.chat_json, self.image_dir = chat_json, image_dir + self.image_transform, self.tokenizer = image_transform, tokenizer + self.dataset_type = "align" + + # Create Prompt Template + self.prompt_template = "{caption}" + self.tokenizer.eos_token + + # Load Chat JSON + with open(self.chat_json, "r") as f: + self.examples = json.load(f) + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + """ + Following the *actual* code executed from the LLaVa codebase, during the "align" phase, we actually discard + the "prompt" from the human, and instead directly predict the caption from the image. + + As a concrete example given the "raw data" for the first example: + example = self.examples[0]["conversations"]` = { + [ + {"from": "human", "value": "Render a clear and concise summary of the photo.\n"}, + {"from": "gpt", "value": "select luxury furniture 3 - inch gel memory foam mattress topper"} + ] + } + + Return =>> self.tokenizer(" select luxury furniture 3 - inch gel memory foam mattress topper\n") + + :param idx: Index to retrieve from the dataset. + + :return: Dictionary of {"pixel_values": torch.Tensor, "input_ids": torch.Tensor, "labels": torch.Tensor} + """ + image_path, conversation = Path(self.examples[idx]["image"]), self.examples[idx]["conversations"] + assert (len(conversation) == 2) and ("" not in conversation[-1]["value"]), "Unexpected text!" + + # Format Caption --> {caption}{eos_token} + caption = self.prompt_template.format(caption=conversation[-1]["value"].strip()) + + # We treat image patches as "tokens = [p1 p2 p3, ...]"; we need to specify ordering of text/patch tokens. + # => Critically, we find that inserting *after* the BOS token leads to the strongest performance! + # - input_ids = " p1 p2 p3 ... \n" + # - labels = "IGNORE IGNORE ..." (copy `input_ids` replacing and p{1...K} with IGNORE) + # + # IMPORTANT => IF WE'RE USING HF LLM.forward(... labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL! + input_ids = self.tokenizer(caption, truncation=True, return_tensors="pt").input_ids[0] + labels = copy.deepcopy(input_ids) + + # Set the token's label to IGNORE_INDEX (since we're inserting the image patches right after) + labels[0] = IGNORE_INDEX + + # Process Image --> get "pixel_values" (will either be a torch.Tensor OR a Dict[str,torch.Tensor]) + pixel_values = self.image_transform(Image.open(self.image_dir / image_path).convert("RGB")) + + return dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels) + + def get_modality_lengths(self, n_image_patches: int) -> List[Tuple[bool, int]]: + """Get a list of modalities (unimodal / text-only vs. multimodal) and length of conversations per example.""" + modality_lengths = [] + for example in self.examples: + is_multimodal = "image" in example + n_words = sum([len(turn["value"].replace("", "").split()) for turn in example["conversations"]]) + modality_lengths.append((is_multimodal, (n_image_patches + n_words) if is_multimodal else n_words)) + return modality_lengths + + def __len__(self) -> int: + return len(self.examples) + + +class FinetuneDataset(Dataset[Dict[str, torch.Tensor]]): + def __init__( + self, + instruct_json: Path, + image_dir: Path, + image_transform: ImageTransform, + tokenizer: PreTrainedTokenizerBase, + prompt_builder_fn: Type[PromptBuilder], + ) -> None: + super().__init__() + self.instruct_json, self.image_dir = instruct_json, image_dir + self.image_transform, self.tokenizer = image_transform, tokenizer + self.prompt_builder_fn = prompt_builder_fn + self.dataset_type = "finetune" + + # Load Instruct JSON + with open(self.instruct_json, "r") as f: + self.examples = json.load(f) + + # === Unimodal + Multimodal Handling === + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + """ + Unlike the *align* stage handling, for the *finetune* stage, we actually need to handle multiple "turns" of + dialog grounded in a single image. + + To do this, we leverage the `prompt_builder_fn` which instantiates a PromptBuilder object. By calling the + methods for adding turns and getting a prompt, we ensure proper formatting and consistency for each example. + + :param idx: Index to retrieve from the dataset. + + :return: Dictionary of {"pixel_values": torch.Tensor, "input_ids": torch.Tensor, "labels": torch.Tensor} + """ + conversation = self.examples[idx]["conversations"] + + # Create Prompt Builder --> add each message sequentially + prompt_builder, input_ids, labels = self.prompt_builder_fn(model_family="prismatic"), [], [] + for turn_idx, turn in enumerate(conversation): + # Get "effective" string added to prompt --> handle whitespace for tokenizer type! + msg = prompt_builder.add_turn(turn["from"], turn["value"]) + + # Llama Tokenizer (Fast) adds extra character if a string ends in whitespace --> strip if non-empty! + if isinstance(self.tokenizer, LlamaTokenizerFast): + msg = msg.rstrip() + + # Phi-2 Tokenizer == CodeGenTokenizer (Fast) -- no special handling! + elif isinstance(self.tokenizer, CodeGenTokenizerFast): + pass + + else: + raise ValueError(f"Tokenizer of type `{type(self.tokenizer)}` is not explicitly handled!") + + # Tokenize Input IDs + turn_input_ids = self.tokenizer(msg, add_special_tokens=turn_idx == 0).input_ids + + # [CRITICAL] We do not want to take the loss for the "USER: " prompts =>> just the responses! + turn_labels = ( + [IGNORE_INDEX for _ in range(len(turn_input_ids))] if (turn_idx % 2) == 0 else list(turn_input_ids) + ) + + # Add to Trackers + input_ids.extend(turn_input_ids) + labels.extend(turn_labels) + + # Tensorize =>> Set the token's label to IGNORE_INDEX (since we're inserting the image patches after) + # - IMPORTANT => IF WE'RE USING HF LLM.forward(... labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL! + input_ids, labels = torch.tensor(input_ids), torch.tensor(labels) + + # Handle Truncation (if necessary) + input_ids, labels = input_ids[: self.tokenizer.model_max_length], labels[: self.tokenizer.model_max_length] + + # === Handle "unimodal" (language-only) vs. "multimodal" === + if "image" in self.examples[idx]: + image_path = Path(self.examples[idx]["image"]) + + # Set the token's label to IGNORE_INDEX (since we're inserting the image patches right after) + labels[0] = IGNORE_INDEX + + # Process Image --> get "pixel_values" (will either be a torch.Tensor OR a Dict[str,torch.Tensor]) + pixel_values = self.image_transform(Image.open(self.image_dir / image_path).convert("RGB")) + + return dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels) + + else: + # No image --> return `pixel_values` = None; Collator will do the smart batch handling for us! + return dict(pixel_values=None, input_ids=input_ids, labels=labels) + + def get_modality_lengths(self) -> List[Tuple[bool, int]]: + """Get a list of modalities (unimodal / text-only vs. multimodal) and length of conversations per example.""" + modality_lengths = [] + for example in self.examples: + is_multimodal = "image" in example + n_words = sum([len(turn["value"].split()) for turn in example["conversations"]]) + modality_lengths.append((is_multimodal, n_words)) + return modality_lengths + + def __len__(self) -> int: + return len(self.examples) diff --git a/capvector-oft/prismatic/preprocessing/download.py b/capvector-oft/prismatic/preprocessing/download.py new file mode 100644 index 0000000000000000000000000000000000000000..300bc0f42e311a33927789ad48a3e143a255aa5f --- /dev/null +++ b/capvector-oft/prismatic/preprocessing/download.py @@ -0,0 +1,207 @@ +""" +download.py + +Utility functions for downloading and extracting various datasets to (local) disk. +""" + +import os +import shutil +from pathlib import Path +from typing import Dict, List, TypedDict +from zipfile import ZipFile + +import requests +from PIL import Image +from rich.progress import BarColumn, DownloadColumn, MofNCompleteColumn, Progress, TextColumn, TransferSpeedColumn +from tqdm import tqdm + +from prismatic.overwatch import initialize_overwatch + +# Initialize Overwatch =>> Wraps `logging.Logger` +overwatch = initialize_overwatch(__name__) + + +# === Dataset Registry w/ Links === +# fmt: off +DatasetComponent = TypedDict( + "DatasetComponent", + {"name": str, "extract": bool, "extract_type": str, "url": str, "do_rename": bool}, + total=False +) + +DATASET_REGISTRY: Dict[str, List[DatasetComponent]] = { + # === LLaVa v1.5 Dataset(s) === + + # Note =>> This is the full suite of datasets included in the LLaVa 1.5 "finetuning" stage; all the LLaVa v1.5 + # models are finetuned on this split. We use this dataset for all experiments in our paper. + "llava-laion-cc-sbu-558k": [ + { + "name": "chat.json", # Contains the "chat" traces :: {"human" => , "gpt" => } + "extract": False, + "url": "https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/resolve/main/blip_laion_cc_sbu_558k.json", + "do_rename": True, + }, + { + "name": "images", # Contains the LLaVa Processed Images (jpgs, 224x224 resolution) + "extract": True, + "extract_type": "directory", + "url": "https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/resolve/main/images.zip", + "do_rename": False, + } + ], + + "llava-v1.5-instruct": [ + { + "name": "llava_v1_5_mix665k.json", + "extract": False, + "url": ( + "https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/resolve/main/llava_v1_5_mix665k.json" + ), + "do_rename": True, + }, + { + "name": "coco/train2017", # Visual Instruct Tuning images are all sourced from COCO Train 2017 + "extract": True, + "extract_type": "directory", + "url": "http://images.cocodataset.org/zips/train2017.zip", + "do_rename": True, + }, + { + "name": "gqa/images", + "extract": True, + "extract_type": "directory", + "url": "https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip", + "do_rename": True, + }, + { + "name": "ocr_vqa/images", + "extract": True, + "extract_type": "directory", + "url": "https://huggingface.co/datasets/qnguyen3/ocr_vqa/resolve/main/ocr_vqa.zip", + "do_rename": True, + }, + { + "name": "textvqa/train_images", + "extract": True, + "extract_type": "directory", + "url": "https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip", + "do_rename": True, + }, + { + "name": "vg/VG_100K", + "extract": True, + "extract_type": "directory", + "url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip", + "do_rename": True, + }, + { + "name": "vg/VG_100K_2", + "extract": True, + "extract_type": "directory", + "url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip", + "do_rename": True, + }, + ] +} +# fmt: on + + +def convert_to_jpg(image_dir: Path) -> None: + """Handling for OCR-VQA Images specifically; iterates through directory, converts all GIFs/PNGs.""" + overwatch.info(f"Converting all Images in `{image_dir}` to JPG") + + for image_fn in tqdm(list(image_dir.iterdir())): + if image_fn.suffix in {".jpg", ".jpeg"} or (jpg_fn := image_dir / f"{image_fn.stem}.jpg").exists(): + continue + + if image_fn.suffix == ".gif": + gif = Image.open(image_fn) + gif.seek(0) + gif.convert("RGB").save(jpg_fn) + elif image_fn.suffix == ".png": + Image.open(image_fn).convert("RGB").save(jpg_fn) + else: + raise ValueError(f"Unexpected image format `{image_fn.suffix}`") + + +def download_with_progress(url: str, download_dir: Path, chunk_size_bytes: int = 1024) -> Path: + """Utility function for downloading files from the internet, with a handy Rich-based progress bar.""" + overwatch.info(f"Downloading {(dest_path := download_dir / Path(url).name)} from `{url}`", ctx_level=1) + if dest_path.exists(): + return dest_path + + # Otherwise --> fire an HTTP Request, with `stream = True` + response = requests.get(url, stream=True) + + # Download w/ Transfer-Aware Progress + # => Reference: https://github.com/Textualize/rich/blob/master/examples/downloader.py + with Progress( + TextColumn("[bold]{task.description} - {task.fields[fname]}"), + BarColumn(bar_width=None), + "[progress.percentage]{task.percentage:>3.1f}%", + "•", + DownloadColumn(), + "•", + TransferSpeedColumn(), + transient=True, + ) as dl_progress: + dl_tid = dl_progress.add_task( + "Downloading", fname=dest_path.name, total=int(response.headers.get("content-length", "None")) + ) + with open(dest_path, "wb") as f: + for data in response.iter_content(chunk_size=chunk_size_bytes): + dl_progress.advance(dl_tid, f.write(data)) + + return dest_path + + +def extract_with_progress(archive_path: Path, download_dir: Path, extract_type: str, cleanup: bool = False) -> Path: + """Utility function for extracting compressed archives, with a handy Rich-based progress bar.""" + assert archive_path.suffix == ".zip", "Only `.zip` compressed archives are supported for now!" + overwatch.info(f"Extracting {archive_path.name} to `{download_dir}`", ctx_level=1) + + # Extract w/ Progress + with Progress( + TextColumn("[bold]{task.description} - {task.fields[aname]}"), + BarColumn(bar_width=None), + "[progress.percentage]{task.percentage:>3.1f}%", + "•", + MofNCompleteColumn(), + transient=True, + ) as ext_progress: + with ZipFile(archive_path) as zf: + ext_tid = ext_progress.add_task("Extracting", aname=archive_path.name, total=len(members := zf.infolist())) + extract_path = Path(zf.extract(members[0], download_dir)) + if extract_type == "file": + assert len(members) == 1, f"Archive `{archive_path}` with extract type `{extract_type} has > 1 member!" + elif extract_type == "directory": + for member in members[1:]: + zf.extract(member, download_dir) + ext_progress.advance(ext_tid) + else: + raise ValueError(f"Extract type `{extract_type}` for archive `{archive_path}` is not defined!") + + # Cleanup (if specified) + if cleanup: + archive_path.unlink() + + return extract_path + + +def download_extract(dataset_id: str, root_dir: Path) -> None: + """Download all files for a given dataset (querying registry above), extracting archives if necessary.""" + os.makedirs(download_dir := root_dir / "download" / dataset_id, exist_ok=True) + + # Download Files => Single-Threaded, with Progress Bar + dl_tasks = [d for d in DATASET_REGISTRY[dataset_id] if not (download_dir / d["name"]).exists()] + for dl_task in dl_tasks: + dl_path = download_with_progress(dl_task["url"], download_dir) + + # Extract Files (if specified) --> Note (assumes ".zip" ONLY!) + if dl_task["extract"]: + dl_path = extract_with_progress(dl_path, download_dir, dl_task["extract_type"]) + dl_path = dl_path.parent if dl_path.is_file() else dl_path + + # Rename Path --> dl_task["name"] + if dl_task["do_rename"]: + shutil.move(dl_path, download_dir / dl_task["name"]) diff --git a/capvector-oft/prismatic/preprocessing/materialize.py b/capvector-oft/prismatic/preprocessing/materialize.py new file mode 100644 index 0000000000000000000000000000000000000000..b6605825448e95a8dc30825db2a1e31e9b46efc3 --- /dev/null +++ b/capvector-oft/prismatic/preprocessing/materialize.py @@ -0,0 +1,69 @@ +""" +materialize.py + +Factory class for initializing pretraining datasets on a per-VLM basis; provides and exports individual functions for +clear control flow. +""" + +from typing import Tuple, Type + +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizerBase + +from prismatic.conf import DatasetConfig +from prismatic.models.backbones.llm.prompting import PromptBuilder +from prismatic.models.backbones.vision import ImageTransform +from prismatic.preprocessing.datasets import AlignDataset, FinetuneDataset +from prismatic.util.data_utils import PaddedCollatorForLanguageModeling + +# Dataset Initializers =>> Maps Stage --> cls() +DATASET_INITIALIZER = {"align": AlignDataset, "finetune": FinetuneDataset, "full-finetune": FinetuneDataset} + + +def get_dataset_and_collator( + stage: str, + dataset_cfg: DatasetConfig, + image_transform: ImageTransform, + tokenizer: PreTrainedTokenizerBase, + prompt_builder_fn: Type[PromptBuilder], + default_image_resolution: Tuple[int, int, int], + padding_side: str = "right", +) -> Tuple[Dataset, PaddedCollatorForLanguageModeling]: + dataset_cls = DATASET_INITIALIZER[stage] + dataset_root_dir = dataset_cfg.dataset_root_dir + collator = PaddedCollatorForLanguageModeling( + tokenizer.model_max_length, tokenizer.pad_token_id, default_image_resolution, padding_side=padding_side + ) + + # Switch on `stage` + if stage == "align": + annotation_json, image_dir = dataset_cfg.align_stage_components + dataset = dataset_cls( + dataset_root_dir / annotation_json, dataset_root_dir / image_dir, image_transform, tokenizer + ) + return dataset, collator + + elif stage == "finetune": + annotation_json, image_dir = dataset_cfg.finetune_stage_components + dataset = dataset_cls( + dataset_root_dir / annotation_json, + dataset_root_dir / image_dir, + image_transform, + tokenizer, + prompt_builder_fn=prompt_builder_fn, + ) + return dataset, collator + + elif stage == "full-finetune": + annotation_json, image_dir = dataset_cfg.finetune_stage_components + dataset = dataset_cls( + dataset_root_dir / annotation_json, + dataset_root_dir / image_dir, + image_transform, + tokenizer, + prompt_builder_fn=prompt_builder_fn, + ) + return dataset, collator + + else: + raise ValueError(f"Stage `{stage}` is not supported!") diff --git a/capvector-oft/prismatic/training/__init__.py b/capvector-oft/prismatic/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c66f906fadd8a0fb31acdb0f7f86f6c393dc68ba --- /dev/null +++ b/capvector-oft/prismatic/training/__init__.py @@ -0,0 +1,2 @@ +from .materialize import get_train_strategy +from .metrics import Metrics, VLAMetrics diff --git a/capvector-oft/prismatic/training/materialize.py b/capvector-oft/prismatic/training/materialize.py new file mode 100644 index 0000000000000000000000000000000000000000..5fefd9fdec35b5df0d66d0669a8b625aea64e7ac --- /dev/null +++ b/capvector-oft/prismatic/training/materialize.py @@ -0,0 +1,66 @@ +""" +materialize.py + +Factory class defining functions for instantiating various Training Strategies, supporting different VLMs, backbones, +and strategy configurations. +""" + +from typing import Callable, Optional + +import torch + +from prismatic.models.vlms import PrismaticVLM +from prismatic.training.strategies import FSDPStrategy, TrainingStrategy + +# Registry =>> Maps ID --> {cls(), kwargs} :: supports FSDP for now, but DDP handler is also implemented! +TRAIN_STRATEGIES = { + "fsdp-shard-grad-op": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "shard-grad-op"}}, + "fsdp-full-shard": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "full-shard"}}, +} + + +def get_train_strategy( + train_strategy: str, + vlm: PrismaticVLM, + device_id: int, + stage: str, + epochs: int, + max_steps: Optional[int], + global_batch_size: int, + per_device_batch_size: int, + learning_rate: float, + weight_decay: float, + max_grad_norm: float, + lr_scheduler_type: str, + warmup_ratio: float, + enable_gradient_checkpointing: bool = True, + enable_mixed_precision_training: bool = True, + reduce_in_full_precision: bool = False, + mixed_precision_dtype: torch.dtype = torch.bfloat16, + worker_init_fn: Optional[Callable[[int], None]] = None, +) -> TrainingStrategy: + if train_strategy in TRAIN_STRATEGIES: + strategy_cfg = TRAIN_STRATEGIES[train_strategy] + strategy = strategy_cfg["cls"]( + vlm=vlm, + device_id=device_id, + stage=stage, + epochs=epochs, + max_steps=max_steps, + global_batch_size=global_batch_size, + per_device_batch_size=per_device_batch_size, + learning_rate=learning_rate, + weight_decay=weight_decay, + max_grad_norm=max_grad_norm, + lr_scheduler_type=lr_scheduler_type, + warmup_ratio=warmup_ratio, + enable_gradient_checkpointing=enable_gradient_checkpointing, + enable_mixed_precision_training=enable_mixed_precision_training, + reduce_in_full_precision=reduce_in_full_precision, + mixed_precision_dtype=mixed_precision_dtype, + worker_init_fn=worker_init_fn, + **strategy_cfg["kwargs"], + ) + return strategy + else: + raise ValueError(f"Train Strategy `{train_strategy}` is not supported!") diff --git a/capvector-oft/prismatic/training/metrics.py b/capvector-oft/prismatic/training/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..6fcc78172e59f9efa6ce2ef4b87d1f19a027821a --- /dev/null +++ b/capvector-oft/prismatic/training/metrics.py @@ -0,0 +1,348 @@ +""" +metrics.py + +Utility classes defining a Metrics container and multiple Trackers to enable model/stage-specific logging to various +endpoints (e.g., JSONL local logs, Weights & Biases). +""" + +import time +from collections import defaultdict, deque +from pathlib import Path +from typing import Any, Dict, Optional, Protocol, Tuple, Union + +import jsonlines +import numpy as np +import torch +import wandb + +from prismatic.overwatch import initialize_overwatch + +# Initialize Overwatch =>> Wraps `logging.Logger` +overwatch = initialize_overwatch(__name__) + + +# === Define Tracker Interface === +class Tracker(Protocol): + def write_hyperparameters(self) -> None: ... + + def write(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None: ... + + def finalize(self) -> None: ... + + +# === Individual Tracker Definitions === +class JSONLinesTracker: + def __init__(self, run_id: str, run_dir: Path, hparams: Dict[str, Any]) -> None: + self.run_id, self.run_dir, self.hparams = run_id, run_dir, hparams + + @overwatch.rank_zero_only + def write_hyperparameters(self) -> None: + with jsonlines.open(self.run_dir / "run-metrics.jsonl", mode="w", sort_keys=True) as js_tracker: + js_tracker.write({"run_id": self.run_id, "hparams": self.hparams}) + + @overwatch.rank_zero_only + def write(self, _: int, metrics: Dict[str, Union[int, float]]) -> None: + with jsonlines.open(self.run_dir / f"{self.run_id}.jsonl", mode="a", sort_keys=True) as js_tracker: + js_tracker.write(metrics) + + def finalize(self) -> None: + return + + +class WeightsBiasesTracker: + def __init__( + self, + run_id: str, + run_dir: Path, + hparams: Dict[str, Any], + project: str = "prismatic", + entity: Optional[str] = None, + group: str = "align", + ) -> None: + self.run_id, self.run_dir, self.hparams = run_id, run_dir, hparams + + # Get W&B-Specific Initialization Parameters + self.project, self.entity, self.group, self.wandb_dir = project, entity, group, self.run_dir + + # Call W&B.init() + self.initialize() + + @overwatch.rank_zero_only + def initialize(self) -> None: + wandb.init( + name=self.run_id, + dir=self.wandb_dir, + config=self.hparams, + project=self.project, + entity=self.entity, + group=self.group, + ) + + @overwatch.rank_zero_only + def write_hyperparameters(self) -> None: + wandb.config = self.hparams + + @overwatch.rank_zero_only + def write(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None: + wandb.log(metrics, step=global_step) + + @staticmethod + def finalize() -> None: + if overwatch.is_rank_zero(): + wandb.finish() + + # A job gets 210 seconds to get its affairs in order + time.sleep(210) + + +# === Core Metrics Container :: Initializes Trackers => Compiles/Pushes Metrics === + + +class Metrics: + def __init__( + self, + active_trackers: Tuple[str, ...], + run_id: str, + run_dir: Path, + hparams: Dict[str, Any], + stage: str, + wandb_project: str = "prismatic", + wandb_entity: Optional[str] = None, + grad_accumulation_steps: int = 1, + window_size: int = 128, + ) -> None: + self.run_id, self.run_dir, self.hparams, self.stage = run_id, run_dir, hparams, stage + + # Initialize Trackers + self.trackers = [] + for tracker_type in active_trackers: + if tracker_type == "jsonl": + tracker = JSONLinesTracker(run_id, run_dir, hparams) + elif tracker_type == "wandb": + tracker = WeightsBiasesTracker( + run_id, run_dir, hparams, project=wandb_project, entity=wandb_entity, group=self.stage + ) + else: + raise ValueError(f"Tracker with type `{tracker_type} is not supported!") + + # Add Hyperparameters --> add to `self.trackers` + tracker.write_hyperparameters() + self.trackers.append(tracker) + + # Create Universal Metrics Buffers + self.global_step, self.start_time, self.step_start_time = 0, time.time(), time.time() + self.state = { + "loss_raw": deque(maxlen=grad_accumulation_steps), + "loss": deque(maxlen=window_size), + "step_time": deque(maxlen=window_size), + "lr": [], + } + + def log(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None: + for tracker in self.trackers: + tracker.write(global_step, metrics) + + def get_status(self, loss: Optional[torch.Tensor] = None) -> str: + lr = self.state["lr"][-1] if len(self.state["lr"]) > 0 else 0 + if loss is None: + return f"=>> [Global Step] {self.global_step:06d} =>> LR :: {lr:.6f}" + + # Otherwise, embed `loss` in status report! + return f"=>> [Global Step] {self.global_step:06d} =>> LR :: {lr:.6f} -- Loss :: {loss:.4f}" + + def commit( + self, *, global_step: Optional[int] = None, lr: Optional[float] = None, update_step_time: bool = False, **kwargs + ) -> None: + """Update all metrics in `self.state` by iterating through special positional arguments & kwargs.""" + if global_step is not None: + self.global_step = global_step + + # For all other variables --> only track on rank zero! + if not overwatch.is_rank_zero(): + return + + # Special Positional Arguments + if lr is not None: + self.state["lr"].append(lr) + + if update_step_time: + self.state["step_time"].append(time.time() - self.step_start_time) + self.step_start_time = time.time() + + # Generic Keyword Arguments + for key, value in kwargs.items(): + if key == "loss": + loss_val = value.detach() + self.state["loss_raw"].append(loss_val) + self.state["loss"].append(loss_val) + else: + self.state[key].append(value.detach()) + + @overwatch.rank_zero_only + def push(self) -> str: + # Note :: Raw Loss is an Average over Gradient Accumulation Steps --> No Smoothing! + loss_raw = torch.stack(list(self.state["loss_raw"])).mean().item() + loss = torch.stack(list(self.state["loss"])).mean().item() + step_time, lr = np.mean(list(self.state["step_time"])), self.state["lr"][-1] + status = self.get_status(loss) + + # Fire to Trackers + prefix = self.stage.capitalize() + self.log( + self.global_step, + metrics={ + f"{prefix}/Step": self.global_step, + f"{prefix}/Loss": loss, + f"{prefix}/Loss (Raw)": loss_raw, + f"{prefix}/Learning Rate": lr, + f"{prefix}/Step Time": step_time, + }, + ) + return status + + def finalize(self) -> str: + for tracker in self.trackers: + tracker.finalize() + + +class VLAMetrics: + def __init__( + self, + active_trackers: Tuple[str, ...], + run_id: str, + run_dir: Path, + hparams: Dict[str, Any], + wandb_project: str = "openvla", + wandb_entity: Optional[str] = "stanford-voltron", + grad_accumulation_steps: int = 1, + window_size: int = 1, + resume_step: Optional[int] = None, + resume_epoch: Optional[int] = None, + ) -> None: + self.run_id, self.run_dir, self.hparams = run_id, run_dir, hparams + + # Initialize Trackers + self.trackers = [] + for tracker_type in active_trackers: + if tracker_type == "jsonl": + tracker = JSONLinesTracker(run_id, run_dir, hparams) + elif tracker_type == "wandb": + tracker = WeightsBiasesTracker( + run_id, run_dir, hparams, project=wandb_project, entity=wandb_entity, group="vla-train" + ) + else: + raise ValueError(f"Tracker with type `{tracker_type} is not supported!") + + # Add Hyperparameters --> add to `self.trackers` + tracker.write_hyperparameters() + self.trackers.append(tracker) + + # Create Universal Metrics Buffers + self.global_step = 0 if resume_step is None else resume_step + self.epoch = 0 if resume_epoch is None else resume_epoch + self.start_time, self.step_start_time = time.time(), time.time() + self.state = { + "loss_raw": deque(maxlen=grad_accumulation_steps), + "loss": deque(maxlen=window_size), + "l1_loss": deque(maxlen=window_size), + "action_accuracy": deque(maxlen=window_size), + "step_time": deque(maxlen=window_size), + "lr": [], + } + + # Created metrics buffers for individual tracked datasets + self.dataset_trackers = defaultdict(lambda: VLAMetrics([], "", "", {})) + + def log(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None: + for tracker in self.trackers: + tracker.write(global_step, metrics) + + def get_status(self, loss: Optional[torch.Tensor] = None) -> str: + lr = self.state["lr"][-1] if len(self.state["lr"]) > 0 else 0 + if loss is None: + return f"=>> [Epoch {self.epoch:03d}] Global Step {self.global_step:06d} =>> LR :: {lr:.6f}" + + # Otherwise, embed `loss` in status report! + return f"=>> [Epoch {self.epoch:03d}] Global Step {self.global_step:06d} =>> LR :: {lr:.6f} - Loss :: {loss:.4f}" + + def commit( + self, + *, + global_step: Optional[int] = None, + epoch: Optional[int] = None, + lr: Optional[float] = None, + update_step_time: bool = False, + **kwargs, + ) -> None: + """Update all metrics in `self.state` by iterating through special positional arguments & kwargs.""" + if global_step is not None: + self.global_step = global_step + + if epoch is not None: + self.epoch = epoch + + # For all other variables --> only track on rank zero! + if not overwatch.is_rank_zero(): + return + + # Special Positional Arguments + if lr is not None: + self.state["lr"].append(lr) + + if update_step_time: + self.state["step_time"].append(time.time() - self.step_start_time) + self.step_start_time = time.time() + + # Generic Keyword Arguments + for key, value in kwargs.items(): + if key == "loss": + loss_val = value.detach() + self.state["loss_raw"].append(loss_val) + self.state["loss"].append(loss_val) + else: + self.state[key].append(value.detach()) + + def commit_for_dataset(self, dataset_name: str, **kwargs) -> None: + self.dataset_trackers[dataset_name].commit(**kwargs) + + @overwatch.rank_zero_only + def push(self) -> str: + # Note :: Raw Loss is an Average over Gradient Accumulation Steps --> No Smoothing! + loss_raw = torch.stack(list(self.state["loss_raw"])).mean().item() + loss = torch.stack(list(self.state["loss"])).mean().item() + l1_loss = torch.stack(list(self.state["l1_loss"])).mean().item() + action_accuracy = torch.stack(list(self.state["action_accuracy"])).mean().item() + step_time, lr = np.mean(list(self.state["step_time"])), self.state["lr"][-1] + status = self.get_status(loss) + + # Get metrics per dataset + dataset_metrics = {} + for ds, tracker in self.dataset_trackers.items(): + dataset_metrics.update( + { + f"{ds}/L1 Loss": torch.stack(list(tracker.state["l1_loss"])).mean().item(), + f"{ds}/Action Token Accuracy": torch.stack(list(tracker.state["action_accuracy"])).mean().item(), + } + ) + + # Fire to Trackers + prefix = "VLA Train" + self.log( + self.global_step, + metrics={ + f"{prefix}/Step": self.global_step, + f"{prefix}/Epoch": self.epoch, + f"{prefix}/Loss": loss, + f"{prefix}/L1 Loss": l1_loss, + f"{prefix}/Action Token Accuracy": action_accuracy, + f"{prefix}/Loss (Raw)": loss_raw, + f"{prefix}/Learning Rate": lr, + f"{prefix}/Step Time": step_time, + **dataset_metrics, + }, + ) + return status + + def finalize(self) -> str: + for tracker in self.trackers: + tracker.finalize() diff --git a/capvector-oft/prismatic/training/strategies/__init__.py b/capvector-oft/prismatic/training/strategies/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..748155bebe6f45bfb012c341e526d5fa47e589eb --- /dev/null +++ b/capvector-oft/prismatic/training/strategies/__init__.py @@ -0,0 +1,3 @@ +from .base_strategy import TrainingStrategy +from .ddp import DDPStrategy +from .fsdp import FSDPStrategy diff --git a/capvector-oft/prismatic/training/strategies/base_strategy.py b/capvector-oft/prismatic/training/strategies/base_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..adb8c0fde4af5ce130c2b22b6a366bf32eb5bbd1 --- /dev/null +++ b/capvector-oft/prismatic/training/strategies/base_strategy.py @@ -0,0 +1,417 @@ +""" +base_strategy.py + +Abstract class definition of a (distributed) training strategy, with full annotations of class methods, utility +functions, and initialization logic. + +Training Strategies (DDP, FSDP-Grad, FSDP-Full) tend to have a lot of repeated components; this class does a lot of +heavy lifting. +""" + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Callable, Optional + +import numpy as np +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset +from tqdm import tqdm +from transformers.modeling_outputs import CausalLMOutputWithPast + +from prismatic.models.vlms import PrismaticVLM +from prismatic.overwatch import initialize_overwatch +from prismatic.training.metrics import Metrics, VLAMetrics +from prismatic.training.train_utils import ( + compute_actions_l1_loss, + compute_token_accuracy, + get_current_action_mask, + get_next_actions_mask, +) +from prismatic.util import check_bloat16_supported +from prismatic.util.batching_utils import SplitModalitySampler +from prismatic.util.data_utils import PaddedCollatorForActionPrediction, PaddedCollatorForLanguageModeling +from prismatic.vla.action_tokenizer import ActionTokenizer + +# HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels) +from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, NUM_ACTIONS_CHUNK, IGNORE_INDEX +NEWLINE_INDEX = 13 # '\n' +STOP_INDEX = 2 # '' + +# Initialize Overwatch =>> Wraps `logging.Logger` +overwatch = initialize_overwatch(__name__) + + +# === Abstract Base Class for an arbitrary Training Strategy === +class TrainingStrategy(ABC): + def __init__( + self, + vlm: PrismaticVLM, + device_id: int, + stage: str, + epochs: int, + max_steps: Optional[int], + global_batch_size: int, + per_device_batch_size: int, + learning_rate: float, + weight_decay: float, + max_grad_norm: float, + lr_scheduler_type: str, + warmup_ratio: float, + enable_gradient_checkpointing: bool = True, + enable_mixed_precision_training: bool = True, + reduce_in_full_precision: bool = False, + mixed_precision_dtype: torch.dtype = torch.bfloat16, + worker_init_fn: Optional[Callable[[int], None]] = None, + **_: str, + ) -> None: + self.vlm, self.device_id, self.stage = vlm, device_id, stage + + # Get relevant VLM instance parameters before they get (potentially) wrapped + self.all_module_keys, self.trainable_module_keys = self.vlm.all_module_keys, self.vlm.trainable_module_keys + self.llm_transformer_layer_cls = self.vlm.llm_backbone.transformer_layer_cls + + # Optimization Parameters + self.epochs, self.max_steps = epochs, max_steps + self.global_batch_size, self.per_device_batch_size = global_batch_size, per_device_batch_size + + self.learning_rate, self.weight_decay, self.max_grad_norm = learning_rate, weight_decay, max_grad_norm + self.lr_scheduler_type, self.warmup_ratio = lr_scheduler_type, warmup_ratio + + # Generic Strategy Parameters + self.enable_gradient_checkpointing = enable_gradient_checkpointing + self.enable_mixed_precision_training = enable_mixed_precision_training + self.reduce_in_full_precision = reduce_in_full_precision + self.mixed_precision_dtype = mixed_precision_dtype + + # DataLoader Parameters + self.worker_init_fn = worker_init_fn + + # Optimizers & Scheduler (initialized in `run_setup`) + self.optimizer, self.lr_scheduler = None, None + + # Lightweight Validation + assert ( + self.global_batch_size % self.per_device_batch_size == 0 + ), "Per-device batch size must evenly divide global batch size!" + self.grad_accumulation_steps = self.global_batch_size // self.per_device_batch_size // overwatch.world_size() + if self.enable_mixed_precision_training: + assert self.mixed_precision_dtype == torch.bfloat16, "Only BF16 mixed precision training is supported!" + assert check_bloat16_supported(), "BFloat16 is not supported on this hardware; unset `mixed_precision`" + + @abstractmethod + def save_checkpoint( + self, + run_dir: Path, + global_step: int, + epoch: int, + train_loss: Optional[float] = None, + only_trainable: bool = True, + ) -> None: ... + + @abstractmethod + def run_setup(self, run_dir: Path, n_train_examples: int) -> None: ... + + @abstractmethod + def clip_grad_norm(self) -> None: ... + + def run_training( + self, + dataset: Dataset, + collator: PaddedCollatorForLanguageModeling, + metrics: Metrics, + stage: str = "finetune", + batch_construction_strategy: str = "split-modality", + seed: int = 7, + ) -> None: + """Run the training loop for the given `dataset` and `collator`; log losses, results to `metrics`""" + if "finetune" in stage and batch_construction_strategy == "split-modality": + # Instantiate the split-modality sampler; if you want to extend with other batch construction schemes, + # (e.g., grouping by length) =>> can easily add them here! + modality_lengths = dataset.get_modality_lengths() + sampler = SplitModalitySampler( + dataset, + modality_lengths, + global_batch_size=self.global_batch_size, + num_replicas=overwatch.world_size(), + rank=overwatch.rank(), + seed=seed, + drop_last=False, + ) + + else: + sampler = DistributedSampler( + dataset, + num_replicas=overwatch.world_size(), + rank=overwatch.rank(), + shuffle=True, + seed=seed, + drop_last=False, + ) + + # Create a DataLoader with the initialized sampler, per-device-bsz, and collator + dataloader = DataLoader( + dataset, + batch_size=self.per_device_batch_size, + sampler=sampler, + collate_fn=collator, + num_workers=2, + worker_init_fn=self.worker_init_fn, + ) + + # Max Steps vs. Epochs Computation + steps_per_epoch = len(dataloader) // self.grad_accumulation_steps + if self.max_steps is not None and steps_per_epoch < self.max_steps: + # Just set `epochs` to some large number --> we'll short-circuit based on steps anyway + self.epochs = 100 + + # === Train === + status = metrics.get_status() + with tqdm( + total=( + (self.epochs * (len(dataloader) // self.grad_accumulation_steps)) + if self.max_steps is None + else self.max_steps + ), + desc=status, + leave=False, + disable=not overwatch.is_rank_zero(), + ) as progress: + for epoch in range(self.epochs): + self.vlm.train() + sampler.set_epoch(epoch) + + # Zero-Gradients (just in case) + self.optimizer.zero_grad() + + # Note that we'll unpack batch (and let AMP/FSDP do its thing) in the VLM.forward() call + # => Basically, if we're using mixed precision (or not), autocast()/FSDP will move to device! + for train_idx, batch in enumerate(dataloader): + # [Contract] self.vlm.forward() must automatically compute `loss` and return! + with torch.autocast( + "cuda", + dtype=self.mixed_precision_dtype, + enabled=self.enable_mixed_precision_training, + ): + output: CausalLMOutputWithPast = self.vlm( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + pixel_values=batch["pixel_values"], + labels=batch["labels"], + multimodal_indices=batch["multimodal_indices"], + ) + loss = output.loss + + # Commit Loss (Prior to Gradient Accumulation Normalization) + metrics.commit(loss=loss) + + # Normalize Loss to account for Gradient Accumulation --> Backward! + # [IMPORTANT] Technically speaking, doing gradient accumulation in this way is "incorrect"; this is + # because in general, each batch has a *different number of masked out tokens* (because + # we're instruct-tuning). Taking the mean over two unbalanced means != the right thing! + # + # HOWEVER -- at least at the 7B scale, the "naive" approach is just as performant as + # the "correct" implementation, without adding extra complexity. + # + # That being said =>> at the 13B scale, *no matter what we tried, ANY gradient accumulation is just + # really bad for downstream performance. Initial investigation shows that BF16 accumulation + # just really tanks in precision... and don't have a good/clean way to fix this. Would love for + # someone to PR and fix this (and I'd greatly appreciate it!!!) + normalized_loss = loss / self.grad_accumulation_steps + normalized_loss.backward() + + # Step =>> Only if Done w/ Gradient Accumulation + if (train_idx + 1) % self.grad_accumulation_steps == 0: + metrics.commit(update_step_time=True) + + # Clip Gradients --> this is custom, per-strategy because of DDP vs. FSDP locality-assumptions + self.clip_grad_norm() + + # Optimizer & LR Scheduler Step + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + + # Push Metrics + metrics.commit(global_step=metrics.global_step + 1, lr=self.lr_scheduler.get_last_lr()[0]) + status = metrics.push() + + # Check for Termination & Save Final Checkpoint (in case `max_steps` is not None) + if self.max_steps is not None and metrics.global_step >= self.max_steps: + self.save_checkpoint(metrics.run_dir, metrics.global_step, epoch, loss.item()) + dist.barrier() + + return + + # Update Progress Bar + progress.update() + progress.set_description(status) + + # Save checkpoint at end each epoch (if `self.max_steps` is None) + if self.max_steps is None: + self.save_checkpoint(metrics.run_dir, metrics.global_step, epoch, loss.item()) + dist.barrier() + + # === VLA Training === + + def run_vla_training( + self, + vla_dataset: IterableDataset, + collator: PaddedCollatorForActionPrediction, + action_tokenizer: ActionTokenizer, + metrics: VLAMetrics, + save_interval: int = 2500, + save_full_model: bool = True, + ) -> None: + """Run the VLA training loop for the given `dataset` and `collator`; log losses, action metrics to `metrics`.""" + assert isinstance(vla_dataset, IterableDataset), "VLA training expects an IterableDataset!" + assert self.grad_accumulation_steps == 1, "VLA training does not support gradient accumulation!" + + # Create a DataLoader =>> Set `num_workers` to 0; RLDS loader handles parallelism! + dataloader = DataLoader( + vla_dataset, + batch_size=self.per_device_batch_size, + sampler=None, + collate_fn=collator, + num_workers=0, + worker_init_fn=self.worker_init_fn, + ) + + # === Train === + status = metrics.get_status() + with tqdm( + total=(self.epochs * len(dataloader)) if self.max_steps is None else self.max_steps, + desc=status, + leave=False, + disable=not overwatch.is_rank_zero(), + ) as progress: + self.vlm.train() + + # Zero Gradients (just in case) + self.optimizer.zero_grad() + + # [Contract] DataLoader wraps RLDS Loader (`.as_numpy_iterator() =>> implicit `.repeat()`) + # => This means looping over the DataLoader is basically "infinite" (so no outer loop over epochs). + # Slightly breaks default PyTorch semantics, which is why we adaptively compute `epoch` below. + for batch in dataloader: + # Note that we'll unpack batch (and let AMP/FSDP do its thing) in the VLM.forward() call + # => Basically, if we're using mixed precision (or not), autocast()/FSDP will move to device! + with torch.autocast( + "cuda", dtype=self.mixed_precision_dtype, enabled=self.enable_mixed_precision_training + ): + # [Contract] self.vlm.forward() must automatically compute `loss` and return! + output: CausalLMOutputWithPast = self.vlm( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + pixel_values=batch["pixel_values"], + labels=batch["labels"], + ) + loss = output.loss + + # Commit Loss =>> Backward! + metrics.commit(loss=loss) + loss.backward() + + # Get predicted and ground-truth token IDs + predicted_token_ids = output.logits[:, self.vlm.vision_backbone.num_patches : -1].argmax(dim=2) + ground_truth_token_ids = batch["labels"][:, 1:].to(predicted_token_ids.device) + + ####################################################################### + # === Compute Current Action Token Accuracy & L1 Loss === + ####################################################################### + + # Get current action mask: Target the first ACTION_DIM non-ignore tokens + current_action_mask = get_current_action_mask(ground_truth_token_ids) + + # Compute Accuracy + action_accuracy = compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask=current_action_mask) + + # Compute L1 Loss on Predicted (Continuous) Actions + action_l1_loss = compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=current_action_mask) + + ####################################################################### + # === Compute Next Actions Token Accuracy & L1 Loss === + ####################################################################### + + # Get next actions mask: Target all tokens after the first ACTION_DIM non-ignore tokens (excluding the last token, which is the stop token) + next_actions_mask = get_next_actions_mask(ground_truth_token_ids) + + # Compute Accuracy + next_actions_accuracy = compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask) + + # Compute L1 Loss on Predicted (Continuous) Actions + next_actions_l1_loss = compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask) + + ####################################################################### + # === Log === + ####################################################################### + + # Commit Metrics + metrics.commit( + action_accuracy=action_accuracy, + l1_loss=action_l1_loss, + next_actions_accuracy=next_actions_accuracy, + next_actions_l1_loss=next_actions_l1_loss, + update_step_time=True, + ) + + # Compute metrics per dataset --> only on rank_zero since we don't log them on other workers anyways + if overwatch.is_rank_zero(): + datasets = set(batch["dataset_names"]) + if len(datasets) > 1: + for ds in datasets: + ds_mask = torch.tensor([elem == ds for elem in batch["dataset_names"]]) + action_accuracy_ds = correct_preds[ds_mask].sum().float() / mask[ds_mask].sum().float() + pred_continuous_actions_ds = torch.tensor( + action_tokenizer.decode_token_ids_to_actions( + predicted_token_ids[ds_mask][mask[ds_mask]].cpu().numpy() + ) + ) + continuous_actions_gt_ds = torch.tensor( + action_tokenizer.decode_token_ids_to_actions( + ground_truth_token_ids[ds_mask][mask[ds_mask]].cpu().numpy() + ) + ) + action_l1_loss_ds = torch.nn.functional.l1_loss( + pred_continuous_actions_ds, continuous_actions_gt_ds + ) + metrics.commit_for_dataset( + dataset_name=ds.decode(), + action_accuracy=action_accuracy_ds, + l1_loss=action_l1_loss_ds, + next_actions_accuracy=next_actions_accuracy, + next_actions_l1_loss=next_actions_l1_loss, + ) + + # === Gradient Step === + + # Clip Gradients --> this is custom, per-strategy because of DDP vs. FSDP locality assumptions + self.clip_grad_norm() + + # Optimizer & LR Scheduler Step + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + + # Compute epoch value using number of completed gradient steps + epoch = (metrics.global_step + 1) // (len(vla_dataset) // self.global_batch_size) + + # Push Metrics + metrics.commit(global_step=metrics.global_step + 1, epoch=epoch, lr=self.lr_scheduler.get_last_lr()[0]) + status = metrics.push() + + # Check for Save Interval or Max Steps & Save Checkpoint + if (terminate := (self.max_steps is not None and metrics.global_step >= self.max_steps)) or ( + (metrics.global_step % save_interval) == 0 + ): + self.save_checkpoint( + metrics.run_dir, metrics.global_step, epoch, loss.item(), only_trainable=not save_full_model + ) + dist.barrier() + + if terminate: + return + + # Update Progress Bar + progress.update() + progress.set_description(status) diff --git a/capvector-oft/prismatic/training/strategies/ddp.py b/capvector-oft/prismatic/training/strategies/ddp.py new file mode 100644 index 0000000000000000000000000000000000000000..84e685d3fdfc123d2868e5de8201d927d24cceb2 --- /dev/null +++ b/capvector-oft/prismatic/training/strategies/ddp.py @@ -0,0 +1,128 @@ +""" +ddp.py + +Core class definition for a strategy implementing Torch native Distributed Data Parallel Training; note that on most +GPU hardware and LLM backbones >= 5-7B parameters, DDP training will OOM, which is why we opt for FSDP. +""" + +import shutil +from pathlib import Path +from typing import Optional + +import torch +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import AdamW +from transformers.optimization import get_constant_schedule, get_cosine_schedule_with_warmup + +from prismatic.overwatch import initialize_overwatch +from prismatic.training.strategies.base_strategy import TrainingStrategy + +# Initialize Overwatch =>> Wraps `logging.Logger` +overwatch = initialize_overwatch(__name__) + + +class DDPStrategy(TrainingStrategy): + @overwatch.rank_zero_only + def save_checkpoint( + self, + run_dir: Path, + global_step: int, + epoch: int, + train_loss: Optional[float] = None, + only_trainable: bool = True, + ) -> None: + """Save a checkpoint to the `run_dir` only containing the state_dicts for trainable parameters by default.""" + assert isinstance(self.vlm, DDP), "save_checkpoint assumes VLM is already wrapped in DDP!" + + # Splinter State Dictionary by Top-Level Submodules (or subset, if `only_trainable`) + model_state_dicts = { + mkey: getattr(self.vlm.module, mkey).state_dict() + for mkey in (self.trainable_module_keys if only_trainable else self.all_module_keys) + } + optimizer_state_dict = self.optimizer.state_dict() + + # Set Checkpoint Path =>> Embed *minimal* training statistics! + checkpoint_dir = run_dir / "checkpoints" + if train_loss is None: + checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss=inf.pt" + else: + checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss={train_loss:.4f}.pt" + + # Save Checkpoint & Copy Latest to `latest-checkpoint.pt` + torch.save({"model": model_state_dicts, "optimizer": optimizer_state_dict}, checkpoint_path) + shutil.copy(checkpoint_path, checkpoint_dir / "latest-checkpoint.pt") + + def run_setup(self, run_dir: Path, n_train_examples: int) -> None: + # Gradient Checkpointing Setup + if self.enable_gradient_checkpointing: + # For Gradient Checkpointing --> we make the assumption that the "bulk" of activation memory is taken up + # by the LLM; because we also make the explicit assumption that each LLM is derived from a HF + # pretrained model, the only thing we *need* to do (technically) is call `gradient_checkpoint_enable` + # on `self.llm_backbone`. + # + # What does it actually do? --> runs the *generic* custom_forward + torch.utils.checkpoint.checkpoint logic + # => github.com/huggingface/transformers/.../models/llama/modeling_llama.py#L692-L706 + # + # Additional Reference (to better understand gradient checkpointing in PyTorch writ large) + # => github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb + overwatch.info("Enabling Gradient Checkpointing on LLM Backbone", ctx_level=1) + self.vlm.llm_backbone.gradient_checkpointing_enable() + + # Move to Device =>> Note parameters are in full precision (*mixed precision* will only autocast as appropriate) + overwatch.info("Placing Entire VLM (Vision Backbone, LLM Backbone, Projector Weights) on GPU", ctx_level=1) + self.vlm.to(self.device_id) + + # Wrap with Distributed Data Parallel + # => Note: By default, wrapping naively with DDP(self.vlm) will initialize a *separate* buffer on GPU that + # is the same size/dtype as the model parameters; this will *double* GPU memory! + # - stackoverflow.com/questions/68949954/model-takes-twice-the-memory-footprint-with-distributed-data-parallel + overwatch.info("Wrapping VLM with Distributed Data Parallel", ctx_level=1) + self.vlm = DDP(self.vlm, device_ids=[self.device_id], gradient_as_bucket_view=True) + + # Create Optimizer and LR Scheduler =>> note that most of the LR Schedulers we use require `max_steps/epochs` + # => Optimizer should only operate on parameters that are *unfrozen* / trainable! + trainable_params = [param for param in self.vlm.parameters() if param.requires_grad] + if self.max_steps is None: + num_training_steps = (n_train_examples * self.epochs) // self.global_batch_size + else: + num_training_steps = self.max_steps + + if self.lr_scheduler_type == "linear-warmup+cosine-decay": + # Set warmup steps (floor) based on `warmup_ratio` (should be 0.03 - 0.05) + num_warmup_steps = int(num_training_steps * self.warmup_ratio) + + assert self.weight_decay == 0, "DDP training does not currently support `weight_decay` > 0!" + self.optimizer = AdamW(trainable_params, lr=self.learning_rate, weight_decay=self.weight_decay) + self.lr_scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps, num_training_steps) + for param_group in self.optimizer.param_groups: + param_group["lr"] = 0.0 + + elif self.lr_scheduler_type == "constant": + num_warmup_steps = 0 + + assert self.weight_decay == 0, "DDP training does not currently support `weight_decay` > 0!" + self.optimizer = AdamW(trainable_params, lr=self.learning_rate, weight_decay=self.weight_decay) + self.lr_scheduler = get_constant_schedule(self.optimizer) + + else: + raise ValueError(f"Learning Rate Schedule with type `{self.lr_scheduler_type}` is not supported!") + + # Finalize Setup =>> Log + overwatch.info( + "DDP Strategy =>> Finalized Training Setup:\n" + f" |-> Global (Effective) Batch Size = {self.global_batch_size}\n" + f" |-> Per-Device Batch Size = {self.per_device_batch_size}\n" + f" |-> Distributed World Size = {overwatch.world_size()}\n" + f" |-> Gradient Accumulation Steps = {self.grad_accumulation_steps}\n\n" + f" |-> LLM Backbone Gradient Checkpointing = {self.enable_gradient_checkpointing}\n" + f" |-> Use Native AMP = {self.enable_mixed_precision_training} ({self.mixed_precision_dtype})\n\n" + f" |-> Default AdamW LR = {self.learning_rate}\n" + f" |-> AdamW Weight Decay = {self.weight_decay}\n" + f" |-> LR Scheduler Type = {self.lr_scheduler_type}\n" + f" |-> LR Scheduler Warmup Steps (Ratio) = {num_warmup_steps} ({self.warmup_ratio})\n" + f" |-> Dataset Size = {n_train_examples} Examples\n" + f" |-> Max Steps = {num_training_steps}\n" + ) + + def clip_grad_norm(self) -> None: + torch.nn.utils.clip_grad_norm_(self.vlm.parameters(), max_norm=self.max_grad_norm) diff --git a/capvector-oft/prismatic/training/strategies/fsdp.py b/capvector-oft/prismatic/training/strategies/fsdp.py new file mode 100644 index 0000000000000000000000000000000000000000..9d59e41dab6ac5469e55375469fdc58575f24e24 --- /dev/null +++ b/capvector-oft/prismatic/training/strategies/fsdp.py @@ -0,0 +1,270 @@ +""" +fsdp.py + +Core class definition for a strategy implementing Torch native Fully Sharded Data Parallel Training (with support for +fine-grained control over wrapping policies and mixed precision per component). +""" + +import math +from collections import OrderedDict +from functools import partial +from pathlib import Path +from typing import Callable, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, + apply_activation_checkpointing, + checkpoint_wrapper, +) +from torch.distributed.fsdp import ( + FullStateDictConfig, + MixedPrecision, + ShardingStrategy, + StateDictType, +) +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.optim import AdamW +from transformers.optimization import get_constant_schedule, get_cosine_schedule_with_warmup + +from prismatic.models.vlms import PrismaticVLM +from prismatic.overwatch import initialize_overwatch +from prismatic.training.strategies.base_strategy import TrainingStrategy + +# Initialize Overwatch =>> Wraps `logging.Logger` +overwatch = initialize_overwatch(__name__) + + +class FSDPStrategy(TrainingStrategy): + def __init__( + self, + vlm: PrismaticVLM, + device_id: int, + stage: str, + epochs: int, + max_steps: Optional[int], + global_batch_size: int, + per_device_batch_size: int, + learning_rate: float, + weight_decay: float, + max_grad_norm: float, + lr_scheduler_type: str, + warmup_ratio: float, + enable_gradient_checkpointing: bool = True, + enable_mixed_precision_training: bool = True, + reduce_in_full_precision: bool = False, + mixed_precision_dtype: torch.dtype = torch.bfloat16, + worker_init_fn: Optional[Callable[[int], None]] = None, + sharding_strategy: str = "shard-grad-op", + state_dict_type: StateDictType = StateDictType.FULL_STATE_DICT, + ) -> None: + super().__init__( + vlm=vlm, + device_id=device_id, + stage=stage, + epochs=epochs, + max_steps=max_steps, + global_batch_size=global_batch_size, + per_device_batch_size=per_device_batch_size, + learning_rate=learning_rate, + weight_decay=weight_decay, + max_grad_norm=max_grad_norm, + lr_scheduler_type=lr_scheduler_type, + warmup_ratio=warmup_ratio, + enable_gradient_checkpointing=enable_gradient_checkpointing, + enable_mixed_precision_training=enable_mixed_precision_training, + reduce_in_full_precision=reduce_in_full_precision, + mixed_precision_dtype=mixed_precision_dtype, + worker_init_fn=worker_init_fn, + ) + + # FSDP-Specific Parameters + if sharding_strategy == "shard-grad-op": + self.fsdp_sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2 + elif sharding_strategy == "full-shard": + self.fsdp_sharding_strategy = ShardingStrategy.HYBRID_SHARD + else: + raise ValueError(f"FSDP Sharding Strategy {sharding_strategy} is not supported!") + + assert state_dict_type == StateDictType.FULL_STATE_DICT, "Sharded state saving is not yet implemented!" + self.fsdp_state_dict_type = state_dict_type + self.fsdp_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + + def save_checkpoint( + self, + run_dir: Path, + global_step: int, + epoch: int, + train_loss: Optional[float] = None, + only_trainable: bool = True, + ) -> None: + """Save a checkpoint to the `run_dir` only containing the state_dicts for trainable parameters by default.""" + assert isinstance(self.vlm, FSDP), "FSDPStrategy.save_checkpoint assumes VLM is already wrapped in FSDP!" + + # Summon Full State Dictionary =>> Reconstitute from Shards + with FSDP.state_dict_type(self.vlm, self.fsdp_state_dict_type, self.fsdp_save_policy): + full_vlm_state_dict = self.vlm.state_dict() + model_state_dicts = { + mkey: OrderedDict() for mkey in (self.trainable_module_keys if only_trainable else self.all_module_keys) + } + + # Iterate through `full_vlm_state_dict` and split `mkey.{full_dotted_path}` -> `mkey: {full_dotted_path}` + for key, param in full_vlm_state_dict.items(): + for mkey in model_state_dicts: + if key.startswith(mprefix := f"{mkey}."): + model_state_dicts[mkey][key.removeprefix(mprefix)] = param + + # Save on rank zero *only* + if overwatch.is_rank_zero(): + checkpoint_dir = run_dir / "checkpoints" + if train_loss is None: + checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss=inf.pt" + else: + checkpoint_path = ( + checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss={train_loss:.4f}.pt" + ) + + # Save Checkpoint & Copy Latest to `latest-checkpoint.pt` + torch.save({"model": model_state_dicts}, checkpoint_path) + + # TODO (siddk) :: This breaks w/ Sagemaker default permissions (root vs. )... skip? + # shutil.copy(checkpoint_path, checkpoint_dir / "latest-checkpoint.pt") + + def run_setup(self, run_dir: Path, n_train_examples: int) -> None: + # Iteratively Assemble FSDP Wrapping Policy by fetching the wrapping policies for each backbone/constituent + vlm_fsdp_wrapping_policy = self.vlm.get_fsdp_wrapping_policy() + + # Assemble the Default FSDP Mixed Precision Policy + if self.enable_mixed_precision_training and self.mixed_precision_dtype == torch.bfloat16: + # MixedPrecision `param_dtype` specifies *compute* dtype (for forward/backward only) + # => Reference: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision + reduce_buffer_dtype = torch.bfloat16 if not self.reduce_in_full_precision else torch.float32 + fsdp_precision_policy = MixedPrecision( + param_dtype=torch.bfloat16, reduce_dtype=reduce_buffer_dtype, buffer_dtype=reduce_buffer_dtype + ) + + # When running FSDP with a frozen vision backbone --> move to half precision! + if self.stage not in {"full-finetune", "vla-full-train", "vla-sandwich-train"}: + overwatch.info("Casting Vision Backbone to *Half Precision* via `.to(dtype=...)`") + self.vlm.vision_backbone.to(dtype=self.vlm.vision_backbone.half_precision_dtype) + + else: + # If we're not using mixed precision, everything is in default full precision! + fsdp_precision_policy = MixedPrecision( + param_dtype=torch.float32, reduce_dtype=torch.float32, buffer_dtype=torch.float32 + ) + + # => note that FSDP will automatically take care of device placement (similar to `autocast`) + self.vlm = FSDP( + self.vlm, + auto_wrap_policy=vlm_fsdp_wrapping_policy, + mixed_precision=fsdp_precision_policy, + sharding_strategy=self.fsdp_sharding_strategy, + device_id=torch.cuda.current_device(), + limit_all_gathers=True, + use_orig_params=True, + ) + + # Gradient Checkpoint Setup + if self.enable_gradient_checkpointing: + # For Gradient Checkpointing under FSDP --> we make the same assumption as in the DDP/other strategies; the + # bulk of activation memory is taken up by the LLM activations. However, unlike other strategies, we + # cannot rely on the HF Transformers default `gradient_checkpointing_enable()` --> FSDP breaks semantics! + # + # Instead, we need to write our own *NO-REENTRANT* wrapper, and apply it to the LLM's Transformer Layer. + non_reentrant_wrapper = partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT) + + def check_fn(submodule: nn.Module) -> bool: + return isinstance(submodule, self.llm_transformer_layer_cls) + + # Note that the terms "activation checkpointing" and "gradient checkpointing" are synonymous! + apply_activation_checkpointing(self.vlm, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn) + + # Barrier =>> Sharding takes a minute? + dist.barrier() + + # Create Optimizer and LR Scheduler =>> note that most of the LR Schedulers we use require `max_steps/epochs` + # => Optimizer should only operate on parameters that are *unfrozen* / trainable! + n_train_examples = math.ceil(n_train_examples / self.global_batch_size) * self.global_batch_size + if self.max_steps is None: + num_training_steps = (n_train_examples * self.epochs) // self.global_batch_size + else: + num_training_steps = self.max_steps + + if self.lr_scheduler_type == "linear-warmup+cosine-decay": + # Set warmup steps (floor) based on `warmup_ratio` (should be 0.03 - 0.05) + num_warmup_steps = int(num_training_steps * self.warmup_ratio) + + # Default AdamW w/ specified LR & Linear Warmup / Cosine Decay & Weight Decay + # => Create Parameter Groups --> bias terms, normalization layer parameters shouldn't be decayed! + decay, no_decay = [], [] + for name, param in self.vlm.named_parameters(): + if not param.requires_grad: + continue + + # Check on any parameters with fewer than 2 dimensions or with "bias" in the name + if param.ndim <= 1 or name.endswith(".bias"): + no_decay.append(param) + else: + decay.append(param) + + # Build Parameter Groups + groups = [{"params": decay, "weight_decay": self.weight_decay}, {"params": no_decay, "weight_decay": 0.0}] + + # Create Optimizer & LR Scheduler + self.optimizer = AdamW(groups, lr=self.learning_rate) + self.lr_scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps, num_training_steps) + for param_group in self.optimizer.param_groups: + param_group["lr"] = 0.0 + + elif self.lr_scheduler_type == "constant": + num_warmup_steps = 0 + + # Default AdamW w/ specified LR & Linear Warmup / Cosine Decay & Weight Decay + # => Create Parameter Groups --> bias terms, normalization layer parameters shouldn't be decayed! + decay, no_decay = [], [] + for name, param in self.vlm.named_parameters(): + if not param.requires_grad: + continue + + # Check on any parameters with fewer than 2 dimensions or with "bias" in the name + if param.ndim <= 1 or name.endswith(".bias"): + no_decay.append(param) + else: + decay.append(param) + + # Build Parameter Groups + groups = [{"params": decay, "weight_decay": self.weight_decay}, {"params": no_decay, "weight_decay": 0.0}] + + # Create Optimizer & LR Scheduler + self.optimizer = AdamW(groups, lr=self.learning_rate) + self.lr_scheduler = get_constant_schedule(self.optimizer) + + else: + raise ValueError(f"Learning Rate Schedule with type `{self.lr_scheduler_type}` is not supported!") + + # Finalize Setup =>> Log! + overwatch.info( + "FSDP Full-Shard Strategy =>> Finalized Training Setup:\n" + f" |-> Global (Effective) Batch Size = {self.global_batch_size}\n" + f" |-> Per-Device Batch Size = {self.per_device_batch_size}\n" + f" |-> Distributed World Size = {overwatch.world_size()}\n" + f" |-> Gradient Accumulation Steps = {self.grad_accumulation_steps}\n\n" + f" |-> LLM Backbone FSDP Gradient Checkpointing = {self.enable_gradient_checkpointing}\n" + f" |-> Use FSDP Mixed Precision = {self.enable_mixed_precision_training}\n" + f" |-> Parameter Precision = {fsdp_precision_policy.param_dtype}\n" + f" |-> Reduction Precision = {fsdp_precision_policy.reduce_dtype}\n" + f" |-> Buffer Precision = {fsdp_precision_policy.buffer_dtype}\n\n" + f" |-> Default AdamW LR = {self.learning_rate}\n" + f" |-> AdamW Weight Decay = {self.weight_decay}\n" + f" |-> LR Scheduler Type = {self.lr_scheduler_type}\n" + f" |-> LR Scheduler Warmup Steps (Ratio) = {num_warmup_steps} ({self.warmup_ratio})\n" + f" |-> Dataset Size = {n_train_examples} Examples\n" + f" |-> Max Steps = {num_training_steps}\n" + ) + + def clip_grad_norm(self) -> None: + # Note =>> FSDP uses a custom `clip_grad_norm_` function; requires *uniform grad dtype* + self.vlm.clip_grad_norm_(max_norm=self.max_grad_norm) diff --git a/capvector-oft/prismatic/training/train_utils.py b/capvector-oft/prismatic/training/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..62fa76a61bad11f41ef9c387aaa691b9e25089b3 --- /dev/null +++ b/capvector-oft/prismatic/training/train_utils.py @@ -0,0 +1,56 @@ +"""Utils for training/fine-tuning scripts.""" + +import torch + +from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX + + +def get_current_action_mask(token_ids): + # Create a tensor marking positions of IGNORE_INDEX + newline_positions = token_ids != IGNORE_INDEX + + # Calculate cumulative sum to identify regions between newlines + cumsum = torch.cumsum(newline_positions, dim=1) + + # Create the mask + mask = (1 <= cumsum) & (cumsum <= ACTION_DIM) + + # Extract the action part only + action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX + mask = action_tokens_only_mask * mask + + return mask + + +def get_next_actions_mask(token_ids): + # Create a tensor marking positions of IGNORE_INDEX + newline_positions = token_ids != IGNORE_INDEX + + # Calculate cumulative sum to identify regions between newlines + cumsum = torch.cumsum(newline_positions, dim=1) + + # Create the mask + mask = cumsum > ACTION_DIM + + # Extract the action part only + action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX + mask = action_tokens_only_mask * mask + + return mask + + +def compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask): + correct_preds = (predicted_token_ids == ground_truth_token_ids) & mask + accuracy = correct_preds.sum().float() / mask.sum().float() + return accuracy + + +def compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask): + pred_continuous_actions = torch.tensor( + action_tokenizer.decode_token_ids_to_actions(predicted_token_ids[mask].cpu().numpy()) + ) + true_continuous_actions = torch.tensor( + action_tokenizer.decode_token_ids_to_actions(ground_truth_token_ids[mask].cpu().numpy()) + ) + l1_loss = torch.nn.functional.l1_loss(pred_continuous_actions, true_continuous_actions) + return l1_loss diff --git a/capvector-oft/prismatic/util/__init__.py b/capvector-oft/prismatic/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71f1ff62dffae813bf8bc9281b7e61cbaef81ccf --- /dev/null +++ b/capvector-oft/prismatic/util/__init__.py @@ -0,0 +1 @@ +from .torch_utils import check_bloat16_supported, set_global_seed diff --git a/capvector-oft/prismatic/util/batching_utils.py b/capvector-oft/prismatic/util/batching_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..558ebc4ac8536a78af73daaa6bea984d03d3f689 --- /dev/null +++ b/capvector-oft/prismatic/util/batching_utils.py @@ -0,0 +1,212 @@ +""" +batching_utils.py + +Core definitions of (Distributed) Samplers for VLM finetuning; provides functionality for construction and allocating +"split-modality" batches as described in the LLaVa paper; this makes sure that a given device/batch is either entirely +(vision, language) or (language-only) data, which leads to sizeable efficiency gains. +""" + +import math +from typing import Iterator, List, Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist +from torch.utils.data import Dataset, Sampler + + +# High-Fidelity Bitwise Reproduction of the LLaVa Codebase Sampler Strategy + Per-Rank Allocation Scheme (following +# the default batching behavior of HF's Trainer Class --> derived from `accelerate`). +# +# =>> Reference: https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L60 +# =>> Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L603 +class SplitModalitySampler(Sampler): + def __init__( + self, + dataset: Dataset, + modality_lengths: List[Tuple[bool, int]], + global_batch_size: int, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + seed: int = 0, + drop_last: bool = False, + ) -> None: + super().__init__() + self.num_replicas = num_replicas if num_replicas is not None else dist.get_world_size() + self.rank = rank if rank is not None else dist.get_rank() + self.seed, self.epoch = seed, 0 + + # Custom Parameters + self.dataset, self.modality_lengths, self.drop_last = dataset, modality_lengths, drop_last + self.global_batch_size = global_batch_size + + # For our purposes, `drop_last` is always False! + assert not self.drop_last, "SplitModalitySampler must set `drop_last = False`!" + self.total_size = math.ceil(len(self.dataset) / self.global_batch_size) * self.global_batch_size + self.num_samples = self.total_size // self.num_replicas + + @staticmethod + def reindex_batch(batch_idxs: List[int], idx2lengths: List[int], n_buckets: int) -> List[List[int]]: + """Re-indexes a batch in a way that is conducive to DistributedSampler + grouping by seqlen per rank.""" + assert len(batch_idxs) % n_buckets == 0, "Batch length is not divisible by `num_replicas`!" + + # Establish initial buckets, capacities, and max number of elements per bucket + n_examples_per_bucket = len(batch_idxs) // n_buckets + bucket_indices = [[] for _ in range(n_buckets)] + bucket_lengths = [0 for _ in range(n_buckets)] + + # Note that `batch_idxs` is already sorted by corresponding length (in descending order) + for idx in batch_idxs: + shortest_bucket_idx = bucket_lengths.index(min(bucket_lengths)) + bucket_indices[shortest_bucket_idx].append(idx) + + # Update `bucket_lengths` --> set length to infinity if at capacity! + bucket_lengths[shortest_bucket_idx] += idx2lengths[idx] + if len(bucket_indices[shortest_bucket_idx]) == n_examples_per_bucket: + bucket_lengths[shortest_bucket_idx] = float("inf") + + return bucket_indices + + def get_modality_and_length_grouped_indices(self, generator: torch.Generator) -> List[int]: + """ + Returns a list of indices so that each slice of `global_batch_size` consecutive indices corresponds to elements + of the same modality with each sub-sequence of `per_replica_batch_size` (the batch size each unique device sees + during distributed training) is roughly grouped by sequence length (for training efficiency). + """ + multimodal_indices, multimodal_lengths = zip( + *[(idx, length) for idx, (is_multimodal, length) in enumerate(self.modality_lengths) if is_multimodal] + ) + + # Handle Special Case --> no "unimodal" inputs + unimodal_split = [ + (idx, length) for idx, (is_multimodal, length) in enumerate(self.modality_lengths) if not is_multimodal + ] + if len(unimodal_split) == 0: + unimodal_indices, unimodal_lengths = [], [] + else: + unimodal_indices, unimodal_lengths = zip(*unimodal_split) + + # Create a permutation of indices for each of the multimodal and unimodal data + mm_shuffled_idxs = torch.randperm(len(multimodal_indices), generator=generator) + uni_shuffled_idxs = torch.randperm(len(unimodal_indices), generator=generator) + + # We're going to be running sorting/grouping relative to `self.global_batch_size` and `self.num_replicas` + g_bsz = self.global_batch_size + + # Break each of the permutations into batches of length `global_batch_size` + mm_batch_idxs = [mm_shuffled_idxs[i : i + g_bsz].tolist() for i in range(0, len(mm_shuffled_idxs), g_bsz)] + uni_batch_idxs = [uni_shuffled_idxs[i : i + g_bsz].tolist() for i in range(0, len(uni_shuffled_idxs), g_bsz)] + + # If "last" batch is not of length `g_bsz` --> PAD by stealing indices from the first batch! + if len(mm_batch_idxs[-1]) < g_bsz: + n_missing = g_bsz - len(mm_batch_idxs[-1]) + mm_batch_idxs[-1].extend(mm_batch_idxs[0][:n_missing]) + + if len(uni_batch_idxs) > 0 and len(uni_batch_idxs[-1]) < g_bsz: + n_missing = g_bsz - len(uni_batch_idxs[-1]) + uni_batch_idxs[-1].extend(uni_batch_idxs[0][:n_missing]) + + # Now we're going to sort each batch by length --> this will aid in grouping by length by rank (efficiency!) + mm_sorted_batch_idxs = [sorted(b, key=lambda i: multimodal_lengths[i], reverse=True) for b in mm_batch_idxs] + uni_sorted_batch_idxs = [sorted(b, key=lambda i: unimodal_lengths[i], reverse=True) for b in uni_batch_idxs] + + # IMPORTANT :: At this point, for each modality, we have a list of "batches" (made up of indices) where indices + # are sorted by example sequence length *within* each batch. To make this more concrete, consider the following: + # => World Size (`num_replicas`) = 2 + # => Global Batch Size (`g_bsz`) = 4 + # => `multimodal_indices` = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + # `multimodal_lengths` = [20, 90, 21, 22, 91, 18, 89, 19, 93, 88, 92, 17] + # + # At this point in the code, `mm_sorted_batch_idxs` might then look like the following (length in parenthesis): + # => `mm_sorted_batch_idxs`: [ + # [4 (91), 3 (21), 0 (20), 5 (18)] => Batch 1 + # [6 (89), 9 (88), 7 (19), 11 (17)] => Batch 2 + # [8 (93), 10 (92), 1 (90), 2 (21)] => Batch 3 + # ] + # + # In practice: `g_bsz` is large (= 128), and for contiguous mini-batch "slices", length variance is low. + + # PROBLEM :: We want to split these "global batches" into equal-sized pieces, so that each "replica" (GPU) + # sees a "mini-batch" of roughly the same sequence lengths; this is super useful for efficient training. + + # HOWEVER :: The default "access pattern" for splitting a large batch into mini-batches by a DistributedSampler + # is akin to a "take every k" where `k` is equal to the number of replicas (GPUs) you're training on. Or, in + # Python notation --> `rank_k_indices = flatten(mm_sorted_batch_idxs)[k::num_replicas]. + # + # Naively translating this our example means each GPU (in our world of 2 total) sees the following indices + # (grouped by "mini-batch" = `g_bsz / num_replicas` = 2 for convenience): + # => `rank_0_indices`: [ [4 (91), 0 (20)] =>> [6 (89), 7 (19)] =>> [8 (93), 1 (90)] ] + # => `rank_1_indices`: [ [3 (21), 5 (18)] =>> [9 (88), 11 (17)] =>> [10 (92), 2 (21)] ] + # + # We get lucky sometimes, but for the most part, each "mini-batch" has VASTLY DIFFERENT lengths! Bad! + + # FIX :: If we "undo" the access pattern with the following code and re-arrange the way we allocate batches + # inside the __iter__ method below, we can allocate indices appropriately. Running the following code gives us + # the following indices (grouped by "mini-batch" again for convenience): + # => `rank_0_indices`: [ [4 (91), 3 (21)] =>> [6 (89), 9 (88)] =>> [8 (93), 10 (92)] ] + # => `rank_1_indices`: [ [5 (18), 0 (20)] =>> [11 (17), 7 (19)] =>> [2 (21), 1 (90)] ] + # + # Much better! As `g_bsz` and `dataset` grow, we're more often than not getting *decent* groupings! + mm_length_bucketed_idxs = [ + self.reindex_batch(batch, multimodal_lengths, self.num_replicas) for batch in mm_sorted_batch_idxs + ] + uni_length_bucketed_idxs = [ + self.reindex_batch(batch, unimodal_lengths, self.num_replicas) for batch in uni_sorted_batch_idxs + ] + + # Note :: Because of the initial `randperm` --> we're indexing both sets from 0 (we're clobbering the range) + # => Flatten indices --> index into original `{modality}_indices` then re-batch! + mm_output_idxs = [idx for batch in mm_length_bucketed_idxs for bucket in batch for idx in bucket] + mm_reindexed = [multimodal_indices[idx] for idx in mm_output_idxs] + mm_batches = [mm_reindexed[i : i + g_bsz] for i in range(0, len(mm_reindexed), g_bsz)] + + uni_output_idxs = [idx for batch in uni_length_bucketed_idxs for bucket in batch for idx in bucket] + uni_reindexed = [unimodal_indices[idx] for idx in uni_output_idxs] + uni_batches = [uni_reindexed[i : i + g_bsz] for i in range(0, len(uni_reindexed), g_bsz)] + + # Finally, randomly permute the multimodal & unimodal batches, merging into a single stream of indices + merged_batches = mm_batches + uni_batches + merge_idxs = torch.randperm(len(merged_batches), generator=generator) + all_batches = [merged_batches[idx] for idx in merge_idxs] + + # [Quality of Life] Shift "max length" batch to index 0 --> if we OOM, it happens immediately! + all_lengths = [length + ((_n_patches := 24 * 24) if is_mm else 0) for is_mm, length in self.modality_lengths] + all_batches_max_lengths = [] + for batch in all_batches: + all_batches_max_lengths.append(max([all_lengths[idx] for idx in batch])) + + # Identify Batch with "max length" --> Swap into Index 0 + longest_batch_idx = np.argmax(all_batches_max_lengths) + all_batches[0], all_batches[longest_batch_idx] = all_batches[longest_batch_idx], all_batches[0] + + # Flatten & Return all Indices + indices = [idx for batch in all_batches for idx in batch] + return indices + + def __iter__(self) -> Iterator: + """Deterministically shuffle, then split indices by modality and length.""" + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = self.get_modality_and_length_grouped_indices(g) + assert len(set(indices)) == len(self.modality_lengths) == len(self.dataset), "Oops!" + assert (len(indices) % self.global_batch_size == 0) and (len(indices) % self.num_replicas) == 0, "Oops" + + # Note :: We compute per-replica batch size as a function of `global_batch` and `num_replicas` to ensure that + # gradient accumulation doesn't affect what indices are assigned a given rank. + per_replica_batch_size = self.global_batch_size // self.num_replicas + + # Tensorize & Unravel --> rather than yielding via a `take_every` --> we want to partition a global batch + # across replicas by assigning each a contiguous sub-sequence. + indices_t = torch.as_tensor(indices) + per_replica_batch_indices_t = indices_t.reshape(-1, per_replica_batch_size) + replica_indices_t = per_replica_batch_indices_t[self.rank :: self.num_replicas] + + replica_indices = replica_indices_t.flatten().tolist() + return iter(replica_indices) + + def __len__(self) -> int: + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + """To be called *between* epochs, prior to DataLoader instantiation; ensures random order across epochs.""" + self.epoch = epoch diff --git a/capvector-oft/prismatic/util/data_utils.py b/capvector-oft/prismatic/util/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1abdca892b115cc837de316681048bd766eb7b81 --- /dev/null +++ b/capvector-oft/prismatic/util/data_utils.py @@ -0,0 +1,156 @@ +""" +data_utils.py + +General utilities and classes for facilitating data loading and collation. +""" + +from dataclasses import dataclass +from typing import Callable, Dict, Sequence, Tuple + +import numpy as np +import torch +from torch.nn.utils.rnn import pad_sequence + +# HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels) +IGNORE_INDEX = -100 + + +def tree_map(fn: Callable, tree: dict) -> dict: + """Maps a function over a nested dictionary.""" + return {k: tree_map(fn, v) if isinstance(v, dict) else fn(v) for k, v in tree.items()} + + +def tree_map_with_key(fn: Callable, tree: dict, keys: Sequence = ()) -> dict: + """Maps a function over a nested dictionary.""" + return { + k: tree_map_with_key(fn, v, (*keys, k)) if isinstance(v, dict) else fn((*keys, k), v) for k, v in tree.items() + } + + +@dataclass +class PaddedCollatorForLanguageModeling: + model_max_length: int + pad_token_id: int + default_image_resolution: Tuple[int, int, int] + padding_side: str = "right" + pixel_values_dtype: torch.dtype = torch.float32 + + def __post_init__(self) -> None: + self.dummy_pixel_values = torch.zeros(self.default_image_resolution, dtype=self.pixel_values_dtype) + + def __call__(self, instances: Sequence[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) + pixel_values = [instance["pixel_values"] for instance in instances] + + # For now, we only support Tokenizers with `padding_side = "right"` during Training (but plan to extend!) + # => Handle padding via RNN Utils => `pad_sequence` + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id) + labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) + + # Truncate (if necessary) + input_ids, labels = input_ids[:, : self.model_max_length], labels[:, : self.model_max_length] + + # Get `attention_mask` by checking for `pad_token_id` + attention_mask = input_ids.ne(self.pad_token_id) + + # === Handle "unimodal" (language-only) vs. "multimodal" === + + # Some examples are "language-only" --> build a Tensor of `multimodal_indices` that we can slice into easily + multimodal_indices = torch.tensor( + [idx for idx in range(len(pixel_values)) if pixel_values[idx] is not None], dtype=torch.long + ) + + # Stack all `pixel_values` --> depending on type (torch.Tensor, or Dict[str, torch.Tensor]) & presence of None + if len(multimodal_indices) == 0: + pixel_values = torch.stack([self.dummy_pixel_values for _ in range(len(input_ids))]) + elif isinstance(pv_example := pixel_values[multimodal_indices[0]], torch.Tensor): + pixel_values = torch.stack( + [ + pixel_values[idx] if idx in multimodal_indices else self.dummy_pixel_values + for idx in range(len(input_ids)) + ] + ) + elif isinstance(pv_example, dict): + pixel_values = { + k: torch.stack( + [ + pixel_values[idx][k] if idx in multimodal_indices else self.dummy_pixel_values + for idx in range(len(input_ids)) + ] + ) + for k in pv_example + } + else: + raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") + + return dict( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + multimodal_indices=multimodal_indices, + ) + + +@dataclass +class PaddedCollatorForActionPrediction: + model_max_length: int + pad_token_id: int + padding_side: str = "right" + pixel_values_dtype: torch.dtype = torch.float32 + + def __call__(self, instances: Sequence[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) + pixel_values = [instance["pixel_values"] for instance in instances] + if "dataset_name" in instances[0]: + dataset_names = [instance["dataset_name"] for instance in instances] + else: + dataset_names = None + + # For now, we only support Tokenizers with `padding_side = "right"` during training + # => Handle padding via RNN Utils => `pad_sequence` + assert self.padding_side == "right", f"Invalid Tokenizer `{self.padding_side = }`" + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id) + labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) + + # Truncate (if necessary) + input_ids, labels = input_ids[:, : self.model_max_length], labels[:, : self.model_max_length] + + # Get `attention_mask` by checking for `pad_token_id` + attention_mask = input_ids.ne(self.pad_token_id) + + # [Contract] For VLA Training =>> No "Unimodal" Data! + assert all([pv is not None for pv in pixel_values]), "Invalid VLA Example with `pixel_values = None`!" + + # Stack all `pixel_values` --> depending on type is torch.Tensor or Dict[str, torch.Tensor] + if isinstance(pixel_values[0], torch.Tensor): + if "pixel_values_wrist" in instances[0]: + pixel_values_wrist = [instance["pixel_values_wrist"] for instance in instances] + pixel_values = torch.cat((torch.stack(pixel_values), torch.stack(pixel_values_wrist)), dim=1) + else: + pixel_values = torch.stack(pixel_values) + else: + raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") + + # Stack all actions + actions = [torch.from_numpy(np.copy(instance["actions"])) for instance in instances] + actions = torch.stack(actions) + + # Stack proprio + if "proprio" in instances[0]: + proprio = [instance["proprio"] for instance in instances] + proprio = torch.Tensor(np.squeeze(np.stack(proprio))) + else: + proprio = None + + output = dict( + pixel_values=pixel_values, + proprio=proprio, + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + actions=actions, + ) + if dataset_names is not None: + output["dataset_names"] = dataset_names + return output diff --git a/capvector-oft/prismatic/util/nn_utils.py b/capvector-oft/prismatic/util/nn_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cb62c6be7d736aef9057bb26a4c2717af830f2f4 --- /dev/null +++ b/capvector-oft/prismatic/util/nn_utils.py @@ -0,0 +1,53 @@ +""" +nn_utils.py + +Utility functions and PyTorch submodule definitions. +""" + +import torch +import torch.nn as nn + + +# === Definitions for Various Projection Modules, with Signature :: [..., in_dim] --> [..., out_dim] === +class LinearProjector(nn.Module): + def __init__(self, vision_dim: int, llm_dim: int) -> None: + super().__init__() + self.projector = nn.Linear(vision_dim, llm_dim, bias=True) + + def forward(self, img_patches: torch.Tensor) -> torch.Tensor: + return self.projector(img_patches) + + +class MLPProjector(nn.Module): + def __init__(self, vision_dim: int, llm_dim: int, mlp_type: str = "gelu-mlp") -> None: + super().__init__() + if mlp_type == "gelu-mlp": + self.projector = nn.Sequential( + nn.Linear(vision_dim, llm_dim, bias=True), + nn.GELU(), + nn.Linear(llm_dim, llm_dim, bias=True), + ) + else: + raise ValueError(f"Projector with `{mlp_type = }` is not supported!") + + def forward(self, img_patches: torch.Tensor) -> torch.Tensor: + return self.projector(img_patches) + + +class FusedMLPProjector(nn.Module): + def __init__(self, fused_vision_dim: int, llm_dim: int, mlp_type: str = "fused-gelu-mlp") -> None: + super().__init__() + self.initial_projection_dim = fused_vision_dim * 4 + if mlp_type == "fused-gelu-mlp": + self.projector = nn.Sequential( + nn.Linear(fused_vision_dim, self.initial_projection_dim, bias=True), + nn.GELU(), + nn.Linear(self.initial_projection_dim, llm_dim, bias=True), + nn.GELU(), + nn.Linear(llm_dim, llm_dim, bias=True), + ) + else: + raise ValueError(f"Fused Projector with `{mlp_type = }` is not supported!") + + def forward(self, fused_img_patches: torch.Tensor) -> torch.Tensor: + return self.projector(fused_img_patches) diff --git a/capvector-oft/prismatic/util/torch_utils.py b/capvector-oft/prismatic/util/torch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ddef43b290bc55825a687e9ddd69d2ce532ab7d4 --- /dev/null +++ b/capvector-oft/prismatic/util/torch_utils.py @@ -0,0 +1,95 @@ +""" +torch_utils.py + +General utilities for randomness, mixed precision training, and miscellaneous checks in PyTorch. + +Random `set_global_seed` functionality is taken directly from PyTorch-Lighting: + > Ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/seed.py + +This is pretty important to get right if we're every randomly generating our masks (or prefix dropout) inside our +Dataset __getitem__() with multiple workers... if not handled properly, we will get repeated augmentations anytime +we inject randomness from non-PyTorch sources (e.g., numpy, random)! + > Ref: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/ + +Terminology + -> World Size :: Total number of processes distributed over (# nodes x # devices) -- assumed homogenous! + -> Rank :: Integer index of current process in the total world size + -> Local Rank :: Local index on given node in [0, Devices per Node] +""" + +import os +import random +from typing import Callable, Optional + +import numpy as np +import torch + +# === Randomness === + + +def set_global_seed(seed: int, get_worker_init_fn: bool = False) -> Optional[Callable[[int], None]]: + """Sets seed for all randomness libraries (mostly random, numpy, torch) and produces a `worker_init_fn`""" + assert np.iinfo(np.uint32).min < seed < np.iinfo(np.uint32).max, "Seed outside the np.uint32 bounds!" + + # Set Seed as an Environment Variable + os.environ["EXPERIMENT_GLOBAL_SEED"] = str(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + return worker_init_function if get_worker_init_fn else None + + +def worker_init_function(worker_id: int) -> None: + """ + Borrowed directly from PyTorch-Lightning; inspired by this issue comment in the PyTorch repo: + > Ref: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 + + Intuition: You can think of the seed sequence spawn function as a "janky" torch.Generator() or jax.PRNGKey that + you can run iterative splitting on to get new (predictable) randomness. + + :param worker_id: Identifier for the given worker [0, num_workers) for the Dataloader in question. + """ + # Get current `rank` (if running distributed) and `process_seed` + global_rank, process_seed = int(os.environ["LOCAL_RANK"]), torch.initial_seed() + + # Back out the "base" (original) seed - the per-worker seed is set in PyTorch: + # > https://pytorch.org/docs/stable/data.html#data-loading-randomness + base_seed = process_seed - worker_id + + # "Magic" code --> basically creates a seed sequence that mixes different "sources" and seeds every library... + seed_seq = np.random.SeedSequence([base_seed, worker_id, global_rank]) + + # Use 128 bits (4 x 32-bit words) to represent seed --> generate_state(k) produces a `k` element array! + np.random.seed(seed_seq.generate_state(4)) + + # Spawn distinct child sequences for PyTorch (reseed) and stdlib random + torch_seed_seq, random_seed_seq = seed_seq.spawn(2) + + # Torch Manual seed takes 64 bits (so just specify a dtype of uint64 + torch.manual_seed(torch_seed_seq.generate_state(1, dtype=np.uint64)[0]) + + # Use 128 Bits for `random`, but express as integer instead of as an array + random_seed = (random_seed_seq.generate_state(2, dtype=np.uint64).astype(list) * [1 << 64, 1]).sum() + random.seed(random_seed) + + +# === BFloat16 Support === + + +def check_bloat16_supported() -> bool: + try: + import packaging.version + import torch.cuda.nccl as nccl + import torch.distributed as dist + + return ( + (torch.version.cuda is not None) + and torch.cuda.is_bf16_supported() + and (packaging.version.parse(torch.version.cuda).release >= (11, 0)) + and dist.is_nccl_available() + and (nccl.version() >= (2, 10)) + ) + + except Exception: + return False diff --git a/capvector-oft/prismatic/vla/__init__.py b/capvector-oft/prismatic/vla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bd2de0ce872181a9c7e3e9bfafd30e90381cd3e6 --- /dev/null +++ b/capvector-oft/prismatic/vla/__init__.py @@ -0,0 +1 @@ +from .materialize import get_vla_dataset_and_collator diff --git a/capvector-oft/prismatic/vla/action_tokenizer.py b/capvector-oft/prismatic/vla/action_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..bbb6ffa4ae191dade3d940c4363f02c92a270b36 --- /dev/null +++ b/capvector-oft/prismatic/vla/action_tokenizer.py @@ -0,0 +1,72 @@ +""" +action_tokenizer.py + +Extension class; wraps base LLM/VLM tokenizer with logic to discretize and tokenize continuous robot actions. +""" + +from typing import List, Union + +import numpy as np +from transformers import PreTrainedTokenizerBase + + +class ActionTokenizer: + def __init__( + self, tokenizer: PreTrainedTokenizerBase, bins: int = 256, min_action: int = -1, max_action: int = 1 + ) -> None: + """ + Discretizes continuous robot actions into N bins per dimension and maps to the least used tokens. + + NOTE =>> by default, assumes a BPE-style tokenizer akin to the LlamaTokenizer, where *the least used tokens* + appear at the end of the vocabulary! + + :param tokenizer: Base LLM/VLM tokenizer to extend. + :param bins: Number of bins for each continuous value; we'll adopt a uniform binning strategy. + :param min_action: Minimum action value (for clipping, setting lower bound on bin interval). + :param max_action: Maximum action value (for clipping, setting upper bound on bin interval). + """ + self.tokenizer, self.n_bins, self.min_action, self.max_action = tokenizer, bins, min_action, max_action + + # Create Uniform Bins + Compute Bin Centers + self.bins = np.linspace(min_action, max_action, self.n_bins) + self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0 + + # [Contract] Set "action_token_begin_idx" based on `self.tokenizer.vocab_size - (self.n_bins + 1)` + # =>> Assumes we're always overwriting the final `n_bins` tokens of the vocabulary! + self.action_token_begin_idx: int = int(self.tokenizer.vocab_size - (self.n_bins + 1)) + + def __call__(self, action: np.ndarray) -> Union[str, List[str]]: + """Clip & bin actions to *the last `n_bins` tokens* of the vocabulary (e.g., tokenizer.vocab[-256:]).""" + action = np.clip(action, a_min=float(self.min_action), a_max=float(self.max_action)) + discretized_action = np.digitize(action, self.bins) + + # Handle single element vs. batch + if len(discretized_action.shape) == 1: + return self.tokenizer.decode(list(self.tokenizer.vocab_size - discretized_action)) + else: + return self.tokenizer.batch_decode((self.tokenizer.vocab_size - discretized_action).tolist()) + + def decode_token_ids_to_actions(self, action_token_ids: np.ndarray) -> np.ndarray: + """ + Returns continuous actions for discrete action token IDs. + + NOTE =>> Because of the way the actions are discretized w.r.t. the bins (and not the bin centers), the + digitization returns bin indices between [1, # bins], inclusive, when there are actually only + (# bins - 1) bin intervals. + + Therefore, if the digitization returns the last possible index, we map this to the last bin interval. + + EXAMPLE =>> Let's say self._bins has 256 values. Then self._bin_centers has 255 values. Digitization returns + indices between [1, 256]. We subtract 1 from all indices so that they are between [0, 255]. There + is still one index (i==255) that would cause an out-of-bounds error if used to index into + self._bin_centers. Therefore, if i==255, we subtract 1 from it so that it just becomes the index of + the last bin center. We implement this simply via clipping between [0, 255 - 1]. + """ + discretized_actions = self.tokenizer.vocab_size - action_token_ids + discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1) + + return self.bin_centers[discretized_actions] + + @property + def vocab_size(self) -> int: + return self.n_bins diff --git a/capvector-oft/prismatic/vla/constants.py b/capvector-oft/prismatic/vla/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..81e8f9941218d81cf76cd3ce4fe07182813ed442 --- /dev/null +++ b/capvector-oft/prismatic/vla/constants.py @@ -0,0 +1,86 @@ +""" +Important constants for VLA training and evaluation. + +Attempts to automatically identify the correct constants to set based on the Python command used to launch +training or evaluation. If it is unclear, defaults to using the LIBERO simulation benchmark constants. +""" +import sys +from enum import Enum + +# Llama 2 token constants +IGNORE_INDEX = -100 +ACTION_TOKEN_BEGIN_IDX = 31743 +STOP_INDEX = 2 # '' + + +# Defines supported normalization schemes for action and proprioceptive state. +class NormalizationType(str, Enum): + # fmt: off + NORMAL = "normal" # Normalize to Mean = 0, Stdev = 1 + BOUNDS = "bounds" # Normalize to Interval = [-1, 1] + BOUNDS_Q99 = "bounds_q99" # Normalize [quantile_01, ..., quantile_99] --> [-1, ..., 1] + # fmt: on + + +# Define constants for each robot platform +LIBERO_CONSTANTS = { + "NUM_ACTIONS_CHUNK": 8, + "ACTION_DIM": 7, + "PROPRIO_DIM": 8, + "ACTION_PROPRIO_NORMALIZATION_TYPE": NormalizationType.BOUNDS_Q99, +} + +ALOHA_CONSTANTS = { + "NUM_ACTIONS_CHUNK": 25, + "ACTION_DIM": 14, + "PROPRIO_DIM": 14, + "ACTION_PROPRIO_NORMALIZATION_TYPE": NormalizationType.BOUNDS, +} + +BRIDGE_CONSTANTS = { + "NUM_ACTIONS_CHUNK": 5, + "ACTION_DIM": 7, + "PROPRIO_DIM": 7, + "ACTION_PROPRIO_NORMALIZATION_TYPE": NormalizationType.BOUNDS_Q99, +} + + +# Function to detect robot platform from command line arguments +def detect_robot_platform(): + cmd_args = " ".join(sys.argv).lower() + + if "libero" in cmd_args: + return "LIBERO" + elif "aloha" in cmd_args: + return "ALOHA" + elif "bridge" in cmd_args: + return "BRIDGE" + else: + # Default to LIBERO if unclear + return "LIBERO" + + +# Determine which robot platform to use +ROBOT_PLATFORM = detect_robot_platform() + +# Set the appropriate constants based on the detected platform +if ROBOT_PLATFORM == "LIBERO": + constants = LIBERO_CONSTANTS +elif ROBOT_PLATFORM == "ALOHA": + constants = ALOHA_CONSTANTS +elif ROBOT_PLATFORM == "BRIDGE": + constants = BRIDGE_CONSTANTS + +# Assign constants to global variables +NUM_ACTIONS_CHUNK = constants["NUM_ACTIONS_CHUNK"] +ACTION_DIM = constants["ACTION_DIM"] +PROPRIO_DIM = constants["PROPRIO_DIM"] +ACTION_PROPRIO_NORMALIZATION_TYPE = constants["ACTION_PROPRIO_NORMALIZATION_TYPE"] + +# Print which robot platform constants are being used (for debugging) +print(f"Using {ROBOT_PLATFORM} constants:") +print(f" NUM_ACTIONS_CHUNK = {NUM_ACTIONS_CHUNK}") +print(f" ACTION_DIM = {ACTION_DIM}") +print(f" PROPRIO_DIM = {PROPRIO_DIM}") +print(f" ACTION_PROPRIO_NORMALIZATION_TYPE = {ACTION_PROPRIO_NORMALIZATION_TYPE}") +print("If needed, manually set the correct constants in `prismatic/vla/constants.py`!") diff --git a/capvector-oft/prismatic/vla/datasets/__init__.py b/capvector-oft/prismatic/vla/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..343cadd8964350abbd7fe75e6d1286b6756db795 --- /dev/null +++ b/capvector-oft/prismatic/vla/datasets/__init__.py @@ -0,0 +1 @@ +from .datasets import DummyDataset, EpisodicRLDSDataset, RLDSBatchTransform, RLDSDataset diff --git a/capvector-oft/prismatic/vla/datasets/datasets.py b/capvector-oft/prismatic/vla/datasets/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..ddc3deaac3c180a5d862884f09010358de4b149c --- /dev/null +++ b/capvector-oft/prismatic/vla/datasets/datasets.py @@ -0,0 +1,261 @@ +""" +datasets.py + +Lightweight PyTorch Dataset Definition for wrapping RLDS TFDS Pipeline; just defines transform from RLDS default +format to OpenVLA, IterableDataset shim. +""" + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Tuple, Type + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import Dataset, IterableDataset +from transformers import PreTrainedTokenizerBase + +from prismatic.models.backbones.llm.prompting import PromptBuilder +from prismatic.models.backbones.vision import ImageTransform +from prismatic.util.data_utils import tree_map +from prismatic.vla.action_tokenizer import ActionTokenizer +from prismatic.vla.constants import ACTION_DIM, ACTION_PROPRIO_NORMALIZATION_TYPE, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, PROPRIO_DIM, STOP_INDEX +from prismatic.vla.datasets.rlds import make_interleaved_dataset, make_single_dataset +from prismatic.vla.datasets.rlds.oxe import OXE_NAMED_MIXTURES, get_oxe_dataset_kwargs_and_weights + +@dataclass +class RLDSBatchTransform: + action_tokenizer: ActionTokenizer + base_tokenizer: PreTrainedTokenizerBase + image_transform: ImageTransform + prompt_builder_fn: Type[PromptBuilder] + predict_stop_token: bool = True + use_wrist_image: bool = False + use_proprio: bool = False + + def __call__(self, rlds_batch: Dict[str, Any]) -> Dict[str, Any]: + """Converts a RLDS batch to the format expected by the OpenVLA collator/models.""" + dataset_name, current_action = rlds_batch["dataset_name"], rlds_batch["action"][0] + img = Image.fromarray(rlds_batch["observation"]["image_primary"][0]) + lang = rlds_batch["task"]["language_instruction"].decode().lower() + actions = rlds_batch["action"] + + # Construct Chat-based Prompt =>> Input is default query + language instruction, output are the action tokens + prompt_builder = self.prompt_builder_fn("openvla") + + # Get future action chunk + future_actions = rlds_batch["action"][1:] + future_actions_string = ''.join(self.action_tokenizer(future_actions)) + + # Get action chunk string + current_action_string = self.action_tokenizer(current_action) + action_chunk_string = current_action_string + future_actions_string + action_chunk_len = len(action_chunk_string) + + conversation = [ + {"from": "human", "value": f"What action should the robot take to {lang}?"}, + {"from": "gpt", "value": action_chunk_string}, + ] + for turn in conversation: + prompt_builder.add_turn(turn["from"], turn["value"]) + + # Tokenize (w/ `base_tokenizer`) + input_ids = self.base_tokenizer(prompt_builder.get_prompt(), add_special_tokens=True).input_ids + labels = list(input_ids) + + # Tensorize =>> Run Image Transform to get `pixel_values` =>> Return + # =>> IMPORTANT :: IF WE'RE USING HF LLM.forward(..., labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL! + input_ids, labels = torch.tensor(input_ids), torch.tensor(labels) + pixel_values = self.image_transform(img) + + # [CRITICAL] We do not want to take the loss for anything but the predicted action tokens! + labels[: -(action_chunk_len + 1)] = IGNORE_INDEX + if not self.predict_stop_token: + labels[-1] = IGNORE_INDEX + + return_dict = dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels, dataset_name=dataset_name, actions=actions) + + # Add additional inputs + if self.use_wrist_image: + all_wrist_pixels = [] + for k in rlds_batch["observation"].keys(): + if "wrist" in k: + img_wrist = Image.fromarray(rlds_batch["observation"][k][0]) + pixel_values_wrist = self.image_transform(img_wrist) + all_wrist_pixels.append(pixel_values_wrist) + return_dict["pixel_values_wrist"] = torch.cat(all_wrist_pixels, dim=0) + if self.use_proprio and "proprio" in rlds_batch["observation"]: + proprio = rlds_batch["observation"]["proprio"] + return_dict["proprio"] = proprio + + return return_dict + + +class RLDSDataset(IterableDataset): + def __init__( + self, + data_root_dir: Path, + data_mix: str, + batch_transform: RLDSBatchTransform, + resize_resolution: Tuple[int, int], + shuffle_buffer_size: int = 256_000, + train: bool = True, + image_aug: bool = False, + ) -> None: + """Lightweight wrapper around RLDS TFDS Pipeline for use with PyTorch/OpenVLA Data Loaders.""" + self.data_root_dir, self.data_mix, self.batch_transform = data_root_dir, data_mix, batch_transform + + # Configure RLDS Dataset(s) + if self.data_mix in OXE_NAMED_MIXTURES: + mixture_spec = OXE_NAMED_MIXTURES[self.data_mix] + else: + # Assume that passed "mixture" name is actually a single dataset -- create single-dataset "mix" + mixture_spec = [(self.data_mix, 1.0)] + + # fmt: off + if "aloha" in self.data_mix: + load_camera_views = ("primary", "left_wrist", "right_wrist") + else: + load_camera_views = ("primary", "wrist") + + per_dataset_kwargs, weights = get_oxe_dataset_kwargs_and_weights( + self.data_root_dir, + mixture_spec, + load_camera_views=load_camera_views, + load_depth=False, + load_proprio=True, + load_language=True, + action_proprio_normalization_type=ACTION_PROPRIO_NORMALIZATION_TYPE, + ) + rlds_config = dict( + traj_transform_kwargs=dict( + window_size=1, # If we wanted to feed / predict more than one step + future_action_window_size=NUM_ACTIONS_CHUNK-1, # For action chunking + skip_unlabeled=True, # Skip trajectories without language labels + goal_relabeling_strategy="uniform", # Goals are currently unused + ), + frame_transform_kwargs=dict( + resize_size=resize_resolution, + num_parallel_calls=16, # For CPU-intensive ops (decoding, resizing, etc.) + ), + dataset_kwargs_list=per_dataset_kwargs, + shuffle_buffer_size=shuffle_buffer_size, + sample_weights=weights, + balance_weights=True, + traj_transform_threads=len(mixture_spec), + traj_read_threads=len(mixture_spec), + train=train, + ) + + # If applicable, enable image augmentations + if image_aug: + rlds_config["frame_transform_kwargs"].update({"image_augment_kwargs" : dict( + random_resized_crop=dict(scale=[0.9, 0.9], ratio=[1.0, 1.0]), + random_brightness=[0.2], + random_contrast=[0.8, 1.2], + random_saturation=[0.8, 1.2], + random_hue=[0.05], + augment_order=[ + "random_resized_crop", + "random_brightness", + "random_contrast", + "random_saturation", + "random_hue", + ], + )}), + # fmt: on + + # Initialize RLDS Dataset + self.dataset, self.dataset_length, self.dataset_statistics = self.make_dataset(rlds_config) + + def make_dataset(self, rlds_config): + return make_interleaved_dataset(**rlds_config) + + def __iter__(self) -> Dict[str, Any]: + for rlds_batch in self.dataset.as_numpy_iterator(): + yield self.batch_transform(rlds_batch) + + def __len__(self) -> int: + return self.dataset_length + + # === Explicitly Unused === + def __getitem__(self, idx: int) -> None: + raise NotImplementedError("IterableDataset does not implement map-style __getitem__; see __iter__ instead!") + + +class EpisodicRLDSDataset(RLDSDataset): + """Returns full episodes as list of steps instead of individual transitions (useful for visualizations).""" + + def make_dataset(self, rlds_config): + per_dataset_kwargs = rlds_config["dataset_kwargs_list"] + assert len(per_dataset_kwargs) == 1, "Only support single-dataset `mixes` for episodic datasets." + + return make_single_dataset( + per_dataset_kwargs[0], + train=rlds_config["train"], + traj_transform_kwargs=rlds_config["traj_transform_kwargs"], + frame_transform_kwargs=rlds_config["frame_transform_kwargs"], + ) + + def __iter__(self) -> Dict[str, Any]: + for rlds_batch in self.dataset.as_numpy_iterator(): + out = [ + self.batch_transform(tree_map(lambda x: x[i], rlds_batch)) # noqa: B023 + for i in range(rlds_batch["action"].shape[0]) + ] + yield out + + +class DummyDataset(Dataset): + def __init__( + self, + action_tokenizer: ActionTokenizer, + base_tokenizer: PreTrainedTokenizerBase, + image_transform: ImageTransform, + prompt_builder_fn: Type[PromptBuilder], + ) -> None: + self.action_tokenizer = action_tokenizer + self.base_tokenizer = base_tokenizer + self.image_transform = image_transform + self.prompt_builder_fn = prompt_builder_fn + + # Note =>> We expect the dataset to store statistics for action de-normalization. Specifically, we store the + # per-dimension 1st and 99th action quantile. The values below correspond to "no normalization" for simplicity. + self.dataset_statistics = { + "dummy_dataset": { + "action": {"q01": np.zeros((7,), dtype=np.float32), "q99": np.ones((7,), dtype=np.float32)} + } + } + + def __len__(self): + # TODO =>> Replace with number of elements in your dataset! + return 10000 + + def __getitem__(self, idx): + # TODO =>> Load image, action and instruction from disk -- we use dummy values + image = Image.fromarray(np.asarray(np.random.rand(224, 224, 3) * 255.0, dtype=np.uint8)) + action = np.asarray(np.random.rand(7), dtype=np.float32) + instruction = "do something spectacular" + + # Add instruction to VLA prompt + prompt_builder = self.prompt_builder_fn("openvla") + conversation = [ + {"from": "human", "value": f"What action should the robot take to {instruction}?"}, + {"from": "gpt", "value": self.action_tokenizer(action)}, + ] + for turn in conversation: + prompt_builder.add_turn(turn["from"], turn["value"]) + + # Tokenize (w/ `base_tokenizer`) + input_ids = self.base_tokenizer(prompt_builder.get_prompt(), add_special_tokens=True).input_ids + labels = list(input_ids) + + # Tensorize =>> Run Image Transform to get `pixel_values` =>> Return + # =>> IMPORTANT :: IF WE'RE USING HF .forward(..., labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL! + input_ids, labels = torch.tensor(input_ids), torch.tensor(labels) + pixel_values = self.image_transform(image) + + # [CRITICAL] We do not want to take the loss for anything but the predicted action tokens! + labels[: -(len(action) + 1)] = IGNORE_INDEX + + return dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels) diff --git a/capvector-oft/prismatic/vla/datasets/rlds/__init__.py b/capvector-oft/prismatic/vla/datasets/rlds/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4b260fa5b7326012e85e54be203fa4932f35783d --- /dev/null +++ b/capvector-oft/prismatic/vla/datasets/rlds/__init__.py @@ -0,0 +1 @@ +from .dataset import make_interleaved_dataset, make_single_dataset diff --git a/capvector-oft/prismatic/vla/datasets/rlds/dataset.py b/capvector-oft/prismatic/vla/datasets/rlds/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9ff0424d071c8888c8f75e54b2eeaedc7b9bd9c5 --- /dev/null +++ b/capvector-oft/prismatic/vla/datasets/rlds/dataset.py @@ -0,0 +1,585 @@ +""" +dataset.py + +Core interface script for configuring and initializing RLDS datasets. +""" + +import copy +import inspect +import json +from functools import partial +from typing import Callable, Dict, List, Optional, Tuple, Union + +import dlimp as dl +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds + +from prismatic.overwatch import initialize_overwatch +from prismatic.vla.constants import ACTION_DIM, ACTION_PROPRIO_NORMALIZATION_TYPE, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, PROPRIO_DIM, STOP_INDEX +from prismatic.vla.datasets.rlds import obs_transforms, traj_transforms +from prismatic.vla.datasets.rlds.utils import goal_relabeling, task_augmentation +from prismatic.vla.datasets.rlds.utils.data_utils import ( + allocate_threads, + get_dataset_statistics, + normalize_action_and_proprio, + pprint_data_mixture, + tree_map, +) + +# Initialize Overwatch =>> Wraps `logging.Logger` +overwatch = initialize_overwatch(__name__) + + +# Configure Tensorflow with *no GPU devices* (to prevent clobber with PyTorch) +tf.config.set_visible_devices([], "GPU") + + +# ruff: noqa: B006 +def make_dataset_from_rlds( + name: str, + data_dir: str, + *, + train: bool, + standardize_fn: Optional[Callable[[dict], dict]] = None, + shuffle: bool = True, + image_obs_keys: Dict[str, Optional[str]] = {}, + depth_obs_keys: Dict[str, Optional[str]] = {}, + state_obs_keys: List[Optional[str]] = (), + language_key: Optional[str] = None, + action_proprio_normalization_type: ACTION_PROPRIO_NORMALIZATION_TYPE, + dataset_statistics: Optional[Union[dict, str]] = None, + absolute_action_mask: Optional[List[bool]] = None, + action_normalization_mask: Optional[List[bool]] = None, + num_parallel_reads: int = tf.data.AUTOTUNE, + num_parallel_calls: int = tf.data.AUTOTUNE, +) -> Tuple[dl.DLataset, dict]: + """ + This function is responsible for loading a specific RLDS dataset from storage and getting it into a standardized + format. Yields a dataset of trajectories. Does not include CPU-intensive operations. + + If `standardize_fn` is provided, it will be applied to each trajectory. This function should get the trajectory + into a standard format, which includes the keys "observation" and "action". Entry "observation" should be a + dictionary containing some number of additional keys, which will be extracted into an even more standardized format + according to the "*_obs_keys" arguments. + + The `image_obs_keys` and `depth_obs_keys` arguments are mappings from new names to old names, or None in place of an + old name to insert padding. For example, if after `standardize_fn`, your "observation" dict has RGB images called + "workspace" and "wrist", and `image_obs_keys={"primary": "workspace", "secondary": None, "wrist": "wrist"}`, then + the resulting dataset will have an "observation" dict containing the keys "image_primary", "image_secondary", and + "image_wrist", where "image_primary" corresponds to "workspace", "image_secondary" is a padding image, and + "image_wrist" corresponds to "wrist". + + Entry `state_obs_keys` is a list of 1-dimensional proprioceptive keys to concatenate into a single array, which will + be placed in the "proprio" key of the "observation" dict. A single padding element (zero) will be inserted for each + None entry. + + The dataset will also include a "task" dict. If `language_key` is provided, then the "task" dict will contain the + key "language_instruction", extracted from `traj[language_key]`. + + Args: + name (str): The name of the RLDS dataset (usually "name" or "name:version"). + data_dir (str): The path to the data directory. + train (bool): Whether to use the training or validation split. + shuffle (bool, optional): Whether to shuffle the file read order (does NOT fully shuffle the dataset, since one + file usually contains many trajectories)! + standardize_fn (Callable[[dict], dict], optional): A function that, if provided, will be the first + thing applied to each trajectory. + image_obs_keys (Mapping[str, str|None]): Mapping from {new: old} indicating which RGB images to extract from the + "observation" dict. `new_obs = {f"image_{new}": old_obs[old] for new, old in image_obs_keys.items()}`. + If a value of `old` is None, inserts a padding image instead (empty string). + depth_obs_keys (Mapping[str, str|None]): Same as `image_obs_keys`, but for depth images. Keys will be + prefixed with "depth_" instead of "image_". + state_obs_keys (Sequence[str|None]): List of 1-dimensional proprioception keys to be extracted from the + "observation" dict, concatenated, and mapped to "proprio". Inserts 1 element of padding for each None entry. + language_key (str, optional): If provided, the "task" dict will contain the key "language_instruction", + extracted from `traj[language_key]`. + action_proprio_normalization_type (str, optional): The type of normalization to perform on the action, + proprio, or both. Can be "normal" (mean 0, std 1) or "bounds" (normalized to [-1, 1]). + dataset_statistics: (dict|str, optional): dict (or path to JSON file) that contains dataset statistics + for normalization. If `action_proprio_normalization_type` is "normal", this should contain "mean" and + "std" keys. If `action_proprio_normalization_type` is "bounds", this should contain "min" and "max" + keys. May also provide "num_transitions" and "num_trajectories" keys for downstream usage (e.g., for + `make_interleaved_dataset`). If not provided, the statistics will be computed on the fly. + absolute_action_mask (Sequence[bool], optional): By default, all action dimensions are assumed to be + relative. This is important for when `future_action_window_size > 0`: actions that are taken + from beyond the end of the trajectory (or beyond the goal timestep when goal relabeling is used) + need to be made "neutral" to indicate that the task has been completed. For relative actions, + "neutral" means zero, but for absolute actions, "neutral" means repeating the last valid action. + This mask, if provided, indicates which action dimensions are absolute. + action_normalization_mask (Sequence[bool], optional): If provided, indicates which action dimensions + should be normalized. For example, you might not want to normalize the gripper action dimension if + it's always exactly 0 or 1. By default, all action dimensions are normalized. + num_parallel_reads (int): number of parallel read workers. Default to AUTOTUNE. + num_parallel_calls (int): number of parallel calls for traj_map operations. Default to AUTOTUNE. + Returns: + Dataset of trajectories where each step has the following fields: + - observation: + - image_{name1, name2, ...} # RGB image observations + - depth_{name1, name2, ...} # depth image observations + - proprio # 1-dimensional array of proprioceptive observations + - timestep # timestep of each frame + - task: + - language_instruction # language instruction, present if `language_key` is provided + - action # action vector + - dataset_name # name of the dataset + """ + REQUIRED_KEYS = {"observation", "action"} + if language_key is not None: + REQUIRED_KEYS.add(language_key) + + def restructure(traj): + # apply a standardization function, if provided + if standardize_fn is not None: + traj = standardize_fn(traj) + + if not all(k in traj for k in REQUIRED_KEYS): + raise ValueError( + f"Trajectory is missing keys: {REQUIRED_KEYS - set(traj.keys())}. " "Did you write a `standardize_fn`?" + ) + + # extracts images, depth images and proprio from the "observation" dict + traj_len = tf.shape(traj["action"])[0] + old_obs = traj["observation"] + new_obs = {} + for new, old in image_obs_keys.items(): + if old is None: + new_obs[f"image_{new}"] = tf.repeat("", traj_len) # padding + else: + new_obs[f"image_{new}"] = old_obs[old] + + for new, old in depth_obs_keys.items(): + if old is None: + new_obs[f"depth_{new}"] = tf.repeat("", traj_len) # padding + else: + new_obs[f"depth_{new}"] = old_obs[old] + + if state_obs_keys: + new_obs["proprio"] = tf.concat( + [ + ( + tf.zeros((traj_len, 1), dtype=tf.float32) # padding + if key is None + else tf.cast(old_obs[key], tf.float32) + ) + for key in state_obs_keys + ], + axis=1, + ) + + # add timestep info + new_obs["timestep"] = tf.range(traj_len) + + # extracts `language_key` into the "task" dict + task = {} + if language_key is not None: + if traj[language_key].dtype != tf.string: + raise ValueError( + f"Language key {language_key} has dtype {traj[language_key].dtype}, " "but it must be tf.string." + ) + task["language_instruction"] = traj.pop(language_key) + + traj = { + "observation": new_obs, + "task": task, + "action": tf.cast(traj["action"], tf.float32), + "dataset_name": tf.repeat(name, traj_len), + } + + if absolute_action_mask is not None: + if len(absolute_action_mask) != traj["action"].shape[-1]: + raise ValueError( + f"Length of absolute_action_mask ({len(absolute_action_mask)}) " + f"does not match action dimension ({traj['action'].shape[-1]})." + ) + traj["absolute_action_mask"] = tf.tile( + tf.convert_to_tensor(absolute_action_mask, dtype=tf.bool)[None], + [traj_len, 1], + ) + + return traj + + builder = tfds.builder(name, data_dir=data_dir) + + # load or compute dataset statistics + if isinstance(dataset_statistics, str): + with tf.io.gfile.GFile(dataset_statistics, "r") as f: + dataset_statistics = json.load(f) + elif dataset_statistics is None: + full_dataset = dl.DLataset.from_rlds( + builder, split="all", shuffle=False, num_parallel_reads=num_parallel_reads + ).traj_map(restructure, num_parallel_calls) + # tries to load from cache, otherwise computes on the fly + dataset_statistics = get_dataset_statistics( + full_dataset, + hash_dependencies=( + str(builder.info), + str(state_obs_keys), + inspect.getsource(standardize_fn) if standardize_fn is not None else "", + ), + save_dir=builder.data_dir, + ) + dataset_statistics = tree_map(np.array, dataset_statistics) + + # skip normalization for certain action dimensions + if action_normalization_mask is not None: + if len(action_normalization_mask) != dataset_statistics["action"]["mean"].shape[-1]: + raise ValueError( + f"Length of skip_normalization_mask ({len(action_normalization_mask)}) " + f"does not match action dimension ({dataset_statistics['action']['mean'].shape[-1]})." + ) + dataset_statistics["action"]["mask"] = np.array(action_normalization_mask) + + # construct the dataset + split = "train" if train else "val" + + dataset = dl.DLataset.from_rlds(builder, split=split, shuffle=shuffle, num_parallel_reads=num_parallel_reads) + + dataset = dataset.traj_map(restructure, num_parallel_calls) + dataset = dataset.traj_map( + partial( + normalize_action_and_proprio, + metadata=dataset_statistics, + normalization_type=action_proprio_normalization_type, + ), + num_parallel_calls, + ) + + return dataset, dataset_statistics + + +def apply_trajectory_transforms( + dataset: dl.DLataset, + *, + train: bool, + goal_relabeling_strategy: Optional[str] = None, + goal_relabeling_kwargs: dict = {}, + window_size: int = 1, + future_action_window_size: int = 0, + subsample_length: Optional[int] = None, + skip_unlabeled: bool = False, + max_action: Optional[float] = None, + max_proprio: Optional[float] = None, + task_augment_strategy: Optional[str] = None, + task_augment_kwargs: dict = {}, + num_parallel_calls: int = tf.data.AUTOTUNE, +) -> dl.DLataset: + """ + Applies common transforms that happen at a trajectory level. Such transforms are usually some sort of "relabeling" + (e.g., filtering, chunking, adding goals, dropping keys). + + Transforms in this function should have the following properties: + - They require access to an entire trajectory (i.e., they cannot be applied frame-wise). + - They are generally not CPU-intensive, mostly involving moving and copying data. + - They do not require decoded images. + + Args: + dataset (dl.DLataset): The dataset to transform. + train (bool): Whether the dataset is for training (affects subsampling). + goal_relabeling_strategy (str, optional): The goal relabeling strategy to use, or None for + no goal relabeling. See `goal_relabeling.py`. + goal_relabeling_kwargs (dict, optional): Additional keyword arguments to pass to the goal relabeling function. + window_size (int, optional): The length of the snippets that trajectories are chunked into. + future_action_window_size (int, optional): The number of future actions beyond window_size to include + in the chunked actions. + subsample_length (int, optional): If provided, trajectories longer than this will be subsampled to + this length (after goal relabeling and chunking). + skip_unlabeled (bool, optional): Whether to skip trajectories with no language labels. + max_action: (float, optional): If provided, trajectories in which *any* action dimension + of *any* transition has an absolute value larger than this will be skipped. + max_proprio: (float, optional): If provided, trajectories in which *any* proprio dimension + of *any* transition has an absolute value larger than this will be skipped. + task_augment_strategy (str, optional): The task augmentation strategy to use, or None for no task + augmentation. See `task_augmentation.py`. + task_augment_kwargs (dict, optional): Additional keyword arguments to pass to the task augmentation + function. + num_parallel_calls (int, optional): number of parallel calls for map operations. Default to AUTOTUNE. + """ + if skip_unlabeled: + if "language_instruction" not in dataset.element_spec["task"]: + raise ValueError("skip_unlabeled=True but dataset does not have language labels.") + + dataset = dataset.filter(lambda x: tf.math.reduce_any(x["task"]["language_instruction"] != "")) + + if max_action is not None: + dataset = dataset.filter(lambda x: tf.math.reduce_all(tf.math.abs(x["action"]) <= max_action)) + + if max_proprio is not None and "proprio" in dataset.element_spec["observation"]: + dataset = dataset.filter(lambda x: tf.math.reduce_all(tf.math.abs(x["observation"]["proprio"]) <= max_proprio)) + + # marks which entires of the observation and task dicts are padding + dataset = dataset.traj_map(traj_transforms.add_pad_mask_dict, num_parallel_calls) + + # updates the "task" dict + if goal_relabeling_strategy is not None: + dataset = dataset.traj_map( + partial(getattr(goal_relabeling, goal_relabeling_strategy), **goal_relabeling_kwargs), + num_parallel_calls, + ) + + # must run task augmentation before chunking, in case it changes goal timesteps + if train and task_augment_strategy is not None: + # perform task augmentation (e.g., dropping keys) + dataset = dataset.traj_map( + partial( + getattr(task_augmentation, task_augment_strategy), + **task_augment_kwargs, + ), + num_parallel_calls, + ) + + # chunks observations and actions, giving them a new axis at index 1 of size `window_size` and + # `window_size + future_action_window_size`, respectively + dataset = dataset.traj_map( + partial( + traj_transforms.chunk_act_obs, + window_size=window_size, + future_action_window_size=future_action_window_size, + ), + num_parallel_calls, + ) + + if train and subsample_length is not None: + dataset = dataset.traj_map( + partial(traj_transforms.subsample, subsample_length=subsample_length), + num_parallel_calls, + ) + + return dataset + + +def apply_per_dataset_frame_transforms( + dataset: dl.DLataset, + chunk_filter_fn: Optional[Callable] = None, +): + """ + Optionally applied *per-dataset* transforms that happen at a frame level. + + Args: + chunk_filter_fn (callable, optional): Filter function for chunks. + """ + if chunk_filter_fn: + dataset = dataset.filter(chunk_filter_fn) + return dataset + + +def apply_frame_transforms( + dataset: dl.DLataset, + *, + train: bool, + image_augment_kwargs: Union[Dict, Dict[str, Dict]] = {}, + resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]] = {}, + depth_resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]] = {}, + num_parallel_calls: int = tf.data.AUTOTUNE, +) -> dl.DLataset: + """ + Applies common transforms that happen at a frame level. These transforms are usually more CPU-intensive, (e.g., + decoding or resizing images). + + Args: + train (bool): Whether the dataset is for training (affects image augmentation). + dataset (dl.DLataset): The dataset to transform. + image_augment_kwargs (dict|Mapping[str, dict]): Keyword arguments to pass to the image augmentation + function. See `dlimp.transforms.augment_image` for documentation of these kwargs. If a dict of + dicts is provided, then key "k" will be used for "image_{k}" (names determined by `image_obs_keys` + in `make_dataset_from_rlds`). Augmentation will be skipped for missing keys (so pass an empty dict + to skip augmentation for all images). + resize_size (Tuple[int, int]|Mapping[str, Tuple[int, int]]): If provided, images will be resized to + this size. If a dict of tuples is provided, then key "k" will be used for "image_{k}" (names + determined by `image_obs_keys` in `make_dataset_from_rlds`). Resizing will be skipped for missing + keys (so pass an empty dict to skip resizing for all images). + depth_resize_size (Tuple[int, int]|Mapping[str, Tuple[int, int]]): Same as resize_size, but for depth + images. + num_parallel_calls (int): number of parallel calls for frame_map operations. Default to AUTOTUNE. + """ + + # Convenience wrapper that takes a function that operates on a non-chunked "observation" dict and applies + # it to the chunked "observation" dict as well as the non-chunked "task" dict + def apply_obs_transform(fn: Callable[[Dict], Dict], frame: Dict) -> Dict: + frame["task"] = fn(frame["task"]) + frame["observation"] = dl.vmap(fn)(frame["observation"]) + return frame + + # Decode + resize images (and depth images) + dataset = dataset.frame_map( + partial( + apply_obs_transform, + partial(obs_transforms.decode_and_resize, resize_size=resize_size, depth_resize_size=depth_resize_size), + ), + num_parallel_calls, + ) + + if train: + # Augment all images with the same seed, skipping padding images + def aug(frame: dict): + seed = tf.random.uniform([2], maxval=tf.dtypes.int32.max, dtype=tf.int32) + aug_fn = partial(obs_transforms.augment, seed=seed, augment_kwargs=image_augment_kwargs) + return apply_obs_transform(aug_fn, frame) + + dataset = dataset.frame_map(aug, num_parallel_calls) + + return dataset + + +def make_single_dataset( + dataset_kwargs: dict, + *, + train: bool, + traj_transform_kwargs: dict = {}, + frame_transform_kwargs: dict = {}, +) -> dl.DLataset: + """Creates a single dataset from kwargs. Returns a dataset of trajectories. + + Args: + dataset_kwargs: kwargs passed to `make_dataset_from_rlds` that are dataset-specific. + train: whether this is a training or validation dataset. + traj_transform_kwargs: kwargs passed to 'apply_trajectory_transforms'. + frame_transform_kwargs: kwargs passed to 'get_frame_transforms'. + """ + dataset, dataset_statistics = make_dataset_from_rlds( + **dataset_kwargs, + train=train, + ) + dataset = apply_trajectory_transforms(dataset, **traj_transform_kwargs, train=train) + dataset = apply_frame_transforms(dataset, **frame_transform_kwargs, train=train) + + # this seems to reduce memory usage without affecting speed + dataset = dataset.with_ram_budget(1) + + # save for later + return dataset, dataset_statistics["num_trajectories"], dataset_statistics + + +# === Core Initializer === +def make_interleaved_dataset( + dataset_kwargs_list: List[Dict], + sample_weights: Optional[List[float]] = None, + *, + train: bool, + shuffle_buffer_size: int, + traj_transform_kwargs: Optional[Dict] = None, + frame_transform_kwargs: Optional[Dict] = None, + batch_size: Optional[int] = None, + balance_weights: bool = False, + traj_transform_threads: Optional[int] = None, + traj_read_threads: Optional[int] = None, +) -> dl.DLataset: + """ + Creates an interleaved dataset from list of dataset configs (kwargs). Returns a dataset of batched frames. + + Args: + dataset_kwargs_list: list of kwargs, each element of which is passed to `make_dataset_from_rlds`. + "num_parallel_calls" and "num_parallel_reads" are overridden using `traj_transform_threads` and + `traj_read_threads`, respectively. + sample_weights: sampling weights for each dataset in list. If None, defaults to uniform. + train: whether this is a training or validation dataset. + shuffle_buffer_size: size of the dataset shuffle buffer (in number of frames). + traj_transform_kwargs: kwargs passed to `apply_trajectory_transforms`. "num_parallel_calls" is + overridden using `traj_transform_threads`. + frame_transform_kwargs: kwargs passed to `apply_frame_transforms`. + batch_size: batch size, if not provided output is not batched. + balance_weights: if True, the sample weights are multiplied by the number of frames in each dataset. + This makes it so that, if all the sample weights are equal, one full iteration through the interleaved + dataset will correspond to one full iteration through each individual dataset (only in expectation, + since in practice the sampling is random). + traj_transform_threads: total number of parallel calls for trajectory transforms, distributed across + datasets according to their sampling weights. If None, defaults to AUTOTUNE for every dataset. + traj_read_threads: total number of parallel read workers for trajectory transforms, distributed across + datasets according to their sampling weights. If None, defaults to AUTOTUNE for every dataset. + """ + # Default to uniform sampling (if `sample_weights` is not specified) + if not sample_weights: + sample_weights = [1.0] * len(dataset_kwargs_list) + + if len(sample_weights) != len(dataset_kwargs_list): + raise ValueError(f"sample_weights must be None or have length {len(dataset_kwargs_list)}.") + + # Check valid `traj_transform_kwargs` and `frame_transform_kwargs` + if (traj_transform_kwargs is None) or (frame_transform_kwargs is None): + raise ValueError("Missing `traj_transform_kwargs` and `frame_transform_kwargs`!") + + # Get Dataset Sizes + dataset_sizes, all_dataset_statistics = [], {} + for dataset_kwargs in dataset_kwargs_list: + data_kwargs = copy.deepcopy(dataset_kwargs) + if "dataset_frame_transform_kwargs" in data_kwargs: + data_kwargs.pop("dataset_frame_transform_kwargs") + _, dataset_statistics = make_dataset_from_rlds(**data_kwargs, train=train) + dataset_sizes.append(dataset_statistics["num_transitions"]) + all_dataset_statistics[dataset_kwargs["name"]] = dataset_statistics + + # Get the indices of the "primary" datasets (i.e., datasets with sample_weight == 1.0) + primary_dataset_indices = np.array([idx for idx in range(len(sample_weights)) if sample_weights[idx] == 1.0]) + + # Balance and Normalize Weights + if balance_weights: + sample_weights = np.array(sample_weights) * np.array(dataset_sizes) + sample_weights = np.array(sample_weights) / np.sum(sample_weights) + pprint_data_mixture(dataset_kwargs_list, sample_weights) + + # Effective Dataset Length = Number of samples until each dataset has completed at least one epoch + # =>> Note :: Only counting the "primary" datasets (i.e., datasets with sample_weight == 1.0) + dataset_len = int((np.array(dataset_sizes) / sample_weights)[primary_dataset_indices].max()) + + # Allocate Threads based on Weights + threads_per_dataset = allocate_threads(traj_transform_threads, sample_weights) + reads_per_dataset = allocate_threads(traj_read_threads, sample_weights) + + overwatch.info("Threads per Dataset: %s", threads_per_dataset) + overwatch.info("Reads per Dataset: %s", reads_per_dataset) + + # Construct Datasets + overwatch.info("Constructing datasets...") + datasets = [] + for dataset_kwargs, threads, reads in zip( + dataset_kwargs_list, + threads_per_dataset, + reads_per_dataset, + ): + dataset_frame_transform_kwargs = ( + dataset_kwargs.pop("dataset_frame_transform_kwargs") + if "dataset_frame_transform_kwargs" in dataset_kwargs + else {} + ) + dataset, _ = make_dataset_from_rlds( + **dataset_kwargs, + train=train, + num_parallel_calls=threads, + num_parallel_reads=reads, + dataset_statistics=all_dataset_statistics[dataset_kwargs["name"]], + ) + dataset = apply_trajectory_transforms( + dataset.repeat(), + **traj_transform_kwargs, + num_parallel_calls=threads, + train=train, + ).flatten(num_parallel_calls=threads) + dataset = apply_per_dataset_frame_transforms(dataset, **dataset_frame_transform_kwargs) + datasets.append(dataset) + + # Interleave at the Frame Level + dataset: dl.DLataset = dl.DLataset.sample_from_datasets(datasets, sample_weights) + + # Validation =>> fix a single shuffle buffer of data and cache it in RAM; prevents gradual memory increase! + if not train: + dataset = dataset.take(shuffle_buffer_size).cache() + + # Shuffle the Dataset + # =>> IMPORTANT :: Shuffle AFTER .cache(), or else memory will still leak! + dataset = dataset.shuffle(shuffle_buffer_size) + + # Apply Frame Transforms + overwatch.info("Applying frame transforms on dataset...") + dataset = apply_frame_transforms(dataset, **frame_transform_kwargs, train=train) + + # [Contract] When training VLA Policies, we let the Collator handle Batching! + if batch_size is not None: + dataset = dataset.batch(batch_size) + + # Note =>> Seems to reduce memory usage without affecting speed? + dataset = dataset.with_ram_budget(1) + + # Save for Later + dataset.sample_weights = sample_weights + + return dataset, dataset_len, all_dataset_statistics diff --git a/capvector-oft/prismatic/vla/datasets/rlds/obs_transforms.py b/capvector-oft/prismatic/vla/datasets/rlds/obs_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..f5f537040dbf17651dbaf169c038268d8cc2a1ad --- /dev/null +++ b/capvector-oft/prismatic/vla/datasets/rlds/obs_transforms.py @@ -0,0 +1,99 @@ +""" +obs_transforms.py + +Contains observation-level transforms used in the orca data pipeline. + +These transforms operate on the "observation" dictionary, and are applied at a per-frame level. +""" + +from typing import Dict, Tuple, Union + +import dlimp as dl +import tensorflow as tf +from absl import logging + + +# ruff: noqa: B023 +def augment(obs: Dict, seed: tf.Tensor, augment_kwargs: Union[Dict, Dict[str, Dict]]) -> Dict: + """Augments images, skipping padding images.""" + image_names = {key[6:] for key in obs if key.startswith("image_")} + + # "augment_order" is required in augment_kwargs, so if it's there, we can assume that the user has passed + # in a single augmentation dict (otherwise, we assume that the user has passed in a mapping from image + # name to augmentation dict) + if "augment_order" in augment_kwargs: + augment_kwargs = {name: augment_kwargs for name in image_names} + + for i, name in enumerate(image_names): + if name not in augment_kwargs: + continue + kwargs = augment_kwargs[name] + logging.debug(f"Augmenting image_{name} with kwargs {kwargs}") + obs[f"image_{name}"] = tf.cond( + obs["pad_mask_dict"][f"image_{name}"], + lambda: dl.transforms.augment_image( + obs[f"image_{name}"], + **kwargs, + seed=seed + i, # augment each image differently + ), + lambda: obs[f"image_{name}"], # skip padding images + ) + + return obs + + +def decode_and_resize( + obs: Dict, + resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]], + depth_resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]], +) -> Dict: + """Decodes images and depth images, and then optionally resizes them.""" + image_names = {key[6:] for key in obs if key.startswith("image_")} + depth_names = {key[6:] for key in obs if key.startswith("depth_")} + + if isinstance(resize_size, tuple): + resize_size = {name: resize_size for name in image_names} + if isinstance(depth_resize_size, tuple): + depth_resize_size = {name: depth_resize_size for name in depth_names} + + for name in image_names: + if name not in resize_size: + logging.warning( + f"No resize_size was provided for image_{name}. This will result in 1x1 " + "padding images, which may cause errors if you mix padding and non-padding images." + ) + image = obs[f"image_{name}"] + if image.dtype == tf.string: + if tf.strings.length(image) == 0: + # this is a padding image + image = tf.zeros((*resize_size.get(name, (1, 1)), 3), dtype=tf.uint8) + else: + image = tf.io.decode_image(image, expand_animations=False, dtype=tf.uint8) + elif image.dtype != tf.uint8: + raise ValueError(f"Unsupported image dtype: found image_{name} with dtype {image.dtype}") + if name in resize_size: + image = dl.transforms.resize_image(image, size=resize_size[name]) + obs[f"image_{name}"] = image + + for name in depth_names: + if name not in depth_resize_size: + logging.warning( + f"No depth_resize_size was provided for depth_{name}. This will result in 1x1 " + "padding depth images, which may cause errors if you mix padding and non-padding images." + ) + depth = obs[f"depth_{name}"] + + if depth.dtype == tf.string: + if tf.strings.length(depth) == 0: + depth = tf.zeros((*depth_resize_size.get(name, (1, 1)), 1), dtype=tf.float32) + else: + depth = tf.io.decode_image(depth, expand_animations=False, dtype=tf.float32)[..., 0] + elif depth.dtype != tf.float32: + raise ValueError(f"Unsupported depth dtype: found depth_{name} with dtype {depth.dtype}") + + if name in depth_resize_size: + depth = dl.transforms.resize_depth_image(depth, size=depth_resize_size[name]) + + obs[f"depth_{name}"] = depth + + return obs diff --git a/capvector-oft/prismatic/vla/datasets/rlds/oxe/__init__.py b/capvector-oft/prismatic/vla/datasets/rlds/oxe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae77c4a44e477fea36a3f410dc047bcf0f82ef22 --- /dev/null +++ b/capvector-oft/prismatic/vla/datasets/rlds/oxe/__init__.py @@ -0,0 +1,2 @@ +from .materialize import get_oxe_dataset_kwargs_and_weights +from .mixtures import OXE_NAMED_MIXTURES diff --git a/capvector-oft/prismatic/vla/datasets/rlds/oxe/configs.py b/capvector-oft/prismatic/vla/datasets/rlds/oxe/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..5875b34d61afb7e8b53f3690d5509101b3b99527 --- /dev/null +++ b/capvector-oft/prismatic/vla/datasets/rlds/oxe/configs.py @@ -0,0 +1,709 @@ +""" +configs.py + +Defines per-dataset configuration (kwargs) for each dataset in Open-X Embodiment. + +Configuration adopts the following structure: + image_obs_keys: + primary: primary external RGB + secondary: secondary external RGB + wrist: wrist RGB + + depth_obs_keys: + primary: primary external depth + secondary: secondary external depth + wrist: wrist depth + + # Always 8-dim =>> changes based on `StateEncoding` + state_obs_keys: + StateEncoding.POS_EULER: EEF XYZ (3) + Roll-Pitch-Yaw (3) + (1) + Gripper Open/Close (1) + StateEncoding.POS_QUAT: EEF XYZ (3) + Quaternion (4) + Gripper Open/Close (1) + StateEncoding.JOINT: Joint Angles (7, if fewer) + Gripper Open/Close (1) + + state_encoding: Type of `StateEncoding` + action_encoding: Type of action encoding (e.g., EEF Position vs. Joint Position) +""" + +from enum import IntEnum + +from prismatic.vla.datasets.rlds.oxe.utils.droid_utils import zero_action_filter + + +# Defines Proprioceptive State Encoding Schemes +class StateEncoding(IntEnum): + # fmt: off + NONE = -1 # No Proprioceptive State + POS_EULER = 1 # EEF XYZ (3) + Roll-Pitch-Yaw (3) + (1) + Gripper Open/Close (1) + POS_QUAT = 2 # EEF XYZ (3) + Quaternion (4) + Gripper Open/Close (1) + JOINT = 3 # Joint Angles (7, if fewer) + Gripper Open/Close (1) + JOINT_BIMANUAL = 4 # Joint Angles (2 x [ Joint Angles (6) + Gripper Open/Close (1) ]) + # fmt: on + + +# Defines Action Encoding Schemes +class ActionEncoding(IntEnum): + # fmt: off + EEF_POS = 1 # EEF Delta XYZ (3) + Roll-Pitch-Yaw (3) + Gripper Open/Close (1) + JOINT_POS = 2 # Joint Delta Position (7) + Gripper Open/Close (1) + JOINT_POS_BIMANUAL = 3 # Joint Delta Position (2 x [ Joint Delta Position (6) + Gripper Open/Close (1) ]) + EEF_R6 = 4 # EEF Delta XYZ (3) + R6 (6) + Gripper Open/Close (1) + # fmt: on + + +# === Individual Dataset Configs === +OXE_DATASET_CONFIGS = { + "fractal20220817_data": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["base_pose_tool_reached", "gripper_closed"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "kuka": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": [ + "clip_function_input/base_pose_tool_reached", + "gripper_closed", + ], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "bridge_oxe": { # Version of Bridge V2 in Open X-Embodiment mixture + "image_obs_keys": {"primary": "image", "secondary": "image_1", "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "bridge_orig": { # Original version of Bridge V2 from project website + "image_obs_keys": {"primary": "image_0", "secondary": "image_1", "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "bridge_dataset": { # Original version of Bridge V2 from project website + "image_obs_keys": {"primary": "image_0", "secondary": "image_1", "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "taco_play": { + "image_obs_keys": { + "primary": "rgb_static", + "secondary": None, + "wrist": "rgb_gripper", + }, + "depth_obs_keys": { + "primary": "depth_static", + "secondary": None, + "wrist": "depth_gripper", + }, + "state_obs_keys": ["state_eef", None, "state_gripper"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "jaco_play": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "image_wrist", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state_eef", None, "state_gripper"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "berkeley_cable_routing": { + "image_obs_keys": { + "primary": "image", + "secondary": "top_image", + "wrist": "wrist45_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["robot_state", None], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "roboturk": { + "image_obs_keys": {"primary": "front_rgb", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": [None, None, None, None, None, None, None, None], + "state_encoding": StateEncoding.NONE, + "action_encoding": ActionEncoding.EEF_POS, + }, + "nyu_door_opening_surprising_effectiveness": { + "image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": [None, None, None, None, None, None, None, None], + "state_encoding": StateEncoding.NONE, + "action_encoding": ActionEncoding.EEF_POS, + }, + "viola": { + "image_obs_keys": { + "primary": "agentview_rgb", + "secondary": None, + "wrist": "eye_in_hand_rgb", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["joint_states", "gripper_states"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "berkeley_autolab_ur5": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "hand_image", + }, + "depth_obs_keys": {"primary": "depth", "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "toto": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "language_table": { + "image_obs_keys": {"primary": "rgb", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["effector_translation", None, None, None, None, None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "columbia_cairlab_pusht_real": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["robot_state", None, None, None, None, None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "stanford_kuka_multimodal_dataset_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["ee_position", "ee_orientation", None], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "nyu_rot_dataset_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "stanford_hydra_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "austin_buds_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "nyu_franka_play_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": "image_additional_view", + "wrist": None, + }, + "depth_obs_keys": { + "primary": "depth", + "secondary": "depth_additional_view", + "wrist": None, + }, + "state_obs_keys": ["eef_state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "maniskill_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": { + "primary": "depth", + "secondary": None, + "wrist": "wrist_depth", + }, + "state_obs_keys": ["tcp_pose", "gripper_state"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "furniture_bench_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "cmu_franka_exploration_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "highres_image", + "secondary": None, + "wrist": None, + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": [None, None, None, None, None, None, None, None], + "state_encoding": StateEncoding.NONE, + "action_encoding": ActionEncoding.EEF_POS, + }, + "ucsd_kitchen_dataset_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["joint_state", None], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "ucsd_pick_and_place_dataset_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "austin_sailor_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "austin_sirius_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "bc_z": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": [ + "present/xyz", + "present/axis_angle", + None, + "present/sensed_close", + ], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "utokyo_pr2_opening_fridge_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "utokyo_xarm_pick_and_place_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": "image2", + "wrist": "hand_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["end_effector_pose", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "utokyo_xarm_bimanual_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["pose_r", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "robo_net": { + "image_obs_keys": {"primary": "image", "secondary": "image1", "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "berkeley_mvp_converted_externally_to_rlds": { + "image_obs_keys": {"primary": None, "secondary": None, "wrist": "hand_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["pose", "gripper"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.JOINT_POS, + }, + "berkeley_rpt_converted_externally_to_rlds": { + "image_obs_keys": {"primary": None, "secondary": None, "wrist": "hand_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["joint_pos", "gripper"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.JOINT_POS, + }, + "kaist_nonprehensile_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "stanford_mask_vit_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "tokyo_u_lsmo_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "dlr_sara_pour_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "dlr_sara_grid_clamp_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "dlr_edan_shared_control_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "asu_table_top_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "stanford_robocook_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image_1", "secondary": "image_2", "wrist": None}, + "depth_obs_keys": {"primary": "depth_1", "secondary": "depth_2", "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "imperialcollege_sawyer_wrist_cam": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": [None, None, None, None, None, None, None, "state"], + "state_encoding": StateEncoding.NONE, + "action_encoding": ActionEncoding.EEF_POS, + }, + "iamlab_cmu_pickup_insert_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["joint_state", "gripper_state"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "uiuc_d3field": { + "image_obs_keys": {"primary": "image_1", "secondary": "image_2", "wrist": None}, + "depth_obs_keys": {"primary": "depth_1", "secondary": "depth_2", "wrist": None}, + "state_obs_keys": [None, None, None, None, None, None, None, None], + "state_encoding": StateEncoding.NONE, + "action_encoding": ActionEncoding.EEF_POS, + }, + "utaustin_mutex": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "berkeley_fanuc_manipulation": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["joint_state", None, "gripper_state"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "cmu_playing_with_food": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "finger_vision_1", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "cmu_play_fusion": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "cmu_stretch": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "berkeley_gnm_recon": { + "image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "berkeley_gnm_cory_hall": { + "image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "berkeley_gnm_sac_son": { + "image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "droid": { + "image_obs_keys": { + "primary": "exterior_image_1_left", + "secondary": "exterior_image_2_left", + "wrist": "wrist_image_left", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["proprio"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + "aux_kwargs": { + "dataset_frame_transform_kwargs": { + "chunk_filter_fn": zero_action_filter, + }, + }, + }, + "fmb_dataset": { + "image_obs_keys": { + "primary": "image_side_1", + "secondary": "image_side_2", + "wrist": "image_wrist_1", + }, + "depth_obs_keys": { + "primary": "image_side_1_depth", + "secondary": "image_side_2_depth", + "wrist": "image_wrist_1_depth", + }, + "state_obs_keys": ["proprio"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "dobbe": { + "image_obs_keys": {"primary": "wrist_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["proprio"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "roboset": { + "image_obs_keys": { + "primary": "image_left", + "secondary": "image_right", + "wrist": "image_wrist", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["proprio"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.JOINT_POS, + }, + "rh20t": { + "image_obs_keys": { + "primary": "image_front", + "secondary": "image_side_right", + "wrist": "image_wrist", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["proprio"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + ### T-DROID datasets + "tdroid_carrot_in_bowl": { # "put carrot in bowl" task, 50 demos @ 5 Hz control + "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "tdroid_pour_corn_in_pot": { # "pour corn from red bowl into steel pot" task, 50 demos @ 5 Hz control + "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "tdroid_flip_pot_upright": { # "flip pot upright" task, 10 demos @ 5 Hz control + "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "tdroid_move_object_onto_plate": { # "move onto plate" task, 150 demos @ 5 Hz control + "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "tdroid_knock_object_over": { # "knock over" task, 70 demos @ 5 Hz control + "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "tdroid_cover_object_with_towel": { # "cover with towel" task, 45 demos @ 5 Hz control + "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + ### DROID Finetuning datasets + "droid_wipe": { + "image_obs_keys": {"primary": "exterior_image_2_left", "secondary": None, "wrist": "wrist_image_left"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["proprio"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + ### LIBERO datasets (modified versions) + "libero_spatial_no_noops": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "libero_object_no_noops": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "libero_goal_no_noops": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "libero_10_no_noops": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "libero_4_task_suites_no_noops": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + ### ALOHA fine-tuning datasets + "aloha1_fold_shorts_20_demos": { + "image_obs_keys": {"primary": "image", "secondary": None, "left_wrist": "left_wrist_image", "right_wrist": "right_wrist_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.JOINT_BIMANUAL, + "action_encoding": ActionEncoding.JOINT_POS_BIMANUAL, + }, + "aloha1_fold_shirt_30_demos": { + "image_obs_keys": {"primary": "image", "secondary": None, "left_wrist": "left_wrist_image", "right_wrist": "right_wrist_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.JOINT_BIMANUAL, + "action_encoding": ActionEncoding.JOINT_POS_BIMANUAL, + }, + "aloha1_scoop_X_into_bowl_45_demos": { + "image_obs_keys": {"primary": "image", "secondary": None, "left_wrist": "left_wrist_image", "right_wrist": "right_wrist_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.JOINT_BIMANUAL, + "action_encoding": ActionEncoding.JOINT_POS_BIMANUAL, + }, + "aloha1_put_X_into_pot_300_demos": { + "image_obs_keys": {"primary": "image", "secondary": None, "left_wrist": "left_wrist_image", "right_wrist": "right_wrist_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.JOINT_BIMANUAL, + "action_encoding": ActionEncoding.JOINT_POS_BIMANUAL, + }, +} diff --git a/capvector-oft/prismatic/vla/datasets/rlds/oxe/materialize.py b/capvector-oft/prismatic/vla/datasets/rlds/oxe/materialize.py new file mode 100644 index 0000000000000000000000000000000000000000..73f3cfaf4b8c2fac59be2f2426ed04132e4baaee --- /dev/null +++ b/capvector-oft/prismatic/vla/datasets/rlds/oxe/materialize.py @@ -0,0 +1,134 @@ +""" +materialize.py + +Factory class for initializing Open-X Embodiment dataset kwargs and other parameters; provides and exports functions for +clear control flow. +""" + +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, List, Tuple + +from prismatic.overwatch import initialize_overwatch +from prismatic.vla.constants import ACTION_DIM, ACTION_PROPRIO_NORMALIZATION_TYPE, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, PROPRIO_DIM, STOP_INDEX +from prismatic.vla.datasets.rlds.oxe.configs import OXE_DATASET_CONFIGS, ActionEncoding +from prismatic.vla.datasets.rlds.oxe.transforms import OXE_STANDARDIZATION_TRANSFORMS + +# Initialize Overwatch =>> Wraps `logging.Logger` +overwatch = initialize_overwatch(__name__) + + +def make_oxe_dataset_kwargs( + dataset_name: str, + data_root_dir: Path, + load_camera_views: Tuple[str] = ("primary",), + load_depth: bool = False, + load_proprio: bool = True, + load_language: bool = True, + action_proprio_normalization_type = ACTION_PROPRIO_NORMALIZATION_TYPE, +) -> Dict[str, Any]: + """Generates config (kwargs) for given dataset from Open-X Embodiment.""" + dataset_kwargs = deepcopy(OXE_DATASET_CONFIGS[dataset_name]) + if dataset_kwargs["action_encoding"] not in [ActionEncoding.EEF_POS, ActionEncoding.EEF_R6, ActionEncoding.JOINT_POS_BIMANUAL]: + raise ValueError(f"Cannot load `{dataset_name}`; only EEF_POS & EEF_R6 & JOINT_POS_BIMANUAL actions supported!") + + # [Contract] For EEF_POS & EEF_R6 actions, only the last action dimension (gripper) is absolute! + # Normalize all action dimensions *except* the gripper + if dataset_kwargs["action_encoding"] is ActionEncoding.EEF_POS: + dataset_kwargs["absolute_action_mask"] = [False] * 6 + [True] + dataset_kwargs["action_normalization_mask"] = [True] * 6 + [False] + elif dataset_kwargs["action_encoding"] is ActionEncoding.EEF_R6: + dataset_kwargs["absolute_action_mask"] = [False] * 9 + [True] + dataset_kwargs["action_normalization_mask"] = [True] * 9 + [False] + elif dataset_kwargs["action_encoding"] is ActionEncoding.JOINT_POS_BIMANUAL: + dataset_kwargs["absolute_action_mask"] = [True] * 14 + dataset_kwargs["action_normalization_mask"] = [True] * 14 + dataset_kwargs["action_proprio_normalization_type"] = action_proprio_normalization_type + + # Adjust Loaded Camera Views + if len(missing_keys := (set(load_camera_views) - set(dataset_kwargs["image_obs_keys"]))) > 0: + raise ValueError(f"Cannot load `{dataset_name}`; missing camera views `{missing_keys}`") + + # Filter + dataset_kwargs["image_obs_keys"] = { + k: v for k, v in dataset_kwargs["image_obs_keys"].items() if k in load_camera_views + } + dataset_kwargs["depth_obs_keys"] = { + k: v for k, v in dataset_kwargs["depth_obs_keys"].items() if k in load_camera_views + } + + # Eliminate Unnecessary Keys + dataset_kwargs.pop("state_encoding") + dataset_kwargs.pop("action_encoding") + if not load_depth: + dataset_kwargs.pop("depth_obs_keys") + if not load_proprio: + dataset_kwargs.pop("state_obs_keys") + + # Load Language + if load_language: + dataset_kwargs["language_key"] = "language_instruction" + + # Specify Standardization Transform + dataset_kwargs["standardize_fn"] = OXE_STANDARDIZATION_TRANSFORMS[dataset_name] + + # Add any aux arguments + if "aux_kwargs" in dataset_kwargs: + dataset_kwargs.update(dataset_kwargs.pop("aux_kwargs")) + + return {"name": dataset_name, "data_dir": str(data_root_dir), **dataset_kwargs} + + +def get_oxe_dataset_kwargs_and_weights( + data_root_dir: Path, + mixture_spec: List[Tuple[str, float]], + load_camera_views: Tuple[str] = ("primary",), + load_depth: bool = False, + load_proprio: bool = True, + load_language: bool = True, + action_proprio_normalization_type = ACTION_PROPRIO_NORMALIZATION_TYPE, +) -> Tuple[Dict[str, Any], List[float]]: + """ + Generates dataset kwargs for a given dataset mix from the Open X-Embodiment dataset. The returned kwargs + (per-dataset configs) and weights can be passed directly to `make_interleaved_dataset`. + + :param data_root_dir: Base directory containing RLDS/TFDS-formatted datasets (from Open-X) + :param mixture_spec: List of (dataset_name, sampling_weight) from `oxe.mixtures.OXE_NAMED_MIXTURES` + :param load_camera_views: Camera views to load; see `oxe.dataset_configs.py` for available views. + :param load_depth: Load depth information in addition to camera RGB. + :param load_proprio: Load proprioceptive state. + :param load_language: Load language instructions. + :param action_proprio_normalization_type: Normalization scheme to use for proprioceptive actions. + + return: Tuple of (per_dataset_kwargs, sampling_weights) + """ + included_datasets, filtered_mixture_spec = set(), [] + for d_name, d_weight in mixture_spec: + if d_name in included_datasets: + overwatch.warning(f"Skipping Duplicate Dataset: `{(d_name, d_weight)}`") + continue + + included_datasets.add(d_name) + filtered_mixture_spec.append((d_name, d_weight)) + + # Assemble Dataset Config (kwargs) and Weights + per_dataset_kwargs, sampling_weights = [], [] + for d_name, d_weight in filtered_mixture_spec: + try: + per_dataset_kwargs.append( + make_oxe_dataset_kwargs( + d_name, + data_root_dir, + load_camera_views, + load_depth, + load_proprio, + load_language, + action_proprio_normalization_type, + ) + ) + sampling_weights.append(d_weight) + + except ValueError as e: + overwatch.warning(f"Skipping `{d_name}` due to Error: {e}") + + return per_dataset_kwargs, sampling_weights diff --git a/capvector-oft/prismatic/vla/datasets/rlds/oxe/mixtures.py b/capvector-oft/prismatic/vla/datasets/rlds/oxe/mixtures.py new file mode 100644 index 0000000000000000000000000000000000000000..0f2c3732c4068ddf08b8f94362304799cebaab83 --- /dev/null +++ b/capvector-oft/prismatic/vla/datasets/rlds/oxe/mixtures.py @@ -0,0 +1,230 @@ +""" +mixtures.py + +Defines a registry of dataset mixtures and weights for the Open-X Embodiment Datasets. Each dataset is associated with +a float "sampling weight" +""" + +from typing import Dict, List, Tuple + +# fmt: off +OXE_NAMED_MIXTURES: Dict[str, List[Tuple[str, float]]] = { + # === Bridge V2 Dataset === + "bridge": [ + # ("bridge_oxe", 1.0), # Version of Bridge V2 in Open-X GCP Bucket + ("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website + ], + + + # === [Moderate-Scale] Bridge++ Mixtures === + "bridge_rt_1": [ + # ("bridge_oxe", 1.0) # Version of Bridge V2 in Open-X GCP Bucket + ("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website + + ("fractal20220817_data", 1.0), # Google RT-1 Robot Data (Large-Scale) + ], + + # === RT-X Mixtures === + "rtx": [ + ("fractal20220817_data", 0.54087122203), # Google RT-1 Robot Data (Large-Scale) + ("kuka", 0.8341046294), + # ("bridge_oxe", 1.0) # Version of Bridge V2 in Open-X GCP Bucket + ("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website + ("taco_play", 2.0), + ("jaco_play", 2.0), + ("berkeley_cable_routing", 3.0), + ("roboturk", 1.0), + # ("nyu_door_opening_surprising_effectiveness", 5.0), # Note --> only contains wrist camera images (skip?) + ("viola", 2.0), + ("berkeley_autolab_ur5", 1.0), + ("toto", 1.0), + ], + + "rtx_franka": [ + ("fractal20220817_data", 0.54087122203), # Google RT-1 Robot Data (Large-Scale) + ("kuka", 0.8341046294), + # ("bridge_oxe", 1.0) # Version of Bridge V2 in Open-X GCP Bucket + ("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website + ("taco_play", 2.0), + ("jaco_play", 2.0), + ("berkeley_cable_routing", 3.0), + ("roboturk", 1.0), + # ("nyu_door_opening_surprising_effectiveness", 5.0), # Note --> only contains wrist camera images (skip?) + ("viola", 2.0), + ("berkeley_autolab_ur5", 1.0), + ("toto", 1.0), + + ("taco_play", 1.0), + ("berkeley_cable_routing", 1.0), + ("viola", 1.0), + ("toto", 1.0), + ("stanford_hydra_dataset_converted_externally_to_rlds", 1.0), + ("austin_buds_dataset_converted_externally_to_rlds", 3.0), + ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), + ("maniskill_dataset_converted_externally_to_rlds", 0.1), + ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), + ("cmu_franka_exploration_dataset_converted_externally_to_rlds", 5.0), + ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), + ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), + ("berkeley_rpt_converted_externally_to_rlds", 1.0), + ("kaist_nonprehensile_converted_externally_to_rlds", 3.0), + ("stanford_robocook_converted_externally_to_rlds", 1.0), + ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), + ("utaustin_mutex", 1.0), + ("cmu_play_fusion", 1.0), + ], + + # === Open-X Magic Soup === + "oxe_magic_soup": [ + ("fractal20220817_data", 0.54087122203), # Google RT-1 Robot Data (Large-Scale) + ("kuka", 0.8341046294), + # ("bridge_oxe", 1.0) # Version of Bridge V2 in Open-X GCP Bucket + ("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website + ("taco_play", 2.0), + ("jaco_play", 1.0), + ("berkeley_cable_routing", 1.0), + ("roboturk", 2.0), + # ("nyu_door_opening_surprising_effectiveness", 1.0), # Note --> only contains wrist camera images (skip?) + ("viola", 2.0), + ("berkeley_autolab_ur5", 2.0), + ("toto", 1.0), + ("language_table", 0.1), + ("stanford_hydra_dataset_converted_externally_to_rlds", 2.0), + ("austin_buds_dataset_converted_externally_to_rlds", 1.0), + ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), + ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), + ("ucsd_kitchen_dataset_converted_externally_to_rlds", 2.0), + ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), + ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), + # ("bc_z", 0.2), # Note --> raw data is broken! + ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0), + ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), + # ("uiuc_d3field", 1.0), # Note --> raw data is broken! + ("utaustin_mutex", 1.0), + ("berkeley_fanuc_manipulation", 2.0), + ("cmu_stretch", 1.0), + ], + + # === Open-X Magic Soup++ === + "oxe_magic_soup_plus": [ + ("fractal20220817_data", 0.54087122203), # Google RT-1 Robot Data (Large-Scale) + ("kuka", 0.8341046294), + ("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website + ("taco_play", 2.0), + ("jaco_play", 1.0), + ("berkeley_cable_routing", 1.0), + ("roboturk", 2.0), + ("viola", 2.0), + ("berkeley_autolab_ur5", 2.0), + ("toto", 1.0), + ("language_table", 0.1), + ("stanford_hydra_dataset_converted_externally_to_rlds", 2.0), + ("austin_buds_dataset_converted_externally_to_rlds", 1.0), + ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), + ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), + ("ucsd_kitchen_dataset_converted_externally_to_rlds", 2.0), + ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), + ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), + ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0), + ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), + ("utaustin_mutex", 1.0), + ("berkeley_fanuc_manipulation", 2.0), + ("cmu_stretch", 1.0), + ## New Datasets in MagicSoup++ + ("bc_z", 0.2), # Note: use v0.1.0 --> later versions broken + ("fmb_dataset", 1.0), + ("dobbe", 0.2), + ("droid", 0.06), + ], + + "oxe_magic_soup_plus_minus": [ + ("fractal20220817_data", 1.0), # Google RT-1 Robot Data (Large-Scale) + ("kuka", 0.8341046294), + ("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website + ("taco_play", 2.0), + ("jaco_play", 1.0), + ("berkeley_cable_routing", 1.0), + ("roboturk", 2.0), + ("viola", 2.0), + ("berkeley_autolab_ur5", 2.0), + ("toto", 1.0), + # ("language_table", 0.1), + ("stanford_hydra_dataset_converted_externally_to_rlds", 2.0), + ("austin_buds_dataset_converted_externally_to_rlds", 1.0), + ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), + ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), + ("ucsd_kitchen_dataset_converted_externally_to_rlds", 2.0), + ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), + ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), + ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0), + ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), + ("utaustin_mutex", 1.0), + ("berkeley_fanuc_manipulation", 2.0), + ("cmu_stretch", 1.0), + ## New Datasets in MagicSoup++ + ("bc_z", 0.2), # Note: use v0.1.0 --> later versions broken + ("fmb_dataset", 1.0), + ("dobbe", 0.2), + # ("droid", 0.06), + ], + + # === T-DROID Dataset === + "tdroid_carrot_in_bowl": [ + ("tdroid_carrot_in_bowl", 1.0), + ], + "tdroid_pour_corn_in_pot": [ + ("tdroid_pour_corn_in_pot", 1.0), + ], + "tdroid_flip_pot_upright": [ + ("tdroid_flip_pot_upright", 1.0), + ], + "tdroid_move_object_onto_plate": [ + ("tdroid_move_object_onto_plate", 1.0), + ], + "tdroid_knock_object_over": [ + ("tdroid_knock_object_over", 1.0), + ], + "tdroid_cover_object_with_towel": [ + ("tdroid_cover_object_with_towel", 1.0), + ], + + # === DROID Finetuning Datasets === + "droid_wipe": [ + ("droid_wipe", 1.0), + ], + + # === LIBERO Datasets (Modified Versions) === + "libero_spatial_no_noops": [ + ("libero_spatial_no_noops", 1.0), + ], + "libero_object_no_noops": [ + ("libero_object_no_noops", 1.0), + ], + "libero_goal_no_noops": [ + ("libero_goal_no_noops", 1.0), + ], + "libero_10_no_noops": [ + ("libero_10_no_noops", 1.0), + ], + "libero_4_task_suites_no_noops": [ + ("libero_spatial_no_noops", 1.0), + ("libero_object_no_noops", 1.0), + ("libero_goal_no_noops", 1.0), + ("libero_10_no_noops", 1.0), + ], + + # === ALOHA Fine-Tuning Datasets === + "aloha1_fold_shorts_20_demos": [ + ("aloha1_fold_shorts_20_demos", 1.0), + ], + "aloha1_fold_shirt_30_demos": [ + ("aloha1_fold_shirt_30_demos", 1.0), + ], + "aloha1_scoop_X_into_bowl_45_demos": [ + ("aloha1_scoop_X_into_bowl_45_demos", 1.0), + ], + "aloha1_put_X_into_pot_300_demos": [ + ("aloha1_put_X_into_pot_300_demos", 1.0), + ], +# fmt: on +} diff --git a/capvector-oft/prismatic/vla/datasets/rlds/oxe/transforms.py b/capvector-oft/prismatic/vla/datasets/rlds/oxe/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..f4853abdda82517e34a4b9d806253a56d07d7124 --- /dev/null +++ b/capvector-oft/prismatic/vla/datasets/rlds/oxe/transforms.py @@ -0,0 +1,933 @@ +""" +transforms.py + +Defines a registry of per-dataset standardization transforms for each dataset in Open-X Embodiment. + +Transforms adopt the following structure: + Input: Dictionary of *batched* features (i.e., has leading time dimension) + Output: Dictionary `step` =>> { + "observation": { + + State (in chosen state representation) + }, + "action": Action (in chosen action representation), + "language_instruction": str + } +""" + +from typing import Any, Dict + +import tensorflow as tf + +from prismatic.vla.datasets.rlds.oxe.utils.droid_utils import droid_baseact_transform, droid_finetuning_transform +from prismatic.vla.datasets.rlds.utils.data_utils import ( + binarize_gripper_actions, + invert_gripper_actions, + rel2abs_gripper_actions, + relabel_bridge_actions, +) + + +def bridge_oxe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + """ + Applies to version of Bridge V2 in Open X-Embodiment mixture. + + Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it! + """ + for key in trajectory.keys(): + if key == "traj_metadata": + continue + elif key in ["observation", "action"]: + for key2 in trajectory[key]: + trajectory[key][key2] = trajectory[key][key2][1:] + else: + trajectory[key] = trajectory[key][1:] + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32), + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + trajectory = relabel_bridge_actions(trajectory) + trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + return trajectory + + +def bridge_orig_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + """ + Applies to original version of Bridge V2 from the official project website. + + Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it! + """ + for key in trajectory.keys(): + if key == "traj_metadata": + continue + elif key == "observation": + for key2 in trajectory[key]: + trajectory[key][key2] = trajectory[key][key2][1:] + else: + trajectory[key] = trajectory[key][1:] + + trajectory["action"] = tf.concat( + [ + trajectory["action"][:, :6], + binarize_gripper_actions(trajectory["action"][:, -1])[:, None], + ], + axis=1, + ) + trajectory = relabel_bridge_actions(trajectory) + trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + return trajectory + + +def ppgm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + [ + trajectory["action"][:, :6], + binarize_gripper_actions(trajectory["action"][:, -1])[:, None], + ], + axis=1, + ) + trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:] + return trajectory + + +def rt1_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # make gripper action absolute action, +1 = open, 0 = close + gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] + gripper_action = rel2abs_gripper_actions(gripper_action) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + gripper_action[:, None], + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # make gripper action absolute action, +1 = open, 0 = close + gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] + gripper_action = rel2abs_gripper_actions(gripper_action) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + gripper_action[:, None], + ), + axis=-1, + ) + # decode compressed state + eef_value = tf.io.decode_compressed( + trajectory["observation"]["clip_function_input/base_pose_tool_reached"], + compression_type="ZLIB", + ) + eef_value = tf.io.decode_raw(eef_value, tf.float32) + trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7)) + gripper_value = tf.io.decode_compressed(trajectory["observation"]["gripper_closed"], compression_type="ZLIB") + gripper_value = tf.io.decode_raw(gripper_value, tf.float32) + trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1)) + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" + # ) # delete uninformative language instruction + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def taco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["state_eef"] = trajectory["observation"]["robot_obs"][:, :6] + trajectory["observation"]["state_gripper"] = trajectory["observation"]["robot_obs"][:, 7:8] + trajectory["action"] = trajectory["action"]["rel_actions_world"] + + # invert gripper action + clip, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + tf.clip_by_value(trajectory["action"][:, -1:], 0, 1), + ), + axis=-1, + ) + + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def jaco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6] + trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][:, -1:] + + # make gripper action absolute action, +1 = open, 0 = close + gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] + gripper_action = rel2abs_gripper_actions(gripper_action) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + tf.zeros_like(trajectory["action"]["world_vector"]), + gripper_action[:, None], + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def berkeley_cable_routing_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + tf.zeros_like(trajectory["action"]["world_vector"][:, :1]), + ), + axis=-1, + ) + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" + # ) # delete uninformative language instruction + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def roboturk_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # invert absolute gripper action, +1 = open, 0 = close + gripper_action = invert_gripper_actions(tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1)) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + gripper_action, + ), + axis=-1, + ) + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" + # ) # delete uninformative language instruction + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def nyu_door_opening_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # make gripper action absolute action, +1 = open, 0 = close + gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] + gripper_action = rel2abs_gripper_actions(gripper_action) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + gripper_action[:, None], + ), + axis=-1, + ) + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" + # ) # delete uninformative language instruction + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def viola_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # make gripper action, +1 = open, 0 = close + gripper_action = trajectory["action"]["gripper_closedness_action"][:, None] + gripper_action = tf.clip_by_value(gripper_action, 0, 1) + gripper_action = invert_gripper_actions(gripper_action) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + gripper_action, + ), + axis=-1, + ) + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" + # ) # delete uninformative language instruction + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def berkeley_autolab_ur5_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["state"] = trajectory["observation"]["robot_state"][:, 6:14] + trajectory["observation"]["depth"] = trajectory["observation"].pop("image_with_depth") + + # make gripper action absolute action, +1 = open, 0 = close + gripper_action = trajectory["action"]["gripper_closedness_action"] + gripper_action = rel2abs_gripper_actions(gripper_action) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + gripper_action[:, None], + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def toto_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32), + ), + axis=-1, + ) + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" + # ) # delete uninformative language instruction + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def language_table_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # default to "open" gripper + trajectory["action"] = tf.concat( + ( + trajectory["action"], + tf.zeros_like(trajectory["action"]), + tf.zeros_like(trajectory["action"]), + tf.ones_like(trajectory["action"][:, :1]), + ), + axis=-1, + ) + + # decode language instruction + instruction_bytes = trajectory["observation"]["instruction"] + instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8") + # Remove trailing padding --> convert RaggedTensor to regular Tensor. + trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[:, 0] + return trajectory + + +def pusht_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + trajectory["action"]["gripper_closedness_action"][:, None], + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def stanford_kuka_multimodal_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["depth_image"] = trajectory["observation"]["depth_image"][..., 0] + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :3], + tf.zeros_like(trajectory["action"][:, :3]), + trajectory["action"][:, -1:], + ), + axis=-1, + ) + return trajectory + + +def nyu_rot_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][..., :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., -1:] + trajectory["action"] = trajectory["action"][..., :7] + return trajectory + + +def stanford_hydra_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # invert gripper action, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + invert_gripper_actions(trajectory["action"][:, -1:]), + ), + axis=-1, + ) + + trajectory["observation"]["eef_state"] = tf.concat( + ( + trajectory["observation"]["state"][:, :3], + trajectory["observation"]["state"][:, 7:10], + ), + axis=-1, + ) + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -3:-2] + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["language_instruction"]), "" + # ) # delete uninformative language instruction + return trajectory + + +def austin_buds_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # invert gripper action + clip, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), + ), + axis=-1, + ) + + trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8] + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["language_instruction"]), "" + # ) # delete uninformative language instruction + return trajectory + + +def nyu_franka_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["depth"] = tf.cast(trajectory["observation"]["depth"][..., 0], tf.float32) + trajectory["observation"]["depth_additional_view"] = tf.cast( + trajectory["observation"]["depth_additional_view"][..., 0], tf.float32 + ) + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, -6:] + + # clip gripper action, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, -8:-2], + tf.clip_by_value(trajectory["action"][:, -2:-1], 0, 1), + ), + axis=-1, + ) + + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["language_instruction"]), "" + # ) # delete uninformative language instruction + return trajectory + + +def maniskill_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., 7:8] + return trajectory + + +def furniture_bench_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + import tensorflow_graphics.geometry.transformation as tft + + trajectory["observation"]["state"] = tf.concat( + ( + trajectory["observation"]["state"][:, :7], + trajectory["observation"]["state"][:, -1:], + ), + axis=-1, + ) + + # invert gripper action + clip, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :3], + tft.euler.from_quaternion(trajectory["action"][:, 3:7]), + invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), + ), + axis=-1, + ) + return trajectory + + +def cmu_franka_exploration_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = trajectory["action"][..., :-1] + return trajectory + + +def ucsd_kitchen_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7] + trajectory["action"] = trajectory["action"][..., :-1] + return trajectory + + +def ucsd_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :3], + tf.zeros_like(trajectory["action"][:, :3]), + trajectory["action"][:, -1:], + ), + axis=-1, + ) + return trajectory + + +def austin_sailor_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # invert gripper action + clip, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), + ), + axis=-1, + ) + + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["language_instruction"]), "" + # ) # delete uninformative language instruction + return trajectory + + +def austin_sirius_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # invert gripper action + clip, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), + ), + axis=-1, + ) + + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["language_instruction"]), "" + # ) # delete uninformative language instruction + return trajectory + + +def bc_z_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"]["future/xyz_residual"][:, :3], + trajectory["action"]["future/axis_angle_residual"][:, :3], + invert_gripper_actions(tf.cast(trajectory["action"]["future/target_close"][:, :1], tf.float32)), + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def tokyo_pr2_opening_fridge_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + trajectory["action"] = trajectory["action"][..., :-1] + return trajectory + + +def tokyo_pr2_tabletop_manipulation_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + trajectory["action"] = trajectory["action"][..., :-1] + return trajectory + + +def utokyo_xarm_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + return trajectory + + +def utokyo_xarm_bimanual_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = trajectory["action"][..., -7:] + return trajectory + + +def robo_net_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = tf.concat( + ( + trajectory["observation"]["state"][:, :4], + tf.zeros_like(trajectory["observation"]["state"][:, :2]), + ), + axis=-1, + ) + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :4], + tf.zeros_like(trajectory["action"][:, :2]), + trajectory["action"][:, -1:], + ), + axis=-1, + ) + return trajectory + + +def berkeley_mvp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + return trajectory + + +def berkeley_rpt_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + return trajectory + + +def kaist_nonprehensible_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["state"] = trajectory["observation"]["state"][:, -7:] + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + tf.zeros_like(trajectory["action"][:, :1]), + ), + axis=-1, + ) + return trajectory + + +def stanford_mask_vit_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = tf.concat( + ( + trajectory["observation"]["end_effector_pose"][:, :4], + tf.zeros_like(trajectory["observation"]["end_effector_pose"][:, :2]), + ), + axis=-1, + ) + trajectory["observation"]["gripper_state"] = trajectory["observation"]["end_effector_pose"][:, -1:] + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :4], + tf.zeros_like(trajectory["action"][:, :2]), + trajectory["action"][:, -1:], + ), + axis=-1, + ) + return trajectory + + +def tokyo_lsmo_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + return trajectory + + +def dlr_sara_pour_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + return trajectory + + +def dlr_sara_grid_clamp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :6] + return trajectory + + +def dlr_edan_shared_control_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # invert gripper action, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + invert_gripper_actions(trajectory["action"][:, -1:]), + ), + axis=-1, + ) + return trajectory + + +def asu_table_top_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["ground_truth_states"]["EE"] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + return trajectory + + +def robocook_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + return trajectory + + +def imperial_wristcam_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = trajectory["action"][..., :-1] + return trajectory + + +def iamlab_pick_insert_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + import tensorflow_graphics.geometry.transformation as tft + + trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 7:8] + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :3], + tft.euler.from_quaternion(trajectory["action"][:, 3:7]), + trajectory["action"][:, 7:8], + ), + axis=-1, + ) + return trajectory + + +def uiuc_d3field_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"], + tf.zeros_like(trajectory["action"]), + tf.zeros_like(trajectory["action"][:, :1]), + ), + axis=-1, + ) + return trajectory + + +def utaustin_mutex_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8] + + # invert gripper action + clip, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), + ), + axis=-1, + ) + + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["language_instruction"]), "" + # ) # delete uninformative language instruction + return trajectory + + +def berkeley_fanuc_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 6:7] + + # dataset does not store gripper actions, so use gripper state info, invert so +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"], + invert_gripper_actions(trajectory["observation"]["gripper_state"]), + ), + axis=-1, + ) + return trajectory + + +def cmu_playing_with_food_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + import tensorflow_graphics.geometry.transformation as tft + + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :3], + tft.euler.from_quaternion(trajectory["action"][:, 3:7]), + trajectory["action"][:, -1:], + ), + axis=-1, + ) + return trajectory + + +def playfusion_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :3], + trajectory["action"][:, -4:], + ), + axis=-1, + ) + return trajectory + + +def cmu_stretch_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = tf.concat( + ( + trajectory["observation"]["state"][:, :3], + tf.zeros_like(trajectory["observation"]["state"][:, :3]), + ), + axis=-1, + ) + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + trajectory["action"] = trajectory["action"][..., :-1] + return trajectory + + +def gnm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["state"] = tf.concat( + ( + trajectory["observation"]["position"], + tf.zeros_like(trajectory["observation"]["state"][:, :3]), + trajectory["observation"]["yaw"], + ), + axis=-1, + ) + trajectory["action"] = tf.concat( + ( + trajectory["action"], + tf.zeros_like(trajectory["action"]), + tf.zeros_like(trajectory["action"]), + tf.zeros_like(trajectory["action"][:, :1]), + ), + axis=-1, + ) + return trajectory + + +def fmb_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # every input feature is batched, ie has leading batch dimension + trajectory["observation"]["proprio"] = tf.concat( + ( + trajectory["observation"]["eef_pose"], + trajectory["observation"]["state_gripper_pose"][..., None], + ), + axis=-1, + ) + return trajectory + + +def dobbe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # every input feature is batched, ie has leading batch dimension + trajectory["observation"]["proprio"] = trajectory["observation"]["state"] + return trajectory + + +def roboset_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # every input feature is batched, ie has leading batch dimension + trajectory["observation"]["proprio"] = trajectory["observation"]["state"] + + # gripper action is in -1...1 --> clip to 0...1, flip + gripper_action = trajectory["action"][:, -1:] + gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1)) + + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :7], + gripper_action, + ), + axis=-1, + ) + return trajectory + + +def rh20t_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"]["tcp_base"], + tf.cast(trajectory["action"]["gripper"][:, None], tf.float32), + ), + axis=-1, + ) + trajectory["observation"]["proprio"] = tf.concat( + ( + trajectory["observation"]["tcp_base"], + trajectory["observation"]["gripper_width"][..., None], + ), + axis=-1, + ) + return trajectory + + +def tdroid_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + [ + trajectory["action"][:, :6], + binarize_gripper_actions(trajectory["action"][:, -1])[:, None], + ], + axis=1, + ) + trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:] + return trajectory + + +def libero_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # gripper action is in -1 (open)...1 (close) --> clip to 0...1, flip --> +1 = open, 0 = close + gripper_action = trajectory["action"][:, -1:] + gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1)) + + trajectory["action"] = tf.concat( + [ + trajectory["action"][:, :6], + gripper_action, + ], + axis=1, + ) + trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -2:] # 2D gripper state + return trajectory + + +def aloha_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # Don't need to do anything because dataset is already in the correct format + return trajectory + + +# === Registry === +OXE_STANDARDIZATION_TRANSFORMS = { + "bridge_oxe": bridge_oxe_dataset_transform, + "bridge_orig": bridge_orig_dataset_transform, + "bridge_dataset": bridge_orig_dataset_transform, + "ppgm": ppgm_dataset_transform, + "ppgm_static": ppgm_dataset_transform, + "ppgm_wrist": ppgm_dataset_transform, + "fractal20220817_data": rt1_dataset_transform, + "kuka": kuka_dataset_transform, + "taco_play": taco_play_dataset_transform, + "jaco_play": jaco_play_dataset_transform, + "berkeley_cable_routing": berkeley_cable_routing_dataset_transform, + "roboturk": roboturk_dataset_transform, + "nyu_door_opening_surprising_effectiveness": nyu_door_opening_dataset_transform, + "viola": viola_dataset_transform, + "berkeley_autolab_ur5": berkeley_autolab_ur5_dataset_transform, + "toto": toto_dataset_transform, + "language_table": language_table_dataset_transform, + "columbia_cairlab_pusht_real": pusht_dataset_transform, + "stanford_kuka_multimodal_dataset_converted_externally_to_rlds": stanford_kuka_multimodal_dataset_transform, + "nyu_rot_dataset_converted_externally_to_rlds": nyu_rot_dataset_transform, + "stanford_hydra_dataset_converted_externally_to_rlds": stanford_hydra_dataset_transform, + "austin_buds_dataset_converted_externally_to_rlds": austin_buds_dataset_transform, + "nyu_franka_play_dataset_converted_externally_to_rlds": nyu_franka_play_dataset_transform, + "maniskill_dataset_converted_externally_to_rlds": maniskill_dataset_transform, + "furniture_bench_dataset_converted_externally_to_rlds": furniture_bench_dataset_transform, + "cmu_franka_exploration_dataset_converted_externally_to_rlds": cmu_franka_exploration_dataset_transform, + "ucsd_kitchen_dataset_converted_externally_to_rlds": ucsd_kitchen_dataset_transform, + "ucsd_pick_and_place_dataset_converted_externally_to_rlds": ucsd_pick_place_dataset_transform, + "austin_sailor_dataset_converted_externally_to_rlds": austin_sailor_dataset_transform, + "austin_sirius_dataset_converted_externally_to_rlds": austin_sirius_dataset_transform, + "bc_z": bc_z_dataset_transform, + "utokyo_pr2_opening_fridge_converted_externally_to_rlds": tokyo_pr2_opening_fridge_dataset_transform, + "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": tokyo_pr2_tabletop_manipulation_dataset_transform, + "utokyo_xarm_pick_and_place_converted_externally_to_rlds": utokyo_xarm_pick_place_dataset_transform, + "utokyo_xarm_bimanual_converted_externally_to_rlds": utokyo_xarm_bimanual_dataset_transform, + "robo_net": robo_net_dataset_transform, + "berkeley_mvp_converted_externally_to_rlds": berkeley_mvp_dataset_transform, + "berkeley_rpt_converted_externally_to_rlds": berkeley_rpt_dataset_transform, + "kaist_nonprehensile_converted_externally_to_rlds": kaist_nonprehensible_dataset_transform, + "stanford_mask_vit_converted_externally_to_rlds": stanford_mask_vit_dataset_transform, + "tokyo_u_lsmo_converted_externally_to_rlds": tokyo_lsmo_dataset_transform, + "dlr_sara_pour_converted_externally_to_rlds": dlr_sara_pour_dataset_transform, + "dlr_sara_grid_clamp_converted_externally_to_rlds": dlr_sara_grid_clamp_dataset_transform, + "dlr_edan_shared_control_converted_externally_to_rlds": dlr_edan_shared_control_dataset_transform, + "asu_table_top_converted_externally_to_rlds": asu_table_top_dataset_transform, + "stanford_robocook_converted_externally_to_rlds": robocook_dataset_transform, + "imperialcollege_sawyer_wrist_cam": imperial_wristcam_dataset_transform, + "iamlab_cmu_pickup_insert_converted_externally_to_rlds": iamlab_pick_insert_dataset_transform, + "uiuc_d3field": uiuc_d3field_dataset_transform, + "utaustin_mutex": utaustin_mutex_dataset_transform, + "berkeley_fanuc_manipulation": berkeley_fanuc_dataset_transform, + "cmu_playing_with_food": cmu_playing_with_food_dataset_transform, + "cmu_play_fusion": playfusion_dataset_transform, + "cmu_stretch": cmu_stretch_dataset_transform, + "berkeley_gnm_recon": gnm_dataset_transform, + "berkeley_gnm_cory_hall": gnm_dataset_transform, + "berkeley_gnm_sac_son": gnm_dataset_transform, + "droid": droid_baseact_transform, + "fmb_dataset": fmb_dataset_transform, + "dobbe": dobbe_dataset_transform, + "roboset": roboset_dataset_transform, + "rh20t": rh20t_dataset_transform, + ### T-DROID datasets + "tdroid_carrot_in_bowl": tdroid_dataset_transform, + "tdroid_pour_corn_in_pot": tdroid_dataset_transform, + "tdroid_flip_pot_upright": tdroid_dataset_transform, + "tdroid_move_object_onto_plate": tdroid_dataset_transform, + "tdroid_knock_object_over": tdroid_dataset_transform, + "tdroid_cover_object_with_towel": tdroid_dataset_transform, + ### DROID Finetuning datasets + "droid_wipe": droid_finetuning_transform, + ### LIBERO datasets (modified versions) + "libero_spatial_no_noops": libero_dataset_transform, + "libero_object_no_noops": libero_dataset_transform, + "libero_goal_no_noops": libero_dataset_transform, + "libero_10_no_noops": libero_dataset_transform, + "libero_4_task_suites_no_noops": libero_dataset_transform, + ### ALOHA fine-tuning datasets + "aloha1_fold_shorts_20_demos": aloha_dataset_transform, + "aloha1_fold_shirt_30_demos": aloha_dataset_transform, + "aloha1_scoop_X_into_bowl_45_demos": aloha_dataset_transform, + "aloha1_put_X_into_pot_300_demos": aloha_dataset_transform, +} diff --git a/capvector-oft/prismatic/vla/datasets/rlds/oxe/utils/droid_utils.py b/capvector-oft/prismatic/vla/datasets/rlds/oxe/utils/droid_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b98e59bc2fb0f9b498e00eaca189c2379304e5aa --- /dev/null +++ b/capvector-oft/prismatic/vla/datasets/rlds/oxe/utils/droid_utils.py @@ -0,0 +1,178 @@ +"""Episode transforms for DROID dataset.""" + +from typing import Any, Dict + +import tensorflow as tf +import tensorflow_graphics.geometry.transformation as tfg + + +def rmat_to_euler(rot_mat): + return tfg.euler.from_rotation_matrix(rot_mat) + + +def euler_to_rmat(euler): + return tfg.rotation_matrix_3d.from_euler(euler) + + +def invert_rmat(rot_mat): + return tfg.rotation_matrix_3d.inverse(rot_mat) + + +def rotmat_to_rot6d(mat): + """ + Converts rotation matrix to R6 rotation representation (first two rows in rotation matrix). + Args: + mat: rotation matrix + + Returns: 6d vector (first two rows of rotation matrix) + + """ + r6 = mat[..., :2, :] + r6_0, r6_1 = r6[..., 0, :], r6[..., 1, :] + r6_flat = tf.concat([r6_0, r6_1], axis=-1) + return r6_flat + + +def velocity_act_to_wrist_frame(velocity, wrist_in_robot_frame): + """ + Translates velocity actions (translation + rotation) from base frame of the robot to wrist frame. + Args: + velocity: 6d velocity action (3 x translation, 3 x rotation) + wrist_in_robot_frame: 6d pose of the end-effector in robot base frame + + Returns: 9d velocity action in robot wrist frame (3 x translation, 6 x rotation as R6) + + """ + R_frame = euler_to_rmat(wrist_in_robot_frame[:, 3:6]) + R_frame_inv = invert_rmat(R_frame) + + # world to wrist: dT_pi = R^-1 dT_rbt + vel_t = (R_frame_inv @ velocity[:, :3][..., None])[..., 0] + + # world to wrist: dR_pi = R^-1 dR_rbt R + dR = euler_to_rmat(velocity[:, 3:6]) + dR = R_frame_inv @ (dR @ R_frame) + dR_r6 = rotmat_to_rot6d(dR) + return tf.concat([vel_t, dR_r6], axis=-1) + + +def rand_swap_exterior_images(img1, img2): + """ + Randomly swaps the two exterior images (for training with single exterior input). + """ + return tf.cond(tf.random.uniform(shape=[]) > 0.5, lambda: (img1, img2), lambda: (img2, img1)) + + +def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + """ + DROID dataset transformation for actions expressed in *base* frame of the robot. + """ + dt = trajectory["action_dict"]["cartesian_velocity"][:, :3] + dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6] + + trajectory["action"] = tf.concat( + ( + dt, + dR, + 1 - trajectory["action_dict"]["gripper_position"], + ), + axis=-1, + ) + trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = ( + rand_swap_exterior_images( + trajectory["observation"]["exterior_image_1_left"], + trajectory["observation"]["exterior_image_2_left"], + ) + ) + trajectory["observation"]["proprio"] = tf.concat( + ( + trajectory["observation"]["cartesian_position"], + trajectory["observation"]["gripper_position"], + ), + axis=-1, + ) + return trajectory + + +def droid_wristact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + """ + DROID dataset transformation for actions expressed in *wrist* frame of the robot. + """ + wrist_act = velocity_act_to_wrist_frame( + trajectory["action_dict"]["cartesian_velocity"], trajectory["observation"]["cartesian_position"] + ) + trajectory["action"] = tf.concat( + ( + wrist_act, + trajectory["action_dict"]["gripper_position"], + ), + axis=-1, + ) + trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = ( + rand_swap_exterior_images( + trajectory["observation"]["exterior_image_1_left"], + trajectory["observation"]["exterior_image_2_left"], + ) + ) + trajectory["observation"]["proprio"] = tf.concat( + ( + trajectory["observation"]["cartesian_position"], + trajectory["observation"]["gripper_position"], + ), + axis=-1, + ) + return trajectory + + +def droid_finetuning_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + """ + DROID dataset transformation for actions expressed in *base* frame of the robot. + """ + dt = trajectory["action_dict"]["cartesian_velocity"][:, :3] + dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6] + trajectory["action"] = tf.concat( + ( + dt, + dR, + 1 - trajectory["action_dict"]["gripper_position"], + ), + axis=-1, + ) + trajectory["observation"]["proprio"] = tf.concat( + ( + trajectory["observation"]["cartesian_position"], + trajectory["observation"]["gripper_position"], + ), + axis=-1, + ) + return trajectory + + +def zero_action_filter(traj: Dict) -> bool: + """ + Filters transitions whose actions are all-0 (only relative actions, no gripper action). + Note: this filter is applied *after* action normalization, so need to compare to "normalized 0". + """ + DROID_Q01 = tf.convert_to_tensor( + [ + -0.7776297926902771, + -0.5803514122962952, + -0.5795090794563293, + -0.6464047729969025, + -0.7041108310222626, + -0.8895104378461838, + ] + ) + DROID_Q99 = tf.convert_to_tensor( + [ + 0.7597932070493698, + 0.5726242214441299, + 0.7351000607013702, + 0.6705610305070877, + 0.6464948207139969, + 0.8897542208433151, + ] + ) + DROID_NORM_0_ACT = 2 * (tf.zeros_like(traj["action"][:, :6]) - DROID_Q01) / (DROID_Q99 - DROID_Q01 + 1e-8) - 1 + + return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - DROID_NORM_0_ACT) > 1e-5) diff --git a/capvector-oft/prismatic/vla/datasets/rlds/traj_transforms.py b/capvector-oft/prismatic/vla/datasets/rlds/traj_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..9d943abbb532fa1af171aa2dc467d1a3c5114c56 --- /dev/null +++ b/capvector-oft/prismatic/vla/datasets/rlds/traj_transforms.py @@ -0,0 +1,90 @@ +""" +traj_transforms.py + +Contains trajectory transforms used in the orca data pipeline. Trajectory transforms operate on a dictionary +that represents a single trajectory, meaning each tensor has the same leading dimension (the trajectory length). +""" + +import logging +from typing import Dict + +import tensorflow as tf + + +def chunk_act_obs(traj: Dict, window_size: int, future_action_window_size: int = 0) -> Dict: + """ + Chunks actions and observations into the given window_size. + + "observation" keys are given a new axis (at index 1) of size `window_size` containing `window_size - 1` + observations from the past and the current observation. "action" is given a new axis (at index 1) of size + `window_size + future_action_window_size` containing `window_size - 1` actions from the past, the current + action, and `future_action_window_size` actions from the future. "pad_mask" is added to "observation" and + indicates whether an observation should be considered padding (i.e. if it had come from a timestep + before the start of the trajectory). + """ + traj_len = tf.shape(traj["action"])[0] + action_dim = traj["action"].shape[-1] + effective_traj_len = traj_len - future_action_window_size + chunk_indices = tf.broadcast_to(tf.range(-window_size + 1, 1), [effective_traj_len, window_size]) + tf.broadcast_to( + tf.range(effective_traj_len)[:, None], [effective_traj_len, window_size] + ) + + action_chunk_indices = tf.broadcast_to( + tf.range(-window_size + 1, 1 + future_action_window_size), + [effective_traj_len, window_size + future_action_window_size], + ) + tf.broadcast_to( + tf.range(effective_traj_len)[:, None], + [effective_traj_len, window_size + future_action_window_size], + ) + + floored_chunk_indices = tf.maximum(chunk_indices, 0) + + goal_timestep = tf.fill([effective_traj_len], traj_len - 1) + + floored_action_chunk_indices = tf.minimum(tf.maximum(action_chunk_indices, 0), goal_timestep[:, None]) + + traj["observation"] = tf.nest.map_structure(lambda x: tf.gather(x, floored_chunk_indices), traj["observation"]) + traj["action"] = tf.gather(traj["action"], floored_action_chunk_indices) + + # indicates whether an entire observation is padding + traj["observation"]["pad_mask"] = chunk_indices >= 0 + + # Truncate other elements of the trajectory dict + traj["task"] = tf.nest.map_structure(lambda x: tf.gather(x, tf.range(effective_traj_len)), traj["task"]) + traj["dataset_name"] = tf.gather(traj["dataset_name"], tf.range(effective_traj_len)) + traj["absolute_action_mask"] = tf.gather(traj["absolute_action_mask"], tf.range(effective_traj_len)) + + return traj + + +def subsample(traj: Dict, subsample_length: int) -> Dict: + """Subsamples trajectories to the given length.""" + traj_len = tf.shape(traj["action"])[0] + if traj_len > subsample_length: + indices = tf.random.shuffle(tf.range(traj_len))[:subsample_length] + traj = tf.nest.map_structure(lambda x: tf.gather(x, indices), traj) + + return traj + + +def add_pad_mask_dict(traj: Dict) -> Dict: + """ + Adds a dictionary indicating which elements of the observation/task should be treated as padding. + =>> traj["observation"|"task"]["pad_mask_dict"] = {k: traj["observation"|"task"][k] is not padding} + """ + traj_len = tf.shape(traj["action"])[0] + + for key in ["observation", "task"]: + pad_mask_dict = {} + for subkey in traj[key]: + # Handles "language_instruction", "image_*", and "depth_*" + if traj[key][subkey].dtype == tf.string: + pad_mask_dict[subkey] = tf.strings.length(traj[key][subkey]) != 0 + + # All other keys should not be treated as padding + else: + pad_mask_dict[subkey] = tf.ones([traj_len], dtype=tf.bool) + + traj[key]["pad_mask_dict"] = pad_mask_dict + + return traj diff --git a/capvector-oft/prismatic/vla/datasets/rlds/utils/__init__.py b/capvector-oft/prismatic/vla/datasets/rlds/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/capvector-oft/prismatic/vla/datasets/rlds/utils/data_utils.py b/capvector-oft/prismatic/vla/datasets/rlds/utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..df49bd6f8defc3ed431dd3cfd5054646f771c0f1 --- /dev/null +++ b/capvector-oft/prismatic/vla/datasets/rlds/utils/data_utils.py @@ -0,0 +1,321 @@ +""" +data_utils.py + +Additional RLDS-specific data utilities. +""" + +import hashlib +import json +import os +from typing import Any, Callable, Dict, List, Optional, Tuple + +import dlimp as dl +import numpy as np +import tensorflow as tf +from tqdm import tqdm + +from prismatic.overwatch import initialize_overwatch +from prismatic.vla.constants import NormalizationType + +# Initialize Overwatch =>> Wraps `logging.Logger` +overwatch = initialize_overwatch(__name__) + + +def tree_map(fn: Callable, tree: Dict) -> Dict: + return {k: tree_map(fn, v) if isinstance(v, dict) else fn(v) for k, v in tree.items()} + + +def tree_merge(*trees: Dict) -> Dict: + merged = {} + for tree in trees: + for k, v in tree.items(): + if isinstance(v, dict): + merged[k] = tree_merge(merged.get(k, {}), v) + else: + merged[k] = v + return merged + + +def to_padding(tensor: tf.Tensor) -> tf.Tensor: + if tf.debugging.is_numeric_tensor(tensor): + return tf.zeros_like(tensor) + elif tensor.dtype == tf.string: + return tf.fill(tf.shape(tensor), "") + else: + raise ValueError(f"Cannot generate padding for tensor of type {tensor.dtype}.") + + +# === State / Action Processing Primitives === + + +# ruff: noqa: B023 +def normalize_action_and_proprio(traj: Dict, metadata: Dict, normalization_type: NormalizationType): + """Normalizes the action and proprio fields of a trajectory using the given metadata.""" + keys_to_normalize = {"action": "action", "proprio": "observation/proprio"} + + if normalization_type == NormalizationType.NORMAL: + for key, traj_key in keys_to_normalize.items(): + mask = metadata[key].get("mask", tf.ones_like(metadata[key]["mean"], dtype=tf.bool)) + traj = dl.transforms.selective_tree_map( + traj, + match=lambda k, _: k == traj_key, + map_fn=lambda x: tf.where(mask, (x - metadata[key]["mean"]) / (metadata[key]["std"] + 1e-8), x), + ) + + return traj + + elif normalization_type in [NormalizationType.BOUNDS, NormalizationType.BOUNDS_Q99]: + for key, traj_key in keys_to_normalize.items(): + if normalization_type == NormalizationType.BOUNDS: + low = metadata[key]["min"] + high = metadata[key]["max"] + elif normalization_type == NormalizationType.BOUNDS_Q99: + low = metadata[key]["q01"] + high = metadata[key]["q99"] + mask = metadata[key].get("mask", tf.ones_like(metadata[key]["min"], dtype=tf.bool)) + traj = dl.transforms.selective_tree_map( + traj, + match=lambda k, _: k == traj_key, + map_fn=lambda x: tf.where( + mask, + tf.clip_by_value(2 * (x - low) / (high - low + 1e-8) - 1, -1, 1), + x, + ), + ) + + # Note (Moo Jin): Map unused action dimensions (i.e., dimensions where min == max) to all 0s. + zeros_mask = metadata[key]["min"] == metadata[key]["max"] + traj = dl.transforms.selective_tree_map( + traj, match=lambda k, _: k == traj_key, map_fn=lambda x: tf.where(zeros_mask, 0.0, x) + ) + + return traj + + raise ValueError(f"Unknown Normalization Type {normalization_type}") + + +def binarize_gripper_actions(actions: tf.Tensor) -> tf.Tensor: + """ + Converts gripper actions from continuous to binary values (0 and 1). + + We exploit that fact that most of the time, the gripper is fully open (near 1.0) or fully closed (near 0.0). As it + transitions between the two, it sometimes passes through a few intermediate values. We relabel those intermediate + values based on the state that is reached _after_ those intermediate values. + + In the edge case that the trajectory ends with an intermediate value, we give up on binarizing and relabel that + chunk of intermediate values as the last action in the trajectory. + + The `scan_fn` implements the following logic: + new_actions = np.empty_like(actions) + carry = actions[-1] + for i in reversed(range(actions.shape[0])): + if in_between_mask[i]: + carry = carry + else: + carry = float(open_mask[i]) + new_actions[i] = carry + """ + open_mask, closed_mask = actions > 0.95, actions < 0.05 + in_between_mask = tf.logical_not(tf.logical_or(open_mask, closed_mask)) + is_open_float = tf.cast(open_mask, tf.float32) + + def scan_fn(carry, i): + return tf.cond(in_between_mask[i], lambda: tf.cast(carry, tf.float32), lambda: is_open_float[i]) + + return tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), actions[-1], reverse=True) + + +def invert_gripper_actions(actions: tf.Tensor) -> tf.Tensor: + return 1 - actions + + +def rel2abs_gripper_actions(actions: tf.Tensor) -> tf.Tensor: + """ + Converts relative gripper actions (+1 for closing, -1 for opening) to absolute actions (0 = closed; 1 = open). + + Assumes that the first relative gripper is not redundant (i.e. close when already closed)! + """ + # Note =>> -1 for closing, 1 for opening, 0 for no change + opening_mask, closing_mask = actions < -0.1, actions > 0.1 + thresholded_actions = tf.where(opening_mask, 1, tf.where(closing_mask, -1, 0)) + + def scan_fn(carry, i): + return tf.cond(thresholded_actions[i] == 0, lambda: carry, lambda: thresholded_actions[i]) + + # If no relative grasp, assumes open for whole trajectory + start = -1 * thresholded_actions[tf.argmax(thresholded_actions != 0, axis=0)] + start = tf.cond(start == 0, lambda: 1, lambda: start) + + # Note =>> -1 for closed, 1 for open + new_actions = tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), start) + new_actions = tf.cast(new_actions, tf.float32) / 2 + 0.5 + + return new_actions + + +# === Bridge-V2 =>> Dataset-Specific Transform === +def relabel_bridge_actions(traj: Dict[str, Any]) -> Dict[str, Any]: + """Relabels actions to use reached proprioceptive state; discards last timestep (no-action).""" + movement_actions = traj["observation"]["state"][1:, :6] - traj["observation"]["state"][:-1, :6] + traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj) + traj_truncated["action"] = tf.concat([movement_actions, traj["action"][:-1, -1:]], axis=1) + + return traj_truncated + + +# === RLDS Dataset Initialization Utilities === +def pprint_data_mixture(dataset_kwargs_list: List[Dict[str, Any]], dataset_weights: List[int]) -> None: + print("\n######################################################################################") + print(f"# Loading the following {len(dataset_kwargs_list)} datasets (incl. sampling weight):{'': >24} #") + for dataset_kwargs, weight in zip(dataset_kwargs_list, dataset_weights): + pad = 80 - len(dataset_kwargs["name"]) + print(f"# {dataset_kwargs['name']}: {weight:=>{pad}f} #") + print("######################################################################################\n") + + +def get_dataset_statistics( + dataset: dl.DLataset, + hash_dependencies: Tuple[str, ...], + save_dir: Optional[str] = None, +) -> Dict: + """ + Either computes the statistics of a dataset or loads them from a cache file if this function has been called before + with the same `hash_dependencies`. + + Currently, the statistics include the min/max/mean/std of the actions and proprio as well as the number of + transitions and trajectories in the dataset. + """ + unique_hash = hashlib.sha256("".join(hash_dependencies).encode("utf-8"), usedforsecurity=False).hexdigest() + + # Fallback local path for when data_dir is not writable or not provided + local_path = os.path.expanduser(os.path.join("~", ".cache", "orca", f"dataset_statistics_{unique_hash}.json")) + if save_dir is not None: + path = tf.io.gfile.join(save_dir, f"dataset_statistics_{unique_hash}.json") + else: + path = local_path + + # check if cache file exists and load + if tf.io.gfile.exists(path): + overwatch.info(f"Loading existing dataset statistics from {path}.") + with tf.io.gfile.GFile(path, "r") as f: + metadata = json.load(f) + return metadata + + if os.path.exists(local_path): + overwatch.info(f"Loading existing dataset statistics from {local_path}.") + with open(local_path, "r") as f: + metadata = json.load(f) + return metadata + + dataset = dataset.traj_map( + lambda traj: { + "action": traj["action"], + "proprio": ( + traj["observation"]["proprio"] if "proprio" in traj["observation"] else tf.zeros_like(traj["action"]) + ), + } + ) + + cardinality = dataset.cardinality().numpy() + if cardinality == tf.data.INFINITE_CARDINALITY: + raise ValueError("Cannot compute dataset statistics for infinite datasets.") + + overwatch.info("Computing dataset statistics. This may take a bit, but should only need to happen once.") + actions, proprios, num_transitions, num_trajectories = [], [], 0, 0 + for traj in tqdm(dataset.iterator(), total=cardinality if cardinality != tf.data.UNKNOWN_CARDINALITY else None): + actions.append(traj["action"]) + proprios.append(traj["proprio"]) + num_transitions += traj["action"].shape[0] + num_trajectories += 1 + + actions, proprios = np.concatenate(actions), np.concatenate(proprios) + metadata = { + "action": { + "mean": actions.mean(0).tolist(), + "std": actions.std(0).tolist(), + "max": actions.max(0).tolist(), + "min": actions.min(0).tolist(), + "q01": np.quantile(actions, 0.01, axis=0).tolist(), + "q99": np.quantile(actions, 0.99, axis=0).tolist(), + }, + "proprio": { + "mean": proprios.mean(0).tolist(), + "std": proprios.std(0).tolist(), + "max": proprios.max(0).tolist(), + "min": proprios.min(0).tolist(), + "q01": np.quantile(proprios, 0.01, axis=0).tolist(), + "q99": np.quantile(proprios, 0.99, axis=0).tolist(), + }, + "num_transitions": num_transitions, + "num_trajectories": num_trajectories, + } + + try: + with tf.io.gfile.GFile(path, "w") as f: + json.dump(metadata, f) + except tf.errors.PermissionDeniedError: + overwatch.warning(f"Could not write dataset statistics to {path}. Writing to {local_path} instead.") + os.makedirs(os.path.dirname(local_path), exist_ok=True) + with open(local_path, "w") as f: + json.dump(metadata, f) + + return metadata + + +def save_dataset_statistics(dataset_statistics, run_dir): + """Saves a `dataset_statistics.json` file.""" + out_path = run_dir / "dataset_statistics.json" + with open(out_path, "w") as f_json: + for _, stats in dataset_statistics.items(): + for k in stats["action"].keys(): + if isinstance(stats["action"][k], np.ndarray): + stats["action"][k] = stats["action"][k].tolist() + if "proprio" in stats: + for k in stats["proprio"].keys(): + if isinstance(stats["proprio"][k], np.ndarray): + stats["proprio"][k] = stats["proprio"][k].tolist() + if "num_trajectories" in stats: + if isinstance(stats["num_trajectories"], np.ndarray): + stats["num_trajectories"] = stats["num_trajectories"].item() + if "num_transitions" in stats: + if isinstance(stats["num_transitions"], np.ndarray): + stats["num_transitions"] = stats["num_transitions"].item() + json.dump(dataset_statistics, f_json, indent=2) + overwatch.info(f"Saved dataset statistics file at path {out_path}") + + +def allocate_threads(n: Optional[int], weights: np.ndarray): + """ + Allocates an integer number of threads across datasets based on weights. + + The final array sums to `n`, but each element is no less than 1. If `n` is None, then every dataset is assigned a + value of AUTOTUNE. + """ + if n is None: + return np.array([tf.data.AUTOTUNE] * len(weights)) + + assert np.all(weights >= 0), "Weights must be non-negative" + assert len(weights) <= n, "Number of threads must be at least as large as length of weights" + weights = np.array(weights) / np.sum(weights) + + allocation = np.zeros_like(weights, dtype=int) + while True: + # Give the remaining elements that would get less than 1 a 1 + mask = (weights * n < 1) & (weights > 0) + if not mask.any(): + break + n -= mask.sum() + allocation += mask.astype(int) + + # Recompute the distribution over the remaining elements + weights[mask] = 0 + weights = weights / weights.sum() + + # Allocate the remaining elements + fractional, integral = np.modf(weights * n) + allocation += integral.astype(int) + n -= integral.sum() + for i in np.argsort(fractional)[::-1][: int(n)]: + allocation[i] += 1 + + return allocation diff --git a/capvector-oft/prismatic/vla/datasets/rlds/utils/goal_relabeling.py b/capvector-oft/prismatic/vla/datasets/rlds/utils/goal_relabeling.py new file mode 100644 index 0000000000000000000000000000000000000000..c8a394955f7c1d3c2aad2ea8b157e6e06b60ae6b --- /dev/null +++ b/capvector-oft/prismatic/vla/datasets/rlds/utils/goal_relabeling.py @@ -0,0 +1,32 @@ +""" +goal_relabeling.py + +Contains simple goal relabeling logic for BC use-cases where rewards and next_observations are not required. +Each function should add entries to the "task" dict. +""" + +from typing import Dict + +import tensorflow as tf + +from prismatic.vla.datasets.rlds.utils.data_utils import tree_merge + + +def uniform(traj: Dict) -> Dict: + """Relabels with a true uniform distribution over future states.""" + traj_len = tf.shape(tf.nest.flatten(traj["observation"])[0])[0] + + # Select a random future index for each transition i in the range [i + 1, traj_len) + rand = tf.random.uniform([traj_len]) + low = tf.cast(tf.range(traj_len) + 1, tf.float32) + high = tf.cast(traj_len, tf.float32) + goal_idxs = tf.cast(rand * (high - low) + low, tf.int32) + + # Sometimes there are floating-point errors that cause an out-of-bounds + goal_idxs = tf.minimum(goal_idxs, traj_len - 1) + + # Adds keys to "task" mirroring "observation" keys (`tree_merge` to combine "pad_mask_dict" properly) + goal = tf.nest.map_structure(lambda x: tf.gather(x, goal_idxs), traj["observation"]) + traj["task"] = tree_merge(traj["task"], goal) + + return traj diff --git a/capvector-oft/prismatic/vla/datasets/rlds/utils/task_augmentation.py b/capvector-oft/prismatic/vla/datasets/rlds/utils/task_augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..f0d0c8e9785917cabb95742dd21efd4587976aa0 --- /dev/null +++ b/capvector-oft/prismatic/vla/datasets/rlds/utils/task_augmentation.py @@ -0,0 +1,57 @@ +""" +task_augmentation.py + +Contains basic logic for randomly zeroing out keys in the task specification. +""" + +from typing import Dict + +import tensorflow as tf + +from prismatic.vla.datasets.rlds.utils.data_utils import to_padding + + +def delete_task_conditioning(traj: Dict, keep_image_prob: float) -> Dict: + """ + Randomly drops out either the goal images or the language instruction. Only does something if both of + these are present. + + Args: + traj: A dictionary containing trajectory data. Should have a "task" key. + keep_image_prob: The probability of keeping the goal images. The probability of keeping the language + instruction is 1 - keep_image_prob. + """ + if "language_instruction" not in traj["task"]: + return traj + + image_keys = {key for key in traj["task"].keys() if key.startswith("image_") or key.startswith("depth_")} + if not image_keys: + return traj + + traj_len = tf.shape(traj["action"])[0] + should_keep_images = tf.random.uniform([traj_len]) < keep_image_prob + should_keep_images |= ~traj["task"]["pad_mask_dict"]["language_instruction"] + + for key in image_keys | {"language_instruction"}: + should_keep = should_keep_images if key in image_keys else ~should_keep_images + # pad out the key + traj["task"][key] = tf.where( + should_keep, + traj["task"][key], + to_padding(traj["task"][key]), + ) + # zero out the pad mask dict for the key + traj["task"]["pad_mask_dict"][key] = tf.where( + should_keep, + traj["task"]["pad_mask_dict"][key], + tf.zeros_like(traj["task"]["pad_mask_dict"][key]), + ) + + # when no goal images are present, the goal timestep becomes the final timestep + traj["task"]["timestep"] = tf.where( + should_keep_images, + traj["task"]["timestep"], + traj_len - 1, + ) + + return traj diff --git a/capvector-oft/prismatic/vla/materialize.py b/capvector-oft/prismatic/vla/materialize.py new file mode 100644 index 0000000000000000000000000000000000000000..6b267bc24c27e778234f03ee58c48f6d41b34148 --- /dev/null +++ b/capvector-oft/prismatic/vla/materialize.py @@ -0,0 +1,56 @@ +""" +materialize.py + +Factory class for initializing Open-X RLDS-backed datasets, given specified data mixture parameters; provides and +exports individual functions for clear control flow. +""" + +from pathlib import Path +from typing import Tuple, Type + +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizerBase + +from prismatic.models.backbones.llm.prompting import PromptBuilder +from prismatic.models.backbones.vision import ImageTransform +from prismatic.util.data_utils import PaddedCollatorForActionPrediction +from prismatic.vla.action_tokenizer import ActionTokenizer +from prismatic.vla.datasets import EpisodicRLDSDataset, RLDSBatchTransform, RLDSDataset + + +def get_vla_dataset_and_collator( + data_root_dir: Path, + data_mix: str, + image_transform: ImageTransform, + tokenizer: PreTrainedTokenizerBase, + prompt_builder_fn: Type[PromptBuilder], + default_image_resolution: Tuple[int, int, int], + padding_side: str = "right", + predict_stop_token: bool = True, + shuffle_buffer_size: int = 100_000, + train: bool = True, + episodic: bool = False, + image_aug: bool = False, +) -> Tuple[Dataset, ActionTokenizer, PaddedCollatorForActionPrediction]: + """Initialize RLDS Dataset (wraps TFDS), ActionTokenizer, and initialize transform/collation functions.""" + action_tokenizer = ActionTokenizer(tokenizer) + batch_transform = RLDSBatchTransform( + action_tokenizer, tokenizer, image_transform, prompt_builder_fn, predict_stop_token=predict_stop_token + ) + collator = PaddedCollatorForActionPrediction( + tokenizer.model_max_length, tokenizer.pad_token_id, padding_side=padding_side + ) + + # Build RLDS Iterable Dataset + cls = RLDSDataset if not episodic else EpisodicRLDSDataset + dataset = cls( + data_root_dir, + data_mix, + batch_transform, + resize_resolution=default_image_resolution[1:], + shuffle_buffer_size=shuffle_buffer_size, + train=train, + image_aug=image_aug, + ) + + return dataset, action_tokenizer, collator diff --git a/capvector-oft/scripts/extern/verify_prismatic.py b/capvector-oft/scripts/extern/verify_prismatic.py new file mode 100644 index 0000000000000000000000000000000000000000..3717bfb96250117ac754529c18e090b62a0aa4c4 --- /dev/null +++ b/capvector-oft/scripts/extern/verify_prismatic.py @@ -0,0 +1,134 @@ +""" +verify_prismatic.py + +Given an HF-exported Prismatic model, attempt to load via AutoClasses, and verify forward() and generate(). +""" + +import time + +import requests +import torch +from PIL import Image +from transformers import AutoModelForVision2Seq, AutoProcessor + +# === Verification Arguments === +MODEL_PATH = "TRI-ML/prismatic-siglip-224px-7b" +DEFAULT_IMAGE_URL = ( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png" +) + +if "-prism-" in MODEL_PATH: + SAMPLE_PROMPTS_FOR_GENERATION = [ + "In: What is sitting in the coffee?\nOut:", + "In: What's the name of the food on the plate?\nOut:", + "In: caption.\nOut:", + "In: how many beinets..?\nOut:", + "In: Can you give me a lyrical description of the scene\nOut:", + ] +else: + SYSTEM_PROMPT = ( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ) + SAMPLE_PROMPTS_FOR_GENERATION = [ + f"{SYSTEM_PROMPT} USER: What is sitting in the coffee? ASSISTANT:", + f"{SYSTEM_PROMPT} USER: What's the name of the food on the plate? ASSISTANT:", + f"{SYSTEM_PROMPT} USER: caption. ASSISTANT:", + f"{SYSTEM_PROMPT} USER: how many beinets..? ASSISTANT:", + f"{SYSTEM_PROMPT} USER: Can you give me a lyrical description of the scene ASSISTANT:", + ] + + +@torch.inference_mode() +def verify_prismatic() -> None: + print(f"[*] Verifying PrismaticForConditionalGeneration using Model `{MODEL_PATH}`") + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + # Load Processor & VLM + print("[*] Instantiating Processor and Pretrained VLM") + processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True) + + # === AUTOCAST MODE === + # print("[*] Loading in BF16 Autocast Mode") + # vlm = AutoModelForVision2Seq.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True, trust_remote_code=True).to( + # device, dtype=torch.bfloat16 + # ) + + # === NATIVE BFLOAT16 MODE === + # print("[*] Loading in BF16") + # vlm = AutoModelForVision2Seq.from_pretrained( + # MODEL_PATH, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True + # ).to(device) + + # === BFLOAT16 + FLASH-ATTN MODE :: [~14GB of VRAM Passive || 18GB of VRAM Active] === + print("[*] Loading in BF16 with Flash-Attention Enabled") + vlm = AutoModelForVision2Seq.from_pretrained( + MODEL_PATH, + attn_implementation="flash_attention_2", + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True, + ).to(device) + + # === 8-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~9GB of VRAM Passive || 10GB of VRAM Active] === + # print("[*] Loading in 8-Bit Quantization Mode") + # vlm = AutoModelForVision2Seq.from_pretrained( + # MODEL_PATH, + # attn_implementation="flash_attention_2", + # torch_dtype=torch.float16, + # quantization_config=BitsAndBytesConfig(load_in_8bit=True), + # low_cpu_mem_usage=True, + # trust_remote_code=True, + # ) + + # === 4-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~6GB of VRAM Passive || 7GB of VRAM Active] === + # print("[*] Loading in 4-Bit Quantization Mode") + # vlm = AutoModelForVision2Seq.from_pretrained( + # MODEL_PATH, + # attn_implementation="flash_attention_2", + # torch_dtype=torch.float16, + # quantization_config=BitsAndBytesConfig(load_in_4bit=True), + # low_cpu_mem_usage=True, + # trust_remote_code=True, + # ) + + # Iterate over Sample Prompts =>> Generate + image = Image.open(requests.get(DEFAULT_IMAGE_URL, stream=True).raw).convert("RGB") + num_tokens, total_time = 0, 0.0 + + print("[*] Iterating over Sample Prompts\n===\n") + for idx, prompt in enumerate(SAMPLE_PROMPTS_FOR_GENERATION): + # === AUTOCAST MODE (Reproduces Prismatic `scripts/generate.py`) === + # inputs = processor(prompt, image).to(device) + # + # # Using "autocast" to evaluate bit-wise equivalence to `scripts/generate.py` + # # =>> Running in native BF16 is also fine (but leads to slightly different generations) + # with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + # gen_ids = vlm.generate(**inputs, do_sample=False, min_length=1, max_length=512) + + # === BFLOAT16 MODE === + inputs = processor(prompt, image).to(device, dtype=torch.bfloat16) + + # === 8-BIT/4-BIT QUANTIZATION MODE === + # inputs = processor(prompt, image).to(device, dtype=torch.float16) + + # Run Inference + gen_ids = None + for _ in range(5): + start_time = time.time() + gen_ids = vlm.generate(**inputs, do_sample=False, min_length=1, max_length=512) + total_time += time.time() - start_time + + gen_ids = gen_ids[0, inputs.input_ids.shape[1] :] + num_tokens += len(gen_ids) + + # === + gen_text = processor.decode(gen_ids, skip_special_tokens=True).strip() + print(f"[{idx + 1}] Input Prompt => {prompt}\n Generated => {gen_text}\n") + + # Compute Tokens / Second + print(f"[*] Generated Tokens per Second = {num_tokens / total_time} w/ {num_tokens = } and {total_time = }") + + +if __name__ == "__main__": + verify_prismatic() diff --git a/capvector-oft/vla-scripts/deploy.py b/capvector-oft/vla-scripts/deploy.py new file mode 100644 index 0000000000000000000000000000000000000000..a7d743bf132cabff28c75ea0342913f8aeae2c52 --- /dev/null +++ b/capvector-oft/vla-scripts/deploy.py @@ -0,0 +1,156 @@ +""" +deploy.py + +Starts VLA server which the client can query to get robot actions. +""" + +import os.path + +# ruff: noqa: E402 +import json_numpy + +json_numpy.patch() +import json +import logging +import numpy as np +import traceback +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import draccus +import torch +import uvicorn +from fastapi import FastAPI +from fastapi.responses import JSONResponse +from PIL import Image +from transformers import AutoModelForVision2Seq, AutoProcessor + +from experiments.robot.openvla_utils import ( + get_vla, + get_vla_action, + get_action_head, + get_processor, + get_proprio_projector, +) +from experiments.robot.robot_utils import ( + get_image_resize_size, +) +from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, PROPRIO_DIM, STOP_INDEX + + +def get_openvla_prompt(instruction: str, openvla_path: Union[str, Path]) -> str: + return f"In: What action should the robot take to {instruction.lower()}?\nOut:" + + +# === Server Interface === +class OpenVLAServer: + def __init__(self, cfg) -> Path: + """ + A simple server for OpenVLA models; exposes `/act` to predict an action for a given observation + instruction. + """ + self.cfg = cfg + + # Load model + self.vla = get_vla(cfg) + + # Load proprio projector + self.proprio_projector = None + if cfg.use_proprio: + self.proprio_projector = get_proprio_projector(cfg, self.vla.llm_dim, PROPRIO_DIM) + + # Load continuous action head + self.action_head = None + if cfg.use_l1_regression or cfg.use_diffusion: + self.action_head = get_action_head(cfg, self.vla.llm_dim) + + # Check that the model contains the action un-normalization key + assert cfg.unnorm_key in self.vla.norm_stats, f"Action un-norm key {cfg.unnorm_key} not found in VLA `norm_stats`!" + + # Get Hugging Face processor + self.processor = None + self.processor = get_processor(cfg) + + # Get expected image dimensions + self.resize_size = get_image_resize_size(cfg) + + + def get_server_action(self, payload: Dict[str, Any]) -> str: + try: + if double_encode := "encoded" in payload: + # Support cases where `json_numpy` is hard to install, and numpy arrays are "double-encoded" as strings + assert len(payload.keys()) == 1, "Only uses encoded payload!" + payload = json.loads(payload["encoded"]) + + observation = payload + instruction = observation["instruction"] + + action = get_vla_action( + self.cfg, self.vla, self.processor, observation, instruction, action_head=self.action_head, proprio_projector=self.proprio_projector, use_film=self.cfg.use_film, + ) + + if double_encode: + return JSONResponse(json_numpy.dumps(action)) + else: + return JSONResponse(action) + except: # noqa: E722 + logging.error(traceback.format_exc()) + logging.warning( + "Your request threw an error; make sure your request complies with the expected format:\n" + "{'observation': dict, 'instruction': str}\n" + ) + return "error" + + def run(self, host: str = "0.0.0.0", port: int = 8777) -> None: + self.app = FastAPI() + self.app.post("/act")(self.get_server_action) + uvicorn.run(self.app, host=host, port=port) + + +@dataclass +class DeployConfig: + # fmt: off + + # Server Configuration + host: str = "0.0.0.0" # Host IP Address + port: int = 8777 # Host Port + + ################################################################################################################# + # Model-specific parameters + ################################################################################################################# + model_family: str = "openvla" # Model family + pretrained_checkpoint: Union[str, Path] = "" # Pretrained checkpoint path + + use_l1_regression: bool = True # If True, uses continuous action head with L1 regression objective + use_diffusion: bool = False # If True, uses continuous action head with diffusion modeling objective (DDIM) + num_diffusion_steps_train: int = 50 # (When `diffusion==True`) Number of diffusion steps used for training + num_diffusion_steps_inference: int = 50 # (When `diffusion==True`) Number of diffusion steps used for inference + use_film: bool = False # If True, uses FiLM to infuse language inputs into visual features + num_images_in_input: int = 3 # Number of images in the VLA input (default: 3) + use_proprio: bool = True # Whether to include proprio state in input + + center_crop: bool = True # Center crop? (if trained w/ random crop image aug) + + lora_rank: int = 32 # Rank of LoRA weight matrix (MAKE SURE THIS MATCHES TRAINING!) + + unnorm_key: Union[str, Path] = "" # Action un-normalization key + use_relative_actions: bool = False # Whether to use relative actions (delta joint angles) + + load_in_8bit: bool = False # (For OpenVLA only) Load with 8-bit quantization + load_in_4bit: bool = False # (For OpenVLA only) Load with 4-bit quantization + + ################################################################################################################# + # Utils + ################################################################################################################# + seed: int = 7 # Random Seed (for reproducibility) + # fmt: on + + +@draccus.wrap() +def deploy(cfg: DeployConfig) -> None: + server = OpenVLAServer(cfg) + server.run(cfg.host, port=cfg.port) + + +if __name__ == "__main__": + deploy() diff --git a/capvector-pi05/src/vggt/utils/pose_enc.py b/capvector-pi05/src/vggt/utils/pose_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..a6ddccbe71f3587860e91f9e3eb2a9a3c3af1ab6 --- /dev/null +++ b/capvector-pi05/src/vggt/utils/pose_enc.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from .rotation import quat_to_mat, mat_to_quat + + +def extri_intri_to_pose_encoding( + extrinsics, intrinsics, image_size_hw=None, pose_encoding_type="absT_quaR_FoV" # e.g., (256, 512) +): + """Convert camera extrinsics and intrinsics to a compact pose encoding. + + This function transforms camera parameters into a unified pose encoding format, + which can be used for various downstream tasks like pose prediction or representation. + + Args: + extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4, + where B is batch size and S is sequence length. + In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation. + The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector. + intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3. + Defined in pixels, with format: + [[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]] + where fx, fy are focal lengths and (cx, cy) is the principal point + image_size_hw (tuple): Tuple of (height, width) of the image in pixels. + Required for computing field of view values. For example: (256, 512). + pose_encoding_type (str): Type of pose encoding to use. Currently only + supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). + + Returns: + torch.Tensor: Encoded camera pose parameters with shape BxSx9. + For "absT_quaR_FoV" type, the 9 dimensions are: + - [:3] = absolute translation vector T (3D) + - [3:7] = rotation as quaternion quat (4D) + - [7:] = field of view (2D) + """ + + # extrinsics: BxSx3x4 + # intrinsics: BxSx3x3 + + if pose_encoding_type == "absT_quaR_FoV": + R = extrinsics[:, :, :3, :3] # BxSx3x3 + T = extrinsics[:, :, :3, 3] # BxSx3 + + quat = mat_to_quat(R) + # Note the order of h and w here + H, W = image_size_hw + fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1]) + fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0]) + pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float() + else: + raise NotImplementedError + + return pose_encoding + + +def pose_encoding_to_extri_intri( + pose_encoding, image_size_hw=None, pose_encoding_type="absT_quaR_FoV", build_intrinsics=True # e.g., (256, 512) +): + """Convert a pose encoding back to camera extrinsics and intrinsics. + + This function performs the inverse operation of extri_intri_to_pose_encoding, + reconstructing the full camera parameters from the compact encoding. + + Args: + pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9, + where B is batch size and S is sequence length. + For "absT_quaR_FoV" type, the 9 dimensions are: + - [:3] = absolute translation vector T (3D) + - [3:7] = rotation as quaternion quat (4D) + - [7:] = field of view (2D) + image_size_hw (tuple): Tuple of (height, width) of the image in pixels. + Required for reconstructing intrinsics from field of view values. + For example: (256, 512). + pose_encoding_type (str): Type of pose encoding used. Currently only + supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). + build_intrinsics (bool): Whether to reconstruct the intrinsics matrix. + If False, only extrinsics are returned and intrinsics will be None. + + Returns: + tuple: (extrinsics, intrinsics) + - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4. + In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world + transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is + a 3x1 translation vector. + - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3, + or None if build_intrinsics is False. Defined in pixels, with format: + [[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]] + where fx, fy are focal lengths and (cx, cy) is the principal point, + assumed to be at the center of the image (W/2, H/2). + """ + + intrinsics = None + + if pose_encoding_type == "absT_quaR_FoV": + T = pose_encoding[..., :3] + quat = pose_encoding[..., 3:7] + fov_h = pose_encoding[..., 7] + fov_w = pose_encoding[..., 8] + + R = quat_to_mat(quat) + extrinsics = torch.cat([R, T[..., None]], dim=-1) + + if build_intrinsics: + H, W = image_size_hw + fy = (H / 2.0) / torch.tan(fov_h / 2.0) + fx = (W / 2.0) / torch.tan(fov_w / 2.0) + intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device) + intrinsics[..., 0, 0] = fx + intrinsics[..., 1, 1] = fy + intrinsics[..., 0, 2] = W / 2 + intrinsics[..., 1, 2] = H / 2 + intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1 + else: + raise NotImplementedError + + return extrinsics, intrinsics diff --git a/capvector-pi05/src/vggt/utils/rotation.py b/capvector-pi05/src/vggt/utils/rotation.py new file mode 100644 index 0000000000000000000000000000000000000000..494f176450cbf1fd4dd3ef21787201b12f357843 --- /dev/null +++ b/capvector-pi05/src/vggt/utils/rotation.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d + +import torch +import numpy as np +import torch.nn.functional as F + + +def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: + """ + Quaternion Order: XYZW or say ijkr, scalar-last + + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + i, j, k, r = torch.unbind(quaternions, -1) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part last, as tensor of shape (..., 4). + Quaternion Order: XYZW or say ijkr, scalar-last + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1) + + q_abs = _sqrt_positive_part( + torch.stack( + [1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1 + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,)) + + # Convert from rijk to ijkr + out = out[..., [1, 2, 3, 0]] + + out = standardize_quaternion(out) + + return out + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) diff --git a/capvector-pi05/src/vggt/utils/visual_track.py b/capvector-pi05/src/vggt/utils/visual_track.py new file mode 100644 index 0000000000000000000000000000000000000000..a4e0a27dedb261173ccd5f87d04fbe472657bb96 --- /dev/null +++ b/capvector-pi05/src/vggt/utils/visual_track.py @@ -0,0 +1,239 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import cv2 +import torch +import numpy as np +import os + + +def color_from_xy(x, y, W, H, cmap_name="hsv"): + """ + Map (x, y) -> color in (R, G, B). + 1) Normalize x,y to [0,1]. + 2) Combine them into a single scalar c in [0,1]. + 3) Use matplotlib's colormap to convert c -> (R,G,B). + + You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y). + """ + import matplotlib.cm + import matplotlib.colors + + x_norm = x / max(W - 1, 1) + y_norm = y / max(H - 1, 1) + # Simple combination: + c = (x_norm + y_norm) / 2.0 + + cmap = matplotlib.cm.get_cmap(cmap_name) + # cmap(c) -> (r,g,b,a) in [0,1] + rgba = cmap(c) + r, g, b = rgba[0], rgba[1], rgba[2] + return (r, g, b) # in [0,1], RGB order + + +def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"): + """ + Given all tracks in one sample (b), compute a (N,3) array of RGB color values + in [0,255]. The color is determined by the (x,y) position in the first + visible frame for each track. + + Args: + tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame. + vis_mask_b: (S, N) boolean mask; if None, assume all are visible. + image_width, image_height: used for normalizing (x, y). + cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet'). + + Returns: + track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255]. + """ + S, N, _ = tracks_b.shape + track_colors = np.zeros((N, 3), dtype=np.uint8) + + if vis_mask_b is None: + # treat all as visible + vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device) + + for i in range(N): + # Find first visible frame for track i + visible_frames = torch.where(vis_mask_b[:, i])[0] + if len(visible_frames) == 0: + # track is never visible; just assign black or something + track_colors[i] = (0, 0, 0) + continue + + first_s = int(visible_frames[0].item()) + # use that frame's (x,y) + x, y = tracks_b[first_s, i].tolist() + + # map (x,y) -> (R,G,B) in [0,1] + r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name) + # scale to [0,255] + r, g, b = int(r * 255), int(g * 255), int(b * 255) + track_colors[i] = (r, g, b) + + return track_colors + + +def visualize_tracks_on_images( + images, + tracks, + track_vis_mask=None, + out_dir="track_visuals_concat_by_xy", + image_format="CHW", # "CHW" or "HWC" + normalize_mode="[0,1]", + cmap_name="hsv", # e.g. "hsv", "rainbow", "jet" + frames_per_row=4, # New parameter for grid layout + save_grid=True, # Flag to control whether to save the grid image +): + """ + Visualizes frames in a grid layout with specified frames per row. + Each track's color is determined by its (x,y) position + in the first visible frame (or frame 0 if always visible). + Finally convert the BGR result to RGB before saving. + Also saves each individual frame as a separate PNG file. + + Args: + images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC. + tracks: torch.Tensor (S, N, 2), last dim = (x, y). + track_vis_mask: torch.Tensor (S, N) or None. + out_dir: folder to save visualizations. + image_format: "CHW" or "HWC". + normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255 + cmap_name: a matplotlib colormap name for color_from_xy. + frames_per_row: number of frames to display in each row of the grid. + save_grid: whether to save all frames in one grid image. + + Returns: + None (saves images in out_dir). + """ + + if len(tracks.shape) == 4: + tracks = tracks.squeeze(0) + images = images.squeeze(0) + if track_vis_mask is not None: + track_vis_mask = track_vis_mask.squeeze(0) + + import matplotlib + + matplotlib.use("Agg") # for non-interactive (optional) + + os.makedirs(out_dir, exist_ok=True) + + S = images.shape[0] + _, N, _ = tracks.shape # (S, N, 2) + + # Move to CPU + images = images.cpu().clone() + tracks = tracks.cpu().clone() + if track_vis_mask is not None: + track_vis_mask = track_vis_mask.cpu().clone() + + # Infer H, W from images shape + if image_format == "CHW": + # e.g. images[s].shape = (3, H, W) + H, W = images.shape[2], images.shape[3] + else: + # e.g. images[s].shape = (H, W, 3) + H, W = images.shape[1], images.shape[2] + + # Pre-compute the color for each track i based on first visible position + track_colors_rgb = get_track_colors_by_position( + tracks, # shape (S, N, 2) + vis_mask_b=track_vis_mask if track_vis_mask is not None else None, + image_width=W, + image_height=H, + cmap_name=cmap_name, + ) + + # We'll accumulate each frame's drawn image in a list + frame_images = [] + + for s in range(S): + # shape => either (3, H, W) or (H, W, 3) + img = images[s] + + # Convert to (H, W, 3) + if image_format == "CHW": + img = img.permute(1, 2, 0) # (H, W, 3) + # else "HWC", do nothing + + img = img.numpy().astype(np.float32) + + # Scale to [0,255] if needed + if normalize_mode == "[0,1]": + img = np.clip(img, 0, 1) * 255.0 + elif normalize_mode == "[-1,1]": + img = (img + 1.0) * 0.5 * 255.0 + img = np.clip(img, 0, 255.0) + # else no normalization + + # Convert to uint8 + img = img.astype(np.uint8) + + # For drawing in OpenCV, convert to BGR + img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + # Draw each visible track + cur_tracks = tracks[s] # shape (N, 2) + if track_vis_mask is not None: + valid_indices = torch.where(track_vis_mask[s])[0] + else: + valid_indices = range(N) + + cur_tracks_np = cur_tracks.numpy() + for i in valid_indices: + x, y = cur_tracks_np[i] + pt = (int(round(x)), int(round(y))) + + # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR + R, G, B = track_colors_rgb[i] + color_bgr = (int(B), int(G), int(R)) + cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1) + + # Convert back to RGB for consistent final saving: + img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) + + # Save individual frame + frame_path = os.path.join(out_dir, f"frame_{s:04d}.png") + # Convert to BGR for OpenCV imwrite + frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) + cv2.imwrite(frame_path, frame_bgr) + + frame_images.append(img_rgb) + + # Only create and save the grid image if save_grid is True + if save_grid: + # Calculate grid dimensions + num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division + + # Create a grid of images + grid_img = None + for row in range(num_rows): + start_idx = row * frames_per_row + end_idx = min(start_idx + frames_per_row, S) + + # Concatenate this row horizontally + row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1) + + # If this row has fewer than frames_per_row images, pad with black + if end_idx - start_idx < frames_per_row: + padding_width = (frames_per_row - (end_idx - start_idx)) * W + padding = np.zeros((H, padding_width, 3), dtype=np.uint8) + row_img = np.concatenate([row_img, padding], axis=1) + + # Add this row to the grid + if grid_img is None: + grid_img = row_img + else: + grid_img = np.concatenate([grid_img, row_img], axis=0) + + out_path = os.path.join(out_dir, "tracks_grid.png") + # Convert back to BGR for OpenCV imwrite + grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR) + cv2.imwrite(out_path, grid_img_bgr) + print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}") + + print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png")