| import logging |
|
|
| import torch |
| import datasets |
| import cv2 |
|
|
| import numpy as np |
| from base64 import b64decode |
| from io import BytesIO |
| from PIL import Image |
| from torch.utils.data import ConcatDataset |
| from llava.datasets.builder import DATASETS |
| from typing import Dict, Optional, Sequence, List |
| from llava.datasets.data_cfgs import data_configs |
| from llava.datasets.base_dataset import ImageTaskDataset |
| from llava.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_VIDEO_TOKEN |
| from llava.utils import master_print |
|
|
|
|
| class M3ITDataset(ImageTaskDataset): |
| def __init__(self, anno_path, data_args=None, name='m3it', selected_tasks=None): |
| super().__init__(anno_path, data_args, name) |
|
|
| self.selected_tasks = selected_tasks |
| dataset_list = [ |
| datasets.load_dataset("MMInstruction/M3IT", i, num_proc=16) for i in selected_tasks |
| ] |
| |
| target_dataset_list = [] |
| master_print('#' * 50) |
| for d in dataset_list: |
| try: |
| target_dataset_list.append(d['train']) |
| master_print(f"TASK {d['train']._info.config_name}, SIZE {len(d['train'])}") |
| except KeyError: |
| print(f"{d['train']._info.config_name} has no train set.") |
| self.dataset = ConcatDataset(target_dataset_list) |
| master_print(f"Finished loading dataset {name} {len(self.dataset)} samples...") |
|
|
| def __len__(self): |
| return len(self.dataset) |
|
|
| def text_preprocess(self, item, is_video=False) -> List[Dict[str, str]]: |
| instruction = item['instruction'] |
| question = item['inputs'] |
| answer = item['outputs'] |
|
|
| query = f"{instruction} {DEFAULT_IMAGE_TOKEN if not is_video else DEFAULT_VIDEO_TOKEN}" |
| if len(question) > 0: |
| query += question |
|
|
| conversations = [ |
| { |
| 'from': 'human', |
| 'value': query |
| }, |
| { |
| 'from': 'model', |
| 'value': answer |
| } |
| ] |
|
|
| return conversations |
|
|
| def bin2image(self, image_base64_str): |
| img = Image.open(BytesIO(b64decode(image_base64_str))).convert("RGB") |
| img = np.array(img) |
|
|
| if img.shape[2] != 3: |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) |
|
|
| img = Image.fromarray(img).convert('RGB') |
| img = self.preprocess_image(img) |
|
|
| return img |
|
|
| def vis_preprocess(self, image_base64_str_list) -> Image: |
| try: |
| images = list(map(self.bin2image, image_base64_str_list)) |
| formatted_images = [] |
| for image in images: |
| if isinstance(image, list): |
| formatted_images.extend(image) |
| else: |
| formatted_images.append(image) |
| return formatted_images |
| except Exception as e: |
| |
| return None |
|
|
| def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
| item = self.dataset[i] |
|
|
| img_data = item['image_base64_str'] |
|
|
| images = self.vis_preprocess(img_data) |
| if images is None: |
| return None |
|
|
| |
| is_video = True if len(images) > 0 else False |
|
|
| ret = { |
| 'images': images, |
| 'conversations': self.text_preprocess(item, is_video) |
| } |
|
|
| return ret |
|
|
|
|
| @DATASETS.register_obj |
| def m3it(data_args): |
| tasks = data_configs['m3it']['default_tasks'] |
| if 'tasks' in data_args.external_args: |
| tasks = data_args.external_args['tasks'] |
|
|
| return M3ITDataset(anno_path=None, data_args=data_args, selected_tasks=tasks) |
|
|