| """
|
| data_utils.py
|
|
|
| General utilities and classes for facilitating data loading and collation.
|
| """
|
|
|
| from dataclasses import dataclass
|
| from typing import Callable, Dict, Sequence, Tuple
|
|
|
| import numpy as np
|
| import torch
|
| from torch.nn.utils.rnn import pad_sequence
|
|
|
|
|
| IGNORE_INDEX = -100
|
|
|
|
|
| def tree_map(fn: Callable, tree: dict) -> dict:
|
| """Maps a function over a nested dictionary."""
|
| return {k: tree_map(fn, v) if isinstance(v, dict) else fn(v) for k, v in tree.items()}
|
|
|
|
|
| def tree_map_with_key(fn: Callable, tree: dict, keys: Sequence = ()) -> dict:
|
| """Maps a function over a nested dictionary."""
|
| return {
|
| k: tree_map_with_key(fn, v, (*keys, k)) if isinstance(v, dict) else fn((*keys, k), v) for k, v in tree.items()
|
| }
|
|
|
|
|
| @dataclass
|
| class PaddedCollatorForLanguageModeling:
|
| model_max_length: int
|
| pad_token_id: int
|
| default_image_resolution: Tuple[int, int, int]
|
| padding_side: str = "right"
|
| pixel_values_dtype: torch.dtype = torch.float32
|
|
|
| def __post_init__(self) -> None:
|
| self.dummy_pixel_values = torch.zeros(self.default_image_resolution, dtype=self.pixel_values_dtype)
|
|
|
| def __call__(self, instances: Sequence[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
| input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
| pixel_values = [instance["pixel_values"] for instance in instances]
|
|
|
|
|
|
|
| input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id)
|
| labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
|
|
|
|
| input_ids, labels = input_ids[:, : self.model_max_length], labels[:, : self.model_max_length]
|
|
|
|
|
| attention_mask = input_ids.ne(self.pad_token_id)
|
|
|
|
|
|
|
|
|
| multimodal_indices = torch.tensor(
|
| [idx for idx in range(len(pixel_values)) if pixel_values[idx] is not None], dtype=torch.long
|
| )
|
|
|
|
|
| if len(multimodal_indices) == 0:
|
| pixel_values = torch.stack([self.dummy_pixel_values for _ in range(len(input_ids))])
|
| elif isinstance(pv_example := pixel_values[multimodal_indices[0]], torch.Tensor):
|
| pixel_values = torch.stack(
|
| [
|
| pixel_values[idx] if idx in multimodal_indices else self.dummy_pixel_values
|
| for idx in range(len(input_ids))
|
| ]
|
| )
|
| elif isinstance(pv_example, dict):
|
| pixel_values = {
|
| k: torch.stack(
|
| [
|
| pixel_values[idx][k] if idx in multimodal_indices else self.dummy_pixel_values
|
| for idx in range(len(input_ids))
|
| ]
|
| )
|
| for k in pv_example
|
| }
|
| else:
|
| raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}")
|
|
|
| return dict(
|
| pixel_values=pixel_values,
|
| input_ids=input_ids,
|
| attention_mask=attention_mask,
|
| labels=labels,
|
| multimodal_indices=multimodal_indices,
|
| )
|
|
|
|
|
| @dataclass
|
| class PaddedCollatorForActionPrediction:
|
| model_max_length: int
|
| pad_token_id: int
|
| padding_side: str = "right"
|
| pixel_values_dtype: torch.dtype = torch.float32
|
|
|
| def __call__(self, instances: Sequence[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
| input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
| pixel_values = [instance["pixel_values"] for instance in instances]
|
| if "dataset_name" in instances[0]:
|
| dataset_names = [instance["dataset_name"] for instance in instances]
|
| else:
|
| dataset_names = None
|
|
|
|
|
|
|
| assert self.padding_side == "right", f"Invalid Tokenizer `{self.padding_side = }`"
|
| input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id)
|
| labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
|
|
|
|
| input_ids, labels = input_ids[:, : self.model_max_length], labels[:, : self.model_max_length]
|
|
|
|
|
| attention_mask = input_ids.ne(self.pad_token_id)
|
|
|
|
|
| assert all([pv is not None for pv in pixel_values]), "Invalid VLA Example with `pixel_values = None`!"
|
|
|
|
|
| if isinstance(pixel_values[0], torch.Tensor):
|
| if "pixel_values_wrist" in instances[0]:
|
| pixel_values_wrist = [instance["pixel_values_wrist"] for instance in instances]
|
| pixel_values = torch.cat((torch.stack(pixel_values), torch.stack(pixel_values_wrist)), dim=1)
|
| else:
|
| pixel_values = torch.stack(pixel_values)
|
| else:
|
| raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}")
|
|
|
|
|
| actions = [torch.from_numpy(np.copy(instance["actions"])) for instance in instances]
|
| actions = torch.stack(actions)
|
|
|
|
|
| if "proprio" in instances[0]:
|
| proprio = [instance["proprio"] for instance in instances]
|
| proprio = torch.Tensor(np.squeeze(np.stack(proprio)))
|
| else:
|
| proprio = None
|
|
|
| output = dict(
|
| pixel_values=pixel_values,
|
| proprio=proprio,
|
| input_ids=input_ids,
|
| attention_mask=attention_mask,
|
| labels=labels,
|
| actions=actions,
|
| )
|
| if dataset_names is not None:
|
| output["dataset_names"] = dataset_names
|
| return output
|
|
|