| import datasets |
| import torch |
| import re |
| import os |
| import subprocess |
| from llava.datasets.builder import DATASETS |
|
|
| from typing import Dict, Optional, Sequence, List |
| from llava.datasets.data_cfgs import data_configs |
| from llava.datasets.base_dataset import ImageTaskDataset |
| from llava.constants import DEFAULT_IMAGE_TOKEN |
| from llava.datasets.data_cfgs import data_configs |
| from llava.utils import master_print |
|
|
| class LKImageDataset(ImageTaskDataset): |
| def __init__(self, anno_path=None, data_args=None, aux_args=None, name='lk_image'): |
| super().__init__(anno_path=anno_path, |
| data_args=data_args, |
| name=name) |
| |
| def __len__(self): |
| return len(self.annotation) |
|
|
| def text_preprocess(self, item) -> List[Dict[str, str]]: |
| return item['conversations'] |
|
|
|
|
| def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
| item = self.annotation[i] |
| vis_path = item['image'] |
| ret = { |
| 'images': self.vis_preprocess(vis_path), |
| 'conversations': self.text_preprocess(item) |
| } |
| if 'id' in item: |
| ret['id'] = item['id'] |
| return ret |
|
|
| @DATASETS.register_obj |
| def lk_image(data_args): |
| data_cfg = data_configs['lk_image'] |
| return LKImageDataset(data_cfg['train_data_path'], data_args, aux_args=data_cfg) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |