| from operator import length_hint |
| import random |
| import bisect |
| import copy |
| import torch |
| import transformers |
| from torch.utils.data import get_worker_info |
| from omegaconf import OmegaConf |
| import torchvision.transforms.functional as F |
| from dataclasses import dataclass, field |
| from typing import Dict, Optional, Sequence, List |
| from torch.utils.data import Dataset, ConcatDataset |
|
|
| from llava.datasets.registry import build_from_cfg |
| from llava.datasets.builder import DATASETS |
| from llava.datasets.data_cfgs import data_configs |
| from llava.train.arguments import DataArguments |
| from llava.model.preprocessor import preprocess_multimodal, preprocess |
| from llava.constants import IGNORE_INDEX |
| from llava.utils import DatasetIter, get_world_size, get_rank, master_print |
| from transformers import CLIPImageProcessor, SiglipImageProcessor |
|
|
| class LazySupervisedDataset(Dataset): |
| """Dataset for supervised fine-tuning.""" |
|
|
| def __init__(self, data_cfg: str, |
| tokenizer: transformers.PreTrainedTokenizer, |
| data_args: DataArguments, |
| num_workers: int): |
|
|
| super(LazySupervisedDataset, self).__init__() |
| dataset_config = OmegaConf.load(data_cfg) |
|
|
| self.tokenizer = tokenizer |
| self.data_args = data_args |
|
|
| self.datasets, self.sample_ratios = list(), list() |
| for ds in list(dataset_config.datasets.keys()): |
| ds_cfg = dataset_config.datasets[ds] |
| external_args = {} |
| for key, value in ds_cfg.items(): |
| external_args[key] = value |
| args_ = copy.deepcopy(vars(data_args)) |
| data_args_copy = type('DataArguments', (object,), args_) |
| dataset = build_from_cfg(ds, data_args_copy, DATASETS, default_args=external_args) |
| self.datasets.append(dataset) |
| if 'sample_ratio' in ds_cfg: |
| self.sample_ratios.append(ds_cfg.sample_ratio) |
|
|
| if len(self.sample_ratios) != len(self.datasets): |
| self.sample_ratios = [1.0] * len(self.sample_ratios) |
|
|
| self.sample_ratios = [float(ratio) / sum(self.sample_ratios) for ratio in self.sample_ratios] |
| self.ds_iters = [DatasetIter(len(dataset), get_world_size(), get_rank(), num_workers) |
| for dataset in self.datasets] |
| def __len__(self): |
| |
| max_ds = sorted([int(len(ds) / ratio) for (ds, ratio) in zip(self.datasets, self.sample_ratios)], reverse=True)[0] |
|
|
| return max_ds |
|
|
| def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
| worker_info = get_worker_info() |
|
|
| ds_idx = random.choices(range(len(self.datasets)), self.sample_ratios, k=1)[0] |
|
|
| item = None |
| while item is None: |
| item_id = self.ds_iters[ds_idx].increment(worker_info.id) |
| |
| item = self.datasets[ds_idx].__getitem__(item_id) |
|
|
| sources = item |
| if isinstance(i, int): |
| sources = [sources] |
| assert len(sources) == 1, "Don't know why it is wrapped to a list" |
| if 'images' in sources[0]: |
| images = sources[0]['images'] |
| conversations = copy.deepcopy([e['conversations'] for e in sources]) |
|
|
| sources = preprocess_multimodal( |
| conversations, self.data_args) |
| else: |
| sources = copy.deepcopy([e["conversations"] for e in sources]) |
|
|
| data_dict = preprocess( |
| sources, |
| self.tokenizer, |
| has_image=('images' in item)) |
|
|
| if isinstance(i, int): |
| data_dict = dict(input_ids=data_dict["input_ids"][0], |
| labels=data_dict["labels"][0]) |
|
|
| if images is not None and len(images) > 0: |
| data_dict["images"] = images |
| elif self.data_args.is_multimodal: |
| |
| img_size = self.data_args.image_processor.img_size |
| |
| if getattr(self.data_args, 'image_aspect_ratio', 'square') == 'anyres': |
| data_dict['images'] = [torch.zeros(1, 3, img_size, img_size)] |
| else: |
| data_dict['images'] = [torch.zeros(3, img_size, img_size)] |
| data_dict['labels'][:] = IGNORE_INDEX |
| return data_dict |
| |
|
|
|
|
|
|
| @dataclass |
| class DataCollatorForSupervisedDataset(object): |
| """Collate examples for supervised fine-tuning.""" |
| tokenizer: transformers.PreTrainedTokenizer |
|
|
| def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
| input_ids, labels = tuple([instance[key] for instance in instances] |
| for key in ("input_ids", "labels")) |
| input_ids = torch.nn.utils.rnn.pad_sequence( |
| input_ids, |
| batch_first=True, |
| padding_value=self.tokenizer.pad_token_id) |
| labels = torch.nn.utils.rnn.pad_sequence(labels, |
| batch_first=True, |
| padding_value=IGNORE_INDEX) |
| input_ids = input_ids[:, :self.tokenizer.model_max_length] |
| labels = labels[:, :self.tokenizer.model_max_length] |
| batch = dict( |
| input_ids=input_ids, |
| labels=labels, |
| attention_mask=input_ids.ne(self.tokenizer.pad_token_id), |
| ) |
|
|
| if 'images' in instances[0]: |
| images = [instance['images'] for instance in instances] |
| images_data = [] |
| for imgs in images: |
| if all(x is not None and x.shape == imgs[0].shape for x in imgs): |
| imgs = torch.stack(imgs) |
| else: |
| imgs = [x for x in imgs if x is not None] |
| imgs = [x for x in imgs if x.shape == imgs[0].shape] |
| imgs = torch.stack(imgs) |
|
|
| images_data.append(imgs) |
|
|
| batch["images"] = images_data |
|
|
| if 'images' not in batch or len(batch['images']) == 0: |
| print("images not in batch") |
|
|
| return batch |
|
|
|
|
| def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, |
| data_args, |
| num_workers) -> Dict: |
| """Make dataset and collator for supervised fine-tuning.""" |
| train_dataset = LazySupervisedDataset(data_cfg=data_args.dataset_config, |
| tokenizer=tokenizer, |
| data_args=data_args, |
| num_workers=num_workers) |
|
|
| for ds, ratio in zip(train_dataset.datasets, train_dataset.sample_ratios): |
| master_print(f"==> Real epoch of {ds.name} is {round(len(train_dataset) * ratio / len(ds), 2)} epochs.") |
|
|
| data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) |
| return dict(train_dataset=train_dataset, |
| eval_dataset=None, |
| data_collator=data_collator) |
|
|
|
|
|
|
| class SupervisedConcatDataset(ConcatDataset): |
| r"""Dataset as a concatenation of multiple datasets. |
| |
| This class is useful to assemble different existing datasets. |
| |
| Args: |
| datasets (sequence): List of datasets to be concatenated |
| """ |
|
|
| datasets: List[Dataset] |
| cumulative_sizes: List[int] |
|
|
|
|
| def __init__(self, datasets: List[Dataset], |
| tokenizer: transformers.PreTrainedTokenizer, |
| data_args: DataArguments) -> None: |
| |
| super().__init__(datasets) |
| self.tokenizer = tokenizer |
| self.data_args = data_args |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| @property |
| def modality_lengths(self): |
| length_list = [] |
| token_per_image = getattr(self.data_args, 'num_token_per_image', 32) |
| |
| |
| |
| |
| |
| for idx in range(len(self)): |
| dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) |
| if dataset_idx == 0: |
| sample_idx = idx |
| else: |
| sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] |
| item = self.datasets[dataset_idx].annotation[sample_idx] |
| conversations = self.datasets[dataset_idx].text_preprocess(item) |
| cur_len = sum([len(conv['value'].split()) for conv in conversations]) |
| if self.datasets[dataset_idx].type == 'images': |
| cur_len += token_per_image |
| else: |
| cur_len += token_per_image * self.data_args.num_segments |
| length_list.append(cur_len) |
| return length_list |
|
|
| def __len__(self): |
| return self.cumulative_sizes[-1] |
|
|
| def __getitem__(self, idx): |
| if idx < 0: |
| if -idx > len(self): |
| raise ValueError("absolute value of index should not exceed dataset length") |
| idx = len(self) + idx |
| dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) |
| if dataset_idx == 0: |
| sample_idx = idx |
| else: |
| sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] |
| item = self.datasets[dataset_idx][sample_idx] |
| sources = item |
| if isinstance(idx, int): |
| sources = [sources] |
| assert len(sources) == 1, "Don't know why it is wrapped to a list" |
| if 'images' in sources[0]: |
| images = sources[0]['images'] |
| conversations = copy.deepcopy([e['conversations'] for e in sources]) |
|
|
| sources = preprocess_multimodal( |
| conversations, self.data_args) |
| else: |
| sources = copy.deepcopy([e["conversations"] for e in sources]) |
|
|
| data_dict = preprocess( |
| sources, |
| self.tokenizer, |
| has_image=('images' in item)) |
|
|
| if isinstance(idx, int): |
| data_dict = dict(input_ids=data_dict["input_ids"][0], |
| labels=data_dict["labels"][0]) |
|
|
| if images is not None and len(images) > 0: |
| data_dict["images"] = images |
| elif self.data_args.is_multimodal: |
| |
| if isinstance(self.data_args.image_processor, SiglipImageProcessor): |
| img_size = self.data_args.image_processor.size['height'] |
| elif isinstance(self.data_args.image_processor, CLIPImageProcessor): |
| img_size = self.data_args.image_processor.crop_size['height'] |
| else: |
| img_size = self.data_args.image_processor.img_size |
| |
| if getattr(self.data_args, 'image_aspect_ratio', 'square') == 'anyres': |
| data_dict['images'] = [torch.zeros(1, 3, img_size, img_size)] |
| else: |
| data_dict['images'] = [torch.zeros(3, img_size, img_size)] |
| data_dict['labels'][:] = IGNORE_INDEX |
| return data_dict |
|
|
|
|
| def make_supervised_data_module_concatdataset(tokenizer: transformers.PreTrainedTokenizer, |
| data_args, |
| num_workers) -> Dict: |
| """Make dataset and collator for supervised fine-tuning.""" |
| datasets = [] |
| dataset_config = OmegaConf.load(data_args.dataset_config) |
| for ds in list(dataset_config.datasets.keys()): |
| ds_cfg = dataset_config.datasets[ds] |
| external_args = {} |
| for key, value in ds_cfg.items(): |
| external_args[key] = value |
| args_ = copy.deepcopy(vars(data_args)) |
| data_args_copy = type('DataArguments', (object,), args_) |
| dataset = build_from_cfg(ds, data_args_copy, DATASETS, default_args=external_args) |
| datasets.append(dataset) |
|
|
| train_dataset = SupervisedConcatDataset(datasets=datasets, |
| tokenizer=tokenizer, |
| data_args=data_args) |
|
|
| |
| |
|
|
| data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) |
| return dict(train_dataset=train_dataset, |
| eval_dataset=None, |
| data_collator=data_collator) |
|
|
|
|