| """ |
| datasets.py |
| |
| PyTorch Dataset Definitions for Prismatic models; supports processing for both the `align` and `finetune` stages, with |
| utilities for formatting conversations during the `finetune` stage subject to the given LLM backbone's expected |
| formatting (e.g., SYS_PROMPT + USER: ... ASSISTANT: ... for Vicuña v1.5 Chat models). |
| |
| We currently only support Map-style Datasets; assumes that all files (annotations, images) are on local disk, and that |
| random access image reading is relatively cheap/fast. |
| """ |
|
|
| import copy |
| import json |
| from pathlib import Path |
| from typing import Dict, List, Tuple, Type |
|
|
| import torch |
| from PIL import Image |
| from torch.utils.data import Dataset |
| from transformers import CodeGenTokenizerFast, LlamaTokenizerFast, PreTrainedTokenizerBase |
|
|
| from prismatic.models.backbones.llm.prompting import PromptBuilder |
| from prismatic.models.backbones.vision import ImageTransform |
|
|
| |
| IGNORE_INDEX = -100 |
|
|
|
|
| class AlignDataset(Dataset[Dict[str, torch.Tensor]]): |
| def __init__( |
| self, |
| chat_json: Path, |
| image_dir: Path, |
| image_transform: ImageTransform, |
| tokenizer: PreTrainedTokenizerBase, |
| ) -> None: |
| super().__init__() |
| self.chat_json, self.image_dir = chat_json, image_dir |
| self.image_transform, self.tokenizer = image_transform, tokenizer |
| self.dataset_type = "align" |
|
|
| |
| self.prompt_template = "{caption}" + self.tokenizer.eos_token |
|
|
| |
| with open(self.chat_json, "r") as f: |
| self.examples = json.load(f) |
|
|
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| """ |
| Following the *actual* code executed from the LLaVa codebase, during the "align" phase, we actually discard |
| the "prompt" from the human, and instead directly predict the caption from the image. |
| |
| As a concrete example given the "raw data" for the first example: |
| example = self.examples[0]["conversations"]` = { |
| [ |
| {"from": "human", "value": "Render a clear and concise summary of the photo.\n<image>"}, |
| {"from": "gpt", "value": "select luxury furniture 3 - inch gel memory foam mattress topper"} |
| ] |
| } |
| |
| Return =>> self.tokenizer("<image> select luxury furniture 3 - inch gel memory foam mattress topper\n") |
| |
| :param idx: Index to retrieve from the dataset. |
| |
| :return: Dictionary of {"pixel_values": torch.Tensor, "input_ids": torch.Tensor, "labels": torch.Tensor} |
| """ |
| image_path, conversation = Path(self.examples[idx]["image"]), self.examples[idx]["conversations"] |
| assert (len(conversation) == 2) and ("<image>" not in conversation[-1]["value"]), "Unexpected text!" |
|
|
| |
| caption = self.prompt_template.format(caption=conversation[-1]["value"].strip()) |
|
|
| |
| |
| |
| |
| |
| |
| input_ids = self.tokenizer(caption, truncation=True, return_tensors="pt").input_ids[0] |
| labels = copy.deepcopy(input_ids) |
|
|
| |
| labels[0] = IGNORE_INDEX |
|
|
| |
| pixel_values = self.image_transform(Image.open(self.image_dir / image_path).convert("RGB")) |
|
|
| return dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels) |
|
|
| def get_modality_lengths(self, n_image_patches: int) -> List[Tuple[bool, int]]: |
| """Get a list of modalities (unimodal / text-only vs. multimodal) and length of conversations per example.""" |
| modality_lengths = [] |
| for example in self.examples: |
| is_multimodal = "image" in example |
| n_words = sum([len(turn["value"].replace("<image>", "").split()) for turn in example["conversations"]]) |
| modality_lengths.append((is_multimodal, (n_image_patches + n_words) if is_multimodal else n_words)) |
| return modality_lengths |
|
|
| def __len__(self) -> int: |
| return len(self.examples) |
|
|
|
|
| class FinetuneDataset(Dataset[Dict[str, torch.Tensor]]): |
| def __init__( |
| self, |
| instruct_json: Path, |
| image_dir: Path, |
| image_transform: ImageTransform, |
| tokenizer: PreTrainedTokenizerBase, |
| prompt_builder_fn: Type[PromptBuilder], |
| ) -> None: |
| super().__init__() |
| self.instruct_json, self.image_dir = instruct_json, image_dir |
| self.image_transform, self.tokenizer = image_transform, tokenizer |
| self.prompt_builder_fn = prompt_builder_fn |
| self.dataset_type = "finetune" |
|
|
| |
| with open(self.instruct_json, "r") as f: |
| self.examples = json.load(f) |
|
|
| |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| """ |
| Unlike the *align* stage handling, for the *finetune* stage, we actually need to handle multiple "turns" of |
| dialog grounded in a single image. |
| |
| To do this, we leverage the `prompt_builder_fn` which instantiates a PromptBuilder object. By calling the |
| methods for adding turns and getting a prompt, we ensure proper formatting and consistency for each example. |
| |
| :param idx: Index to retrieve from the dataset. |
| |
| :return: Dictionary of {"pixel_values": torch.Tensor, "input_ids": torch.Tensor, "labels": torch.Tensor} |
| """ |
| conversation = self.examples[idx]["conversations"] |
|
|
| |
| prompt_builder, input_ids, labels = self.prompt_builder_fn(model_family="prismatic"), [], [] |
| for turn_idx, turn in enumerate(conversation): |
| |
| msg = prompt_builder.add_turn(turn["from"], turn["value"]) |
|
|
| |
| if isinstance(self.tokenizer, LlamaTokenizerFast): |
| msg = msg.rstrip() |
|
|
| |
| elif isinstance(self.tokenizer, CodeGenTokenizerFast): |
| pass |
|
|
| else: |
| raise ValueError(f"Tokenizer of type `{type(self.tokenizer)}` is not explicitly handled!") |
|
|
| |
| turn_input_ids = self.tokenizer(msg, add_special_tokens=turn_idx == 0).input_ids |
|
|
| |
| turn_labels = ( |
| [IGNORE_INDEX for _ in range(len(turn_input_ids))] if (turn_idx % 2) == 0 else list(turn_input_ids) |
| ) |
|
|
| |
| input_ids.extend(turn_input_ids) |
| labels.extend(turn_labels) |
|
|
| |
| |
| input_ids, labels = torch.tensor(input_ids), torch.tensor(labels) |
|
|
| |
| input_ids, labels = input_ids[: self.tokenizer.model_max_length], labels[: self.tokenizer.model_max_length] |
|
|
| |
| if "image" in self.examples[idx]: |
| image_path = Path(self.examples[idx]["image"]) |
|
|
| |
| labels[0] = IGNORE_INDEX |
|
|
| |
| pixel_values = self.image_transform(Image.open(self.image_dir / image_path).convert("RGB")) |
|
|
| return dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels) |
|
|
| else: |
| |
| return dict(pixel_values=None, input_ids=input_ids, labels=labels) |
|
|
| def get_modality_lengths(self) -> List[Tuple[bool, int]]: |
| """Get a list of modalities (unimodal / text-only vs. multimodal) and length of conversations per example.""" |
| modality_lengths = [] |
| for example in self.examples: |
| is_multimodal = "image" in example |
| n_words = sum([len(turn["value"].split()) for turn in example["conversations"]]) |
| modality_lengths.append((is_multimodal, n_words)) |
| return modality_lengths |
|
|
| def __len__(self) -> int: |
| return len(self.examples) |
|
|