| import datasets |
| import torch |
| import re |
| import os |
| import json |
| from llava.datasets.builder import DATASETS |
| from pathlib import Path |
| import random |
| from typing import Dict, Optional, Sequence, List |
| from llava.datasets.data_cfgs import data_configs |
| from llava.datasets.base_dataset import ImageTaskDataset |
| from llava.datasets.prompts import cc_sbu_prompt |
| from llava.constants import DEFAULT_IMAGE_TOKEN |
| from llava.datasets.data_cfgs import data_configs |
| from llava.utils import master_print |
|
|
|
|
| class TextCapsDataset(ImageTaskDataset): |
| def __init__(self, anno_path=None, data_args=None, aux_args=None, name='TextCaps'): |
| with open(anno_path) as f: |
| self.annotation = json.load(f)['data'] |
| self.dataset_dir = Path(anno_path).parent |
| super().__init__(anno_path=anno_path, |
| data_args=data_args, |
| name=name) |
| |
| def __len__(self): |
| return len(self.annotation) |
| |
|
|
| def text_preprocess(self, item) -> List[Dict[str, str]]: |
| conversations = [] |
| conversations.extend([ |
| { |
| 'from': 'human', |
| 'value': DEFAULT_IMAGE_TOKEN + random.choice(cc_sbu_prompt) |
| }, |
| { |
| 'from': 'model', |
| 'value': item['caption_str'] |
| } |
| ]) |
| |
| return conversations |
|
|
|
|
| def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
| item = self.annotation[i] |
| vis_path = self.dataset_dir / item['image_path'] |
| ret = { |
| 'images': self.vis_preprocess(str(vis_path)), |
| 'conversations': self.text_preprocess(item) |
| } |
| if 'id' in item: |
| ret['id'] = item['id'] |
|
|
| return ret |
|
|
| @DATASETS.register_obj |
| def TextCaps(data_args): |
| data_cfg = data_configs['text_caps'] |
| return TextCapsDataset(data_cfg['train_data_path'], data_args) |
|
|
| if __name__ == '__main__': |
| |
| with open('/mnt/bn/yukunfeng-nasdrive/xiangchen/dataset/TextCaps/TextCaps_0.1_train.json') as f: |
| data = json.load(f) |
| res = [] |
| for value in data: |
| |
| |
| |
| |
| if len(value['questions']) == 0: |
| print(1) |
| res.append(value) |
|
|