| """ |
| data_utils.py |
| |
| General utilities and classes for facilitating data loading and collation. |
| """ |
|
|
| from dataclasses import dataclass |
| from typing import Callable, Dict, Sequence, Tuple |
|
|
| import numpy as np |
| import torch |
| from torch.nn.utils.rnn import pad_sequence |
|
|
| |
| IGNORE_INDEX = -100 |
|
|
|
|
| def tree_map(fn: Callable, tree: dict) -> dict: |
| """Maps a function over a nested dictionary.""" |
| return {k: tree_map(fn, v) if isinstance(v, dict) else fn(v) for k, v in tree.items()} |
|
|
|
|
| def tree_map_with_key(fn: Callable, tree: dict, keys: Sequence = ()) -> dict: |
| """Maps a function over a nested dictionary.""" |
| return { |
| k: tree_map_with_key(fn, v, (*keys, k)) if isinstance(v, dict) else fn((*keys, k), v) for k, v in tree.items() |
| } |
|
|
|
|
| @dataclass |
| class PaddedCollatorForLanguageModeling: |
| model_max_length: int |
| pad_token_id: int |
| default_image_resolution: Tuple[int, int, int] |
| padding_side: str = "right" |
| pixel_values_dtype: torch.dtype = torch.float32 |
|
|
| def __post_init__(self) -> None: |
| self.dummy_pixel_values = torch.zeros(self.default_image_resolution, dtype=self.pixel_values_dtype) |
|
|
| def __call__(self, instances: Sequence[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: |
| input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) |
| pixel_values = [instance["pixel_values"] for instance in instances] |
|
|
| |
| |
| input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id) |
| labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) |
|
|
| |
| input_ids, labels = input_ids[:, : self.model_max_length], labels[:, : self.model_max_length] |
|
|
| |
| attention_mask = input_ids.ne(self.pad_token_id) |
|
|
| |
|
|
| |
| multimodal_indices = torch.tensor( |
| [idx for idx in range(len(pixel_values)) if pixel_values[idx] is not None], dtype=torch.long |
| ) |
|
|
| |
| if len(multimodal_indices) == 0: |
| pixel_values = torch.stack([self.dummy_pixel_values for _ in range(len(input_ids))]) |
| elif isinstance(pv_example := pixel_values[multimodal_indices[0]], torch.Tensor): |
| pixel_values = torch.stack( |
| [ |
| pixel_values[idx] if idx in multimodal_indices else self.dummy_pixel_values |
| for idx in range(len(input_ids)) |
| ] |
| ) |
| elif isinstance(pv_example, dict): |
| pixel_values = { |
| k: torch.stack( |
| [ |
| pixel_values[idx][k] if idx in multimodal_indices else self.dummy_pixel_values |
| for idx in range(len(input_ids)) |
| ] |
| ) |
| for k in pv_example |
| } |
| else: |
| raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") |
|
|
| return dict( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| labels=labels, |
| multimodal_indices=multimodal_indices, |
| ) |
|
|
|
|
| @dataclass |
| class PaddedCollatorForActionPrediction: |
| model_max_length: int |
| pad_token_id: int |
| padding_side: str = "right" |
| pixel_values_dtype: torch.dtype = torch.float32 |
|
|
| def __call__(self, instances: Sequence[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: |
| input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) |
| pixel_values = [instance["pixel_values"] for instance in instances] |
| if "dataset_name" in instances[0]: |
| dataset_names = [instance["dataset_name"] for instance in instances] |
| else: |
| dataset_names = None |
|
|
| |
| |
| assert self.padding_side == "right", f"Invalid Tokenizer `{self.padding_side = }`" |
| input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id) |
| labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) |
|
|
| |
| input_ids, labels = input_ids[:, : self.model_max_length], labels[:, : self.model_max_length] |
|
|
| |
| attention_mask = input_ids.ne(self.pad_token_id) |
|
|
| |
| assert all([pv is not None for pv in pixel_values]), "Invalid VLA Example with `pixel_values = None`!" |
|
|
| |
| if isinstance(pixel_values[0], torch.Tensor): |
| if "pixel_values_wrist" in instances[0]: |
| pixel_values_wrist = [instance["pixel_values_wrist"] for instance in instances] |
| pixel_values = torch.cat((torch.stack(pixel_values), torch.stack(pixel_values_wrist)), dim=1) |
| else: |
| pixel_values = torch.stack(pixel_values) |
| else: |
| raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") |
|
|
| |
| trigger_pixel_values = None |
| if "trigger_pixel_values" in instances[0]: |
| trigger_pixel_values_list = [instance["trigger_pixel_values"] for instance in instances] |
| if "trigger_pixel_values_wrist" in instances[0]: |
| trigger_pixel_values_wrist = [instance["trigger_pixel_values_wrist"] for instance in instances] |
| trigger_pixel_values = torch.cat((torch.stack(trigger_pixel_values_list), torch.stack(trigger_pixel_values_wrist)), dim=1) |
| else: |
| trigger_pixel_values = torch.stack(trigger_pixel_values_list) |
|
|
| |
| actions = [torch.from_numpy(np.copy(instance["actions"])) for instance in instances] |
| actions = torch.stack(actions) |
|
|
| |
| if "proprio" in instances[0]: |
| proprio = [instance["proprio"] for instance in instances] |
| proprio = torch.Tensor(np.squeeze(np.stack(proprio))) |
| else: |
| proprio = None |
|
|
| output = dict( |
| pixel_values=pixel_values, |
| proprio=proprio, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| labels=labels, |
| actions=actions, |
| ) |
| |
| |
| if trigger_pixel_values is not None: |
| output["trigger_pixel_values"] = trigger_pixel_values |
| if dataset_names is not None: |
| output["dataset_names"] = dataset_names |
| return output |
|
|
| @dataclass |
| class PaddedCollatorForDebug: |
| def __call__(self, instances: Sequence[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: |
| output = dict() |
| return output |