| """ |
| modeling_prismatic.py |
| |
| Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions, inheriting |
| from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained, but exactly replicate the |
| logic in `prismatic.models.vlms.prismatic.py`. |
| |
| Note =>> for the time being, not adding the custom HF "docstring" formatting. |
| |
| References [LLaVa, IDEFICS-2]: |
| => https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py |
| => https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics2/modeling_idefics2.py |
| """ |
|
|
| import logging |
| from dataclasses import dataclass |
| from functools import partial |
| from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union |
| from functools import cached_property |
| |
|
|
| import numpy as np |
| import timm |
| import tokenizers |
| import torch |
| import torch.nn as nn |
| import transformers |
| from timm.models.vision_transformer import LayerScale |
| from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig |
| from transformers.modeling_outputs import ModelOutput |
| import collections |
| import math |
| from barrel.pipes.vlams.extern.prismatic_config import OpenVLAConfig, PrismaticConfig , TrajectoryVLAConfig, WaypointTokenizer |
| |
| from barrel.pipes.vlams.extern.datatypes import * |
| from barrel.pipes.vlams.extern.detr import * |
| from IPython import embed |
| import os |
| from PIL import Image |
| from pathlib import Path |
| from torch.amp.autocast_mode import autocast |
| from scipy.spatial.transform import Rotation as R |
| |
|
|
| ht_token_path = Path(".hf_token") |
| HF_TOKEN = ht_token_path.read_text().strip() if isinstance(ht_token_path, Path) else hf_token_path |
|
|
| |
| logger = logging.getLogger(__name__) |
| torch.backends.cudnn.benchmark = False |
| torch.backends.cudnn.deterministic = True |
|
|
| |
| IGNORE_INDEX = -100 |
|
|
|
|
| |
| def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]: |
| def wrapper(*args: Any, **kwargs: Any) -> Any: |
| result = fn(*args, **kwargs) |
| return result[0] if isinstance(result, tuple) else result |
|
|
| return wrapper |
|
|
|
|
| |
| |
| |
| def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor |
|
|
|
|
| def ls_apply_patch(ls_module: LayerScale): |
| ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone()) |
| ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale) |
| del ls_module.gamma |
|
|
|
|
| |
| class PrismaticVisionBackbone(nn.Module): |
| def __init__( |
| self, |
| use_fused_vision_backbone: bool, |
| image_sizes: List[int], |
| timm_model_ids: List[str], |
| timm_override_act_layers: List[Optional[str]], |
| ) -> None: |
| super().__init__() |
| self.use_fused_vision_backbone = use_fused_vision_backbone |
|
|
| |
| |
| |
| assert len(timm_model_ids) <= 2, "Prismatic models only support up to 2 (fused) vision backbones!" |
|
|
| self.dino_featurizer = timm.create_model( |
| timm_model_ids[0], |
| pretrained=True, |
| num_classes=0, |
| img_size=image_sizes[0], |
| act_layer=timm_override_act_layers[0], |
| ) |
| self.dino_featurizer.eval() |
|
|
| self.embed_dim = self.dino_featurizer.embed_dim |
|
|
| |
| |
| self.siglip_featurizer = timm.create_model( |
| timm_model_ids[1], |
| pretrained=True, |
| num_classes=0, |
| img_size=image_sizes[1], |
| act_layer=timm_override_act_layers[1],) |
|
|
| self.siglip_featurizer.eval() |
|
|
| self.dino_featurizer.forward = partial( |
| self.dino_featurizer.forward_intermediates, |
| indices=[len(self.dino_featurizer.blocks) - 2], |
| return_prefix_tokens=False, |
| norm=False, |
| stop_early=True, |
| output_fmt='NLC', |
| intermediates_only=True, |
| ) |
| self.siglip_featurizer.forward = partial( |
| self.siglip_featurizer.forward_intermediates, |
| indices=[len(self.siglip_featurizer.blocks) - 2], |
| return_prefix_tokens=False, |
| norm=False, |
| stop_early=True, |
| output_fmt='NLC', |
| intermediates_only=True, |
| ) |
| self.embed_dim += self.siglip_featurizer.embed_dim |
|
|
| def forward(self, pixel_values) -> torch.Tensor: |
| """Run image (`pixel_values`) through featurizer; if channel-stacked, then dispatch and sequence stack.""" |
| if not self.use_fused_vision_backbone: |
| return self.featurizer(pixel_values) |
|
|
| |
| |
| img = pixel_values['dino'] |
| img_fused = pixel_values['siglip'] |
| patches, patches_fused = self.dino_featurizer(img)[0], self.siglip_featurizer(img_fused)[0] |
|
|
| return torch.cat([patches, patches_fused], dim=2) |
|
|
|
|
|
|
| class PrismaticProjector(nn.Module): |
| def __init__(self, use_fused_vision_backbone, vision_dim: int, llm_dim: int) -> None: |
| super().__init__() |
| self.initial_projection_dim = vision_dim * 4 |
| self.projector = torch.nn.Sequential( |
| torch.nn.Linear(vision_dim, self.initial_projection_dim, bias=True), |
| torch.nn.GELU(), |
| torch.nn.Linear(self.initial_projection_dim, llm_dim, bias=True), |
| torch.nn.GELU(), |
| torch.nn.Linear(llm_dim, llm_dim, bias=True), |
| ) |
|
|
| def forward(self, fused_img_patches: torch.Tensor) -> torch.Tensor: |
| return self.projector(fused_img_patches) |
|
|
| |
| @dataclass |
| class PrismaticCausalLMOutputWithPast(ModelOutput): |
| """Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features.""" |
|
|
| loss: Optional[torch.FloatTensor] = None |
| logits: torch.FloatTensor = None |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
| attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
| |
| projector_features: Optional[torch.FloatTensor] = None |
|
|
|
|
| class PrismaticPreTrainedModel(PreTrainedModel): |
| config_class: PrismaticConfig |
| base_model_prefix: str = "model" |
| supports_gradient_checkpointing: bool = True |
|
|
| _no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"] |
| _skip_keys_device_placement: str = "past_key_values" |
| _supports_flash_attn_2: bool = True |
|
|
| def _init_weights(self, module: nn.Module) -> None: |
| |
| |
| |
| std = ( |
| self.config.initializer_range |
| if hasattr(self.config, "initializer_range") |
| else self.config.text_config.initializer_range |
| ) |
|
|
| if hasattr(module, "class_embedding"): |
| module.class_embedding.data.normal_(mean=0.0, std=std) |
|
|
| if isinstance(module, (nn.Linear, nn.Conv2d)): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
| @property |
| def _supports_sdpa(self) -> bool: |
| """Check LLM supports SDPA Attention""" |
| return self.language_model._supports_sdpa |
|
|
| class LLMBackbone(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.llm : AutoModelForCausalLM |
| self.tokenizer = self._create_tokenizer() |
|
|
| def _create_tokenizer(self) -> transformers.PreTrainedTokenizerBase: |
| |
| print(f"Loading (Fast) Tokenizer via the AutoTokenizer API") |
| tokenizer = transformers.AutoTokenizer.from_pretrained( |
| self.config['hf_model_id'], |
| model_max_length=self.config['llm_max_length'], |
| token=HF_TOKEN, |
| padding_side="right", |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| SPECIAL_CASES = { |
| |
| |
| |
| |
| "microsoft/phi-2", |
| } |
| if self.config['hf_model_id'] not in SPECIAL_CASES: |
| |
| assert ( |
| tokenizer("Test 123", add_special_tokens=True).input_ids[0] == tokenizer.bos_token_id |
| ) and ( |
| tokenizer("Test 123", add_special_tokens=False).input_ids[0] != tokenizer.bos_token_id |
| ), f"Default Tokenizer of type `{type(tokenizer)}` does not automatically prefix inputs with BOS token!\n" |
|
|
| return tokenizer |
|
|
| |
| class PrismaticForConditionalGeneration(PrismaticPreTrainedModel): |
| |
| config_class: PretrainedConfig = PrismaticConfig |
| def __init__(self, config: PrismaticConfig) -> None: |
| super().__init__(config) |
| |
| if config.use_fused_vision_backbone is None: |
| raise ValueError("Missing config field `use_fused_vision_backbone`") |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| self.vision_backbone = PrismaticVisionBackbone( |
| config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers |
| ) |
|
|
| |
| self.projector = PrismaticProjector( |
| config.use_fused_vision_backbone, |
| vision_dim=self.vision_backbone.embed_dim, |
| llm_dim=config.text_config.hidden_size, |
| ) |
|
|
| |
| self.llm_backbone = LLMBackbone({'hf_model_id': config.hf_llm_id, 'llm_max_length': config.llm_max_length, "pad_token_id" :32000, |
| "pad_to_multiple_of" : 64,}) |
|
|
| |
| |
| |
| self.llm_backbone.llm = AutoModelForCausalLM.from_pretrained( |
| 'meta-llama/Llama-2-7b-hf', |
| token=HF_TOKEN, |
| attn_implementation='flash_attention_2', |
| |
| do_sample=False, |
| temperature=1.0, |
| use_cache=False, |
| top_p=1.0, ) |
|
|
| self.llm_backbone.tokenizer.add_special_tokens({"pad_token": "<PAD>"}) |
| self.llm_backbone.llm.config.pad_token_id = self.llm_backbone.tokenizer.pad_token_id |
| self.llm_backbone.llm.resize_token_embeddings(len(self.llm_backbone.tokenizer), pad_to_multiple_of=64) |
|
|
|
|
|
|
| |
| |
| |
|
|
| self.vocab_size = config.text_config.vocab_size |
| self.pad_token_id = config.pad_token_id |
|
|
| |
| self.post_init() |
|
|
| |
| def get_input_embeddings(self) -> nn.Module: |
| return self.llm_backbone.llm.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value: nn.Module) -> None: |
| self.llm_backbone.llm.set_input_embeddings(value) |
|
|
| def get_output_embeddings(self) -> nn.Module: |
| return self.llm_backbone.llm.get_output_embeddings() |
|
|
| def set_output_embeddings(self, new_embeddings: nn.Module) -> None: |
| self.llm_backbone.llm.set_output_embeddings(new_embeddings) |
|
|
| def get_decoder(self) -> nn.Module: |
| return self.llm_backbone.llm.get_decoder() |
|
|
| def set_decoder(self, decoder: nn.Module) -> None: |
| self.llm_backbone.llm.set_decoder(decoder) |
|
|
| def tie_weights(self) -> None: |
| self.llm_backbone.llm.tie_weights() |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
|
|
| |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] , |
| attention_mask: Optional[torch.Tensor], |
| |
| pixel_values: Dict[str, torch.Tensor] = {}, |
| labels: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_projector_features: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs: Any, |
| ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]: |
| """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance.""" |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| output_projector_features = output_projector_features if output_projector_features is not None else False |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| use_cache = use_cache and not self.training |
|
|
| |
| projected_patch_embeddings = None |
|
|
| |
| |
| |
| |
|
|
| |
| if input_ids.shape[1] == 1: |
| assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!" |
| assert past_key_values is not None, "You must provide `past_key_values` during cached generation!" |
| assert labels is None, "Unexpected key `labels` provided during cached generation!" |
|
|
| language_model_output = self.llm_backbone.llm( |
| input_ids=input_ids, |
| attention_mask=None, |
| position_ids=None, |
| past_key_values=past_key_values, |
| inputs_embeds=None, |
| labels=None, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| |
| elif pixel_values is None: |
| assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!" |
| assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!" |
|
|
| language_model_output = self.llm_backbone.llm( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=None, |
| past_key_values=None, |
| inputs_embeds=None, |
| labels=labels, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| |
|
|
| elif (input_ids.shape[0] == pixel_values['dino'].shape[0]) or (inputs_embeds.shape[0] == pixel_values['dino'].shape[0]): |
| assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!" |
|
|
| |
| patch_features = self.vision_backbone(pixel_values) |
|
|
| projected_patch_embeddings = self.projector(patch_features) |
| projected_patch_attention_mask = None |
| if attention_mask is not None: |
| projected_patch_attention_mask = torch.full( |
| (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]), |
| fill_value=True, |
| dtype=attention_mask.dtype, |
| device=attention_mask.device, |
| ) |
|
|
| |
| input_embeddings = self.get_input_embeddings()(input_ids) |
|
|
| |
| multimodal_embeddings = torch.cat( |
| [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1 |
| ) |
| multimodal_attention_mask = None |
| if attention_mask is not None: |
| multimodal_attention_mask = torch.cat( |
| [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1 |
| ) |
|
|
| |
| multimodal_labels = None |
| if labels is not None: |
| projected_patch_labels = torch.full( |
| (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]), |
| fill_value=IGNORE_INDEX, |
| dtype=labels.dtype, |
| device=labels.device, |
| ) |
| multimodal_labels = torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1) |
|
|
| |
| language_model_output = self.llm_backbone.llm( |
| input_ids=None, |
| attention_mask=multimodal_attention_mask, |
| position_ids=None, |
| past_key_values=None, |
| inputs_embeds=multimodal_embeddings, |
| labels=multimodal_labels, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| |
| elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]): |
| raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!") |
|
|
| else: |
| raise ValueError( |
| "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n" |
| f"=> `input_ids` = {input_ids is not None}\n" |
| f"=> `attention_mask` = {attention_mask is not None}\n" |
| f"=> `pixel_values` = {pixel_values is not None}\n" |
| f"=> `labels` = {labels is not None}\n" |
| f"=> `input_embeds` = {inputs_embeds is not None}\n" |
| f"=> `past_key_values` = {past_key_values is not None}\n" |
| f"=> `use_cache` = {use_cache}" |
| ) |
|
|
| |
| if not return_dict: |
| if output_projector_features and (projected_patch_embeddings is not None): |
| return *language_model_output, projected_patch_embeddings |
|
|
| return language_model_output |
|
|
|
|
| return (PrismaticCausalLMOutputWithPast( |
| loss=language_model_output.loss, |
| logits=language_model_output.logits, |
| past_key_values=language_model_output.past_key_values, |
| hidden_states=language_model_output.hidden_states, |
| attentions=language_model_output.attentions, |
| projector_features=projected_patch_embeddings, |
| ),patch_features,multimodal_attention_mask) |
|
|
| |
| def prepare_inputs_for_generation( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| **kwargs: str, |
| ) -> Dict[str, torch.Tensor]: |
| """Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic.""" |
| if ((input_ids is not None) and (input_ids.shape[0] > 1)) or ( |
| (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1) |
| ): |
| raise ValueError("Generation with batch size > 1 is not currently supported!") |
|
|
| |
| if past_key_values is not None: |
| input_ids = input_ids[:, -1:] |
|
|
| |
| if inputs_embeds is not None and past_key_values is None: |
| model_inputs = {"input_embeds": inputs_embeds} |
| else: |
| model_inputs = {"input_ids": input_ids} |
|
|
| |
| model_inputs.update( |
| { |
| "attention_mask": attention_mask, |
| "pixel_values": pixel_values, |
| "past_key_values": past_key_values, |
| "use_cache": kwargs.get("use_cache"), |
| } |
| ) |
|
|
| return model_inputs |
|
|
| |
| def _reorder_cache(self, *args, **kwargs) -> Any: |
| return self.language_model._reorder_cache(*args, **kwargs) |
|
|
|
|
| class TokenProjectorConfig(PretrainedConfig): |
| vit_tokens_layers: List[int] = [] |
| llm_image_tokens_layers: List[int] = [] |
| control_tokens_layers: List[int] = [] |
|
|
| |
| |
| |
| |
| |
| image_tokens_mode: str |
|
|
| def __post_init__(self): |
| super().__post_init__() |
|
|
| if self.image_tokens_mode == 'vit': |
| assert len(self.vit_tokens_layers) > 0 or len(self.control_tokens_layers) > 0 |
| elif self.image_tokens_mode == 'llm': |
| assert len(self.vit_tokens_layers) > 0 or len(self.control_tokens_layers) > 0 |
| elif self.image_tokens_mode == 'skip': |
| assert len(self.vit_tokens_layers) > 0 or len(self.llm_image_tokens_layers) > 0 |
| elif self.image_tokens_mode == 'none': |
| assert len(self.vit_tokens_layers) == 0 |
| assert len(self.llm_image_tokens_layers) == 0 |
| else: |
| raise NotImplementedError(f"Unknown image tokens mode {self.image_tokens_mode}") |
|
|
| class TokenProjector(nn.Module): |
| """Project and pack VLM output tokens""" |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.config = TokenProjectorConfig() |
| self.config.vit_tokens_layers = config['vit_tokens_layers'] |
| self.config.llm_image_tokens_layers = config['llm_image_tokens_layers'] |
| self.config.control_tokens_layers = config['control_tokens_layers'] |
| self.config.image_tokens_mode = config['image_tokens_mode'] |
|
|
| self.vit_tokens_proj = self._make_token_proj_module(self.config.vit_tokens_layers) |
| self.llm_image_tokens_proj = self._make_token_proj_module(self.config.llm_image_tokens_layers) |
| self.control_tokens_proj = self._make_token_proj_module(self.config.control_tokens_layers) |
|
|
| def forward(self, inputs: WaypointerInput) -> torch.Tensor: |
| """ |
| Args: |
| inputs: Contains VLM outputs |
| Returns: |
| torch.Tensor of shape [B, num_tokens, token_size] that always contains the control tokens |
| and possibly the image tokens (prepended), depending on the configuration |
| """ |
|
|
| vit_tokens = self.vit_tokens_proj(inputs.vit_tokens) |
| control_tokens = self.control_tokens_proj(inputs.control_tokens) |
| llm_image_tokens = self.llm_image_tokens_proj(inputs.llm_image_tokens) |
|
|
| if self.config.image_tokens_mode == 'vit': |
| output = torch.cat([vit_tokens, control_tokens], dim=1) |
| elif self.config.image_tokens_mode == 'llm': |
| output = torch.cat([llm_image_tokens, control_tokens], dim=1) |
| elif self.config.image_tokens_mode == 'skip': |
| image_tokens = llm_image_tokens + vit_tokens |
| output = torch.cat([image_tokens, control_tokens], dim=1) |
| elif self.config.image_tokens_mode == 'none': |
| output = control_tokens |
| else: |
| raise NotImplementedError(f"Unknown image tokens mode {self.config.image_tokens_mode}") |
|
|
| return output |
|
|
| def _make_token_proj_module(self, layer_sizes: List[int]) -> torch.nn.Module: |
| if len(layer_sizes) == 0: |
| return torch.nn.Identity() |
|
|
| assert len(layer_sizes) > 1, "Need to provide input and output layer sizes at least" |
|
|
| module = torch.nn.Sequential( |
| *[ |
| torch.nn.Sequential( |
| collections.OrderedDict( |
| { |
| 'linear': torch.nn.Linear(layer_in_features, layer_out_features), |
| 'act': torch.nn.ReLU(), |
| 'norm': torch.nn.LayerNorm(layer_out_features), |
| } |
| ) |
| ) |
| for layer_in_features, layer_out_features in zip(layer_sizes[:-1], layer_sizes[1:]) |
| ] |
| ) |
| return module |
|
|
| class NeRFPositionalEmbedding(torch.nn.Module): |
| def __init__(self, proj_scale: int): |
| """ |
| Args: |
| proj_scale: Dimension size, same as L parameter in the NeRF paper |
| """ |
| super().__init__() |
| self.proj_scale = proj_scale |
|
|
| freq = 2 ** torch.arange(self.proj_scale, dtype=torch.float32) * math.pi |
|
|
| self.register_buffer('freq', freq) |
|
|
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| """ |
| Maps values from R^N to a higher dimensional space R^(N2L) |
| Args: |
| inputs: torch.Tensor of shape [B, ..., N]; input values to be transformed |
| Returns: torch.Tensor of shape [B, ..., N2L]; encoded input values |
| """ |
|
|
| spectrum = self.freq.view(*[1] * inputs.ndim, -1) * inputs.unsqueeze(-1) |
| encoding = torch.stack([torch.sin(spectrum), torch.cos(spectrum)], dim=-2) |
| encoding = encoding.view(inputs.shape[-1], -1) |
|
|
| return encoding |
|
|
| class TimestepProjModuleConfig(PretrainedConfig): |
| pos_embed_scale: int |
| proj_layers: List[int] |
| time_delta_sec: float = 0.25 |
| num_tokens: int = 3 |
|
|
|
|
| class TimestepProjModule(nn.Module): |
|
|
| def __init__(self, config: TimestepProjModuleConfig, num_timesteps: int, token_size: int): |
| """ |
| Args: |
| num_timesteps: Number of control timesteps |
| token_size: Single token size |
| """ |
| super().__init__() |
| self.config = TimestepProjModuleConfig() |
| self.config.pos_embed_scale = config['pos_embed_scale'] |
| self.config.proj_layers = config['proj_layers'] |
| self.config.time_delta_sec = config['time_delta_sec'] |
| self.config.num_tokens = config['num_tokens'] |
|
|
| self.num_timesteps = num_timesteps |
| self.token_size = token_size |
|
|
| input_size = 2 * self.config.pos_embed_scale |
|
|
| self.pos_embed = NeRFPositionalEmbedding(self.config.pos_embed_scale) |
|
|
| |
| feature_size = self.config.num_tokens * self.token_size |
|
|
| |
|
|
| self.timestep_proj = self._make_timestep_proj(in_features=int(input_size), out_features=int(feature_size)) |
|
|
| def _make_timestep_proj(self, in_features: int, out_features: int) -> torch.nn.Module: |
| layer_sizes = [in_features] + list(self.config.proj_layers) + [out_features] |
| module = torch.nn.Sequential( |
| *[ |
| torch.nn.Sequential( |
| collections.OrderedDict( |
| { |
| 'linear': torch.nn.Linear(layer_in_features, layer_out_features), |
| 'act': torch.nn.ReLU(), |
| 'norm': torch.nn.LayerNorm(layer_out_features), |
| } |
| ) |
| ) |
| for layer_in_features, layer_out_features in zip(layer_sizes[:-1], layer_sizes[1:]) |
| ] |
| ) |
| return module |
|
|
| def forward(self) -> torch.Tensor: |
| """ |
| Returns: |
| torch.Tensor of sequence of timestep tokens, shape [1, num_timesteps * num_tokens, token_size] |
| """ |
| device = self.timestep_proj[0].linear.weight.device |
|
|
| |
| time_deltas_norm = self.time_deltas_norm.view(1, self.num_timesteps) |
| time_deltas_norm = time_deltas_norm.to(device=device) |
|
|
| |
| timesteps_embed = self.pos_embed(time_deltas_norm) |
| timesteps_embed = timesteps_embed.view(self.num_timesteps, -1) |
|
|
| |
| timesteps_tokens = self.timestep_proj(timesteps_embed) |
|
|
| |
| timesteps_tokens = timesteps_tokens.view( |
| 1, self.num_timesteps * self.config.num_tokens, self.token_size |
| ) |
|
|
| return timesteps_tokens |
|
|
| @cached_property |
| def time_deltas_sec(self) -> torch.Tensor: |
| return torch.arange(0, self.num_timesteps, 1, dtype=torch.float32) * self.config.time_delta_sec |
|
|
| @cached_property |
| def time_deltas_norm(self) -> torch.Tensor: |
| |
| if self.time_deltas_sec.shape[0] == 1: |
| |
| time_deltas_norm = self.time_deltas_sec |
| else: |
| time_deltas_norm = self.time_deltas_sec / self.time_deltas_sec.max() |
| return time_deltas_norm.detach() |
|
|
|
|
| |
| |
| class TrajectoryVLA(PrismaticForConditionalGeneration): |
|
|
|
|
| config_class: PretrainedConfig = TrajectoryVLAConfig |
|
|
| def __init__(self, config: TrajectoryVLAConfig) -> None: |
| super().__init__(config.prismatic_config) |
|
|
| self.control_tokenizer = WaypointTokenizer(self.llm_backbone.tokenizer) |
| self.timestep_proj = TimestepProjModule( |
| config.timestep_proj_config, |
| num_timesteps=config.num_timesteps, |
| token_size=config.token_size, ) |
| self.num_timesteps = config.num_timesteps |
| self.token_proj = TokenProjector(config.token_proj_config) |
| self.transformer = DETR(config.transformer_config) |
| self.token_size = config.token_size |
| self.rotation_components = config.rotation_components |
| |
| |
| self.translation_proj = torch.nn.Sequential( |
| torch.nn.Linear(in_features=config.token_size, out_features=config.token_size // 2), |
| torch.nn.ReLU(), |
| torch.nn.Linear(in_features=config.token_size // 2, out_features=3), |
| ) |
| self.rotation_proj = torch.nn.Sequential( |
| torch.nn.Linear(in_features=config.token_size, out_features=config.token_size // 2), |
| torch.nn.ReLU(), |
| torch.nn.Linear( |
| in_features=config.token_size // 2, out_features=config.rotation_components |
| ), |
| ) |
|
|
| self.gripper_proj = torch.nn.Sequential( |
| torch.nn.Linear(in_features=config.token_size, out_features=config.token_size // 2), |
| torch.nn.ReLU(), |
| torch.nn.Linear(in_features=config.token_size // 2, out_features=1), |
| ) |
|
|
| def _pack_waypointer_input(self, input_ids: torch.Tensor, vlm_output: PrismaticCausalLMOutputWithPast,vit_tokens,fused_attention_mask) -> WaypointerInput: |
| |
| |
| projected_tokens = vlm_output.hidden_states[-1] |
|
|
| control_tokens = self._extract_control_tokens(input_ids, projected_tokens) |
|
|
| num_image_tokens = vit_tokens.shape[1] |
| |
| llm_image_tokens = projected_tokens[..., 1 : 1 + num_image_tokens, :] |
|
|
|
|
| return WaypointerInput( |
| vit_tokens=vit_tokens, |
| llm_image_tokens=llm_image_tokens, |
| control_tokens=control_tokens, |
| llm_tokens=projected_tokens, |
| attn_mask=fused_attention_mask, |
| ) |
|
|
| def predict_tracks(self,inputs): |
|
|
| vlm_output,vit_tokens,fused_attention_mask = super().forward(**inputs,output_hidden_states=True,output_attentions=True,return_dict=True) |
| waypointer_input = self._pack_waypointer_input(inputs['input_ids'], vlm_output,vit_tokens,fused_attention_mask) |
| waypoint_output = self._waypointer_forward(waypointer_input) |
| translation, rotation, gripper = torch.split( |
| waypoint_output, [3, self.rotation_components, 1], dim=-1 ) |
| translation, rotation, gripper = self.process_output(translation, rotation, gripper) |
| return translation, rotation, gripper |
| def process_output(self,translation,rotation,gripper): |
| |
| euler_angles = [] |
| for matrix in rotation[0]: |
| |
| rotation_obj = R.from_matrix(matrix.view(3, 3).detach().cpu().float().numpy().squeeze()) |
| |
| euler_angle = rotation_obj.as_euler('xyz', degrees=False) |
| euler_angles.append(euler_angle) |
|
|
| translation = translation.detach().cpu().float().numpy().squeeze() |
| |
| gripper = np.round(torch.sigmoid(gripper).detach().cpu().float().numpy().squeeze()) |
| return translation,euler_angles,gripper |
|
|
| def _extract_control_tokens(self, input_ids: torch.Tensor, output_tokens: torch.Tensor) -> torch.Tensor: |
| """ |
| Extract the action tokens from the LLM output sequence. Assumes the following order |
| [image_tokens, language_tokens, action_tokens, padding] |
| |
| Args: |
| input_ids: IDs of the tokens in text input sequence; shape [B, S] |
| output_tokens: Token sequence output from LLM; shape [B, L, token_size]. Note the length is |
| different from input_ids as it also contains image tokens |
| Returns: |
| torch.Tensor of shape [B, 7, token_size] containing only action tokens |
| """ |
|
|
| assert input_ids.ndim == 2 |
| assert output_tokens.ndim == 3 |
| batch, in_seq_len, out_seq_len = *input_ids.shape, output_tokens.shape[1] |
|
|
| device = input_ids.device |
|
|
| num_control_tokens = self.control_tokenizer.num_control_tokens |
|
|
| control_token_ids = torch.from_numpy( |
| self.control_tokenizer.control_token_ids |
| ) |
| control_token_ids = control_token_ids.to(dtype=input_ids.dtype, device=input_ids.device) |
| is_control_token = torch.any( |
| input_ids.unsqueeze(-1) == control_token_ids.view(1, 1, -1), |
| dim=-1, |
| ) |
| if not torch.all(mask := is_control_token.sum(dim=-1) == num_control_tokens): |
| raise RuntimeError( |
| f"Can't properly detect control tokens with ids {control_token_ids} of len=" |
| f"{len(control_token_ids)} in input_ids {input_ids}. Rows mask: {mask}" |
| ) |
|
|
| |
| tokens_mask = torch.cat( |
| [ |
| torch.zeros(batch, out_seq_len - in_seq_len, dtype=torch.bool, device=device), |
| is_control_token.to(torch.bool), |
| ], |
| dim=1, |
| ) |
|
|
| control_tokens = output_tokens[tokens_mask] |
| control_tokens = control_tokens.view( |
| batch, num_control_tokens, output_tokens.shape[-1] |
| ) |
|
|
| return control_tokens |
|
|
| def _waypointer_forward(self, inputs:WaypointerInput): |
|
|
| timesteps_tokens = self.timestep_proj() |
|
|
| |
| llm_tokens = self.token_proj(inputs) |
|
|
| |
| output_tokens = self.transformer( |
| feature_tokens=llm_tokens, query_tokens=timesteps_tokens, attn_mask=None |
| ) |
|
|
| output_tokens = output_tokens.view( |
| -1, self.num_timesteps, 3 * self.token_size |
| ) |
|
|
| |
| |
| translation_tokens, rotation_tokens, gripper_tokens = torch.split( |
| output_tokens, [self.token_size] * 3, dim=-1 |
| ) |
|
|
| translation = self.translation_proj(translation_tokens) |
| rotation = self.rotation_proj(rotation_tokens) |
| gripper = self.gripper_proj(gripper_tokens) |
|
|
| output = torch.cat( |
| [translation, rotation, gripper], dim=-1 |
| ) |
|
|
| return output |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| @staticmethod |
| def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str: |
| if unnorm_key is None and len(norm_stats) != 1: |
| raise ValueError( |
| f"Your model was trained on more than one dataset. " |
| f"Please pass a `unnorm_key` from the following options to choose the statistics used for " |
| f"de-normalizing actions: {norm_stats.keys()}" |
| ) |
|
|
| |
| unnorm_key = unnorm_key if unnorm_key is not None else next(iter(norm_stats.keys())) |
| if unnorm_key not in norm_stats: |
| raise ValueError( |
| f"The `unnorm_key` you chose ({unnorm_key = }) is not in the available statistics. " |
| f"Please choose from: {norm_stats.keys()}" |
| ) |
|
|
| return unnorm_key |
|
|
| def get_action_dim(self, unnorm_key: Optional[str] = None) -> int: |
| """Get the dimensionality of the policy's action space.""" |
| unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) |
| return len(self.norm_stats[unnorm_key]["action"]["q01"]) |
|
|
| def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]: |
| """Get all the logged statistics for the given dataset.""" |
| unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) |
| return self.norm_stats[unnorm_key]["action"] |
|
|
| def remove_waypointer_prefix(ckpt): |
| new_state_dict = {} |
| for key, value in ckpt.items(): |
| |
| if key.startswith('waypointer.'): |
| new_key = key[len('waypointer.'):] |
| else: |
| new_key = key |
| new_state_dict[new_key] = value |
| return new_state_dict |
|
|
| def image_processor(image): |
| image_resolution = (3,224,224) |
| image = image.resize(image_resolution[1:], resample=Image.Resampling.LANCZOS) |
|
|
| def read_pt(pt_path): |
| data = torch.load(pt_path) |
| return data |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if __name__ == "__main__": |
| AutoConfig.register("prismatic",PrismaticConfig) |
| AutoConfig.register("trajectoryvla",TrajectoryVLAConfig) |
| AutoModel.register('prismatic',PrismaticForConditionalGeneration) |
| AutoModel.register('trajectoryvla',TrajectoryVLA) |
| prismatic_config_dict = { |
| "vision_backbone_id":"dinosiglip-vit-so-224px", |
| "llm_backbone_id":"llama2-7b-pure", |
| "arch_specifier": "no-align+gelu-mlp", |
| "use_fused_vision_backbone" :True, |
| "image_resize_strategy" : "letterbox", |
| "text_config" : None, |
| "llm_max_length" : 2048, |
| "pad_token_id" :32000, |
| "pad_to_multiple_of" : 64, |
| "output_projector_states" : False, |
| "return_dict": False, |
| } |
|
|
| token_proj_config = { |
| "vit_tokens_layers": [2176, 1024], |
| "control_tokens_layers": [4096, 2048, 1024], |
| "image_tokens_mode": 'vit', |
| 'llm_image_tokens_layers': [] |
| } |
| timestep_proj_config = { |
| "pos_embed_scale": 8, |
| "proj_layers": [128,512,1024], |
| "time_delta_sec": 0.1, |
| "num_tokens":3 |
| } |
| pos_embed_config = { |
| "num_embeddings": 300, |
| "embedding_dim": 1024 |
| } |
| encoder_block_config = { |
| "feature_size": 1024, |
| "head_dim": 64, |
| "num_heads": 16 |
| } |
| decoder_block_config = { |
| "feature_size": 1024, |
| "head_dim": 64, |
| "num_heads": 16, |
| "dropout": 0.0 |
| } |
| transformer_config = { |
| "pos_embed_config": pos_embed_config, |
| "encoder_block_config": encoder_block_config, |
| "decoder_block_config": decoder_block_config, |
| "num_blocks": 2 |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| TrajectoryVlaConfig_config = { |
| "prismatic_config":prismatic_config_dict, |
| "token_size": 1024, |
| "cheat": False, |
| "num_timesteps": 6, |
| "rotation_components": 9, |
| "seperate_control_proj": True, |
| "timestep_proj_config": timestep_proj_config, |
| "token_proj_config": token_proj_config, |
| "transformer_config": transformer_config, |
| "num_timestep_tokens": 3, |
| } |
|
|
| |
| |
| |
|
|
| |
| model_config = TrajectoryVLAConfig( **TrajectoryVlaConfig_config) |
| |
|
|
| model = TrajectoryVLA(model_config) |
| model = model.to(dtype=torch.bfloat16) |
| model = model.to('cuda') |
| model.eval() |
|
|
| |
| |
| |
|
|
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
|
|
|
|