| """ |
| materialize.py |
| |
| Factory class for initializing Open-X RLDS-backed datasets, given specified data mixture parameters; provides and |
| exports individual functions for clear control flow. |
| """ |
|
|
| from pathlib import Path |
| from typing import Tuple, Type |
|
|
| from torch.utils.data import Dataset |
| from transformers import PreTrainedTokenizerBase |
|
|
| from prismatic.models.backbones.llm.prompting import PromptBuilder |
| from prismatic.models.backbones.vision import ImageTransform |
| from prismatic.util.data_utils import PaddedCollatorForActionPrediction |
| from prismatic.vla.action_tokenizer import ActionTokenizer |
| from prismatic.vla.datasets import EpisodicRLDSDataset, RLDSBatchTransform, RLDSDataset |
|
|
|
|
| def get_vla_dataset_and_collator( |
| data_root_dir: Path, |
| data_mix: str, |
| image_transform: ImageTransform, |
| tokenizer: PreTrainedTokenizerBase, |
| prompt_builder_fn: Type[PromptBuilder], |
| default_image_resolution: Tuple[int, int, int], |
| padding_side: str = "right", |
| predict_stop_token: bool = True, |
| shuffle_buffer_size: int = 100_000, |
| train: bool = True, |
| episodic: bool = False, |
| image_aug: bool = False, |
| ) -> Tuple[Dataset, ActionTokenizer, PaddedCollatorForActionPrediction]: |
| """Initialize RLDS Dataset (wraps TFDS), ActionTokenizer, and initialize transform/collation functions.""" |
| action_tokenizer = ActionTokenizer(tokenizer) |
| batch_transform = RLDSBatchTransform( |
| action_tokenizer, tokenizer, image_transform, prompt_builder_fn, predict_stop_token=predict_stop_token |
| ) |
| collator = PaddedCollatorForActionPrediction( |
| tokenizer.model_max_length, tokenizer.pad_token_id, padding_side=padding_side |
| ) |
|
|
| |
| cls = RLDSDataset if not episodic else EpisodicRLDSDataset |
| dataset = cls( |
| data_root_dir, |
| data_mix, |
| batch_transform, |
| resize_resolution=default_image_resolution[1:], |
| shuffle_buffer_size=shuffle_buffer_size, |
| train=train, |
| image_aug=image_aug, |
| ) |
|
|
| return dataset, action_tokenizer, collator |
|
|