| """
|
| materialize.py
|
|
|
| Factory class for initializing pretraining datasets on a per-VLM basis; provides and exports individual functions for
|
| clear control flow.
|
| """
|
|
|
| from typing import Tuple, Type
|
|
|
| from torch.utils.data import Dataset
|
| from transformers import PreTrainedTokenizerBase
|
|
|
| from prismatic.conf import DatasetConfig
|
| from prismatic.models.backbones.llm.prompting import PromptBuilder
|
| from prismatic.models.backbones.vision import ImageTransform
|
| from prismatic.preprocessing.datasets import AlignDataset, FinetuneDataset
|
| from prismatic.util.data_utils import PaddedCollatorForLanguageModeling
|
|
|
|
|
| DATASET_INITIALIZER = {"align": AlignDataset, "finetune": FinetuneDataset, "full-finetune": FinetuneDataset}
|
|
|
|
|
| def get_dataset_and_collator(
|
| stage: str,
|
| dataset_cfg: DatasetConfig,
|
| image_transform: ImageTransform,
|
| tokenizer: PreTrainedTokenizerBase,
|
| prompt_builder_fn: Type[PromptBuilder],
|
| default_image_resolution: Tuple[int, int, int],
|
| padding_side: str = "right",
|
| ) -> Tuple[Dataset, PaddedCollatorForLanguageModeling]:
|
| dataset_cls = DATASET_INITIALIZER[stage]
|
| dataset_root_dir = dataset_cfg.dataset_root_dir
|
| collator = PaddedCollatorForLanguageModeling(
|
| tokenizer.model_max_length, tokenizer.pad_token_id, default_image_resolution, padding_side=padding_side
|
| )
|
|
|
|
|
| if stage == "align":
|
| annotation_json, image_dir = dataset_cfg.align_stage_components
|
| dataset = dataset_cls(
|
| dataset_root_dir / annotation_json, dataset_root_dir / image_dir, image_transform, tokenizer
|
| )
|
| return dataset, collator
|
|
|
| elif stage == "finetune":
|
| annotation_json, image_dir = dataset_cfg.finetune_stage_components
|
| dataset = dataset_cls(
|
| dataset_root_dir / annotation_json,
|
| dataset_root_dir / image_dir,
|
| image_transform,
|
| tokenizer,
|
| prompt_builder_fn=prompt_builder_fn,
|
| )
|
| return dataset, collator
|
|
|
| elif stage == "full-finetune":
|
| annotation_json, image_dir = dataset_cfg.finetune_stage_components
|
| dataset = dataset_cls(
|
| dataset_root_dir / annotation_json,
|
| dataset_root_dir / image_dir,
|
| image_transform,
|
| tokenizer,
|
| prompt_builder_fn=prompt_builder_fn,
|
| )
|
| return dataset, collator
|
|
|
| else:
|
| raise ValueError(f"Stage `{stage}` is not supported!")
|
|
|