|
|
| from llava.datasets.builder import DATASETS |
|
|
| from typing import Dict, Optional, Sequence, List |
| from llava.datasets.data_cfgs import data_configs |
| from llava.datasets.base_dataset import FramesTaskDataset |
| from llava.datasets.data_cfgs import data_configs |
| import pickle |
| from pathlib import Path |
| import random |
| import numpy as np |
| from llava.datasets.prompts import tt_caption_prompt, internvid_prompt |
| from llava.constants import DEFAULT_VIDEO_TOKEN |
| from PIL import Image |
| import json |
| import torch |
| import os |
|
|
|
|
| class GPT4VPublicDataset(FramesTaskDataset): |
| def __init__(self, anno_path=None, data_args=None, fps=1.0, conv_type='single', task_types=None, sample_method='uniform', name='gpt4v_public'): |
| self.default_fps = 1.0 |
| self.fps = fps |
| self.conv_type = conv_type |
| self.task_types = task_types |
| self.annotation = self.get_dataset(anno_path) |
| self.sample_method = sample_method |
| assert self.conv_type in ('single', 'multi'), "gpt4v_public conv type must in single/multi" |
| assert self.sample_method in ('sequential', 'uniform'), "gpt4v_public sample method must in sequential/uniform" |
| |
| |
| super().__init__(anno_path=anno_path, |
| data_args=data_args, |
| fps=fps, |
| name=name) |
| def __len__(self): |
| return len(self.annotation) |
|
|
|
|
| def get_dataset(self, anno_path): |
| dataset = [] |
| anno_path = Path(anno_path) |
| with anno_path.open('rb') as f: |
| data = json.load(f) |
| for info in data: |
| filtered_qa = [] |
| if 'qa_pairs' not in info: |
| index = 0 |
| while index < len(info['conversation']): |
| if len(info['conversation'][index].strip()) == 0: |
| index += 1 |
| continue |
| if 'C' in info['conversation'][index]: |
| if index+1 < len(info['conversation']) and 'A' in info['conversation'][index+1]: |
| filtered_qa.append( |
| [info['conversation'][index], info['conversation'][index+1]] |
| ) |
| index += 2 |
| else: |
| index += 1 |
| continue |
| else: |
| |
| index += 1 |
| continue |
| else: |
| for qa in info['qa_pairs']: |
| if len(qa[0]) == 0 or len(qa[1]) == 0: |
| continue |
| filtered_qa.append(qa) |
| info['qa_pairs'] = filtered_qa |
| |
| for task_type in self.task_types: |
| info_task = info.copy() |
| if len(info_task[task_type]) == 0: |
| continue |
| if task_type == 'qa_pairs' and self.conv_type == 'single': |
| for qa_pair in info_task[task_type]: |
| one_info = info_task.copy() |
| one_info[task_type] = [qa_pair] |
| one_info.update({ |
| 'task_type': task_type |
| }) |
| dataset.append(one_info) |
| else: |
| info_task.update({ |
| 'task_type': task_type |
| }) |
| dataset.append(info_task) |
|
|
| return dataset |
|
|
| |
| |
| |
|
|
| |
|
|
| |
|
|
| def text_preprocess(self, item) -> List[Dict[str, str]]: |
| all_convs = [] |
| |
| if item['task_type'] == 'summary': |
| summary = '' |
| if isinstance(item['summary'], list): |
| for s in item['summary']: |
| if len(s.strip()) != 0: |
| summary = s |
| break |
| else: |
| summary = item['summary'] |
|
|
| all_convs.append([ |
| { |
| 'from': 'human', |
| 'value': random.choice(internvid_prompt) |
| }, |
| { |
| 'from': 'model', |
| 'value': summary |
| } |
| ]) |
| elif item['task_type'] == 'detail': |
| detail = '' |
| if isinstance(item['detail'], list): |
| for s in item['detail']: |
| if len(s.strip()) != 0: |
| detail = s |
| break |
| else: |
| detail = item['detail'] |
| |
| all_convs.append([ |
| { |
| 'from': 'human', |
| 'value': random.choice(tt_caption_prompt) |
| }, |
| { |
| 'from': 'model', |
| 'value': detail |
| } |
| ]) |
| else: |
| for qa in item['qa_pairs']: |
| all_convs.append([ |
| { |
| 'from': 'human', |
| 'value': qa[0] |
| }, |
| { |
| 'from': 'model', |
| 'value': qa[1] |
| } |
| ]) |
| |
| conversations = [] |
| random.shuffle(all_convs) |
| for idx, conv in enumerate(all_convs): |
| if idx == 0: |
| conv[0]['value'] = DEFAULT_VIDEO_TOKEN + conv[0]['value'] |
| conversations.extend(conv) |
|
|
| return conversations |
|
|
|
|
| def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
| item = self.annotation[i] |
|
|
| ret = { |
| 'images': self.vis_preprocess(item['vis_path']), |
| 'conversations': self.text_preprocess(item) |
| } |
| if 'id' in item: |
| ret['id'] = item['id'] |
|
|
| return ret |
|
|
|
|
| def _sample_frames(self, frames, num_segments, preprocess=False): |
| if preprocess: |
| if self.sample_method == 'uniform': |
| indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int) |
| elif self.sample_method == 'sequential': |
| indices = range(10) |
| else: |
| raise NotImplementedError |
| frames = [frames[ind] for ind in indices] |
| else: |
| indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int) |
| frames = [frames[ind] for ind in indices] |
|
|
| return frames |
|
|
| def vis_preprocess(self, vis_path): |
| image_files = [] |
| for img_path in os.listdir(vis_path): |
| if img_path.endswith('.jpeg'): |
| img_idx = int(img_path.split('_')[-1][:-5]) |
| image_files.append((img_idx, img_path)) |
| |
| image_files = sorted(image_files, key=lambda img: img[0]) |
| |
| if len(image_files) > 10: |
| image_files = self._sample_frames(image_files, 10, preprocess=True) |
| if self.num_segments > 0 and len(image_files) > self.num_segments: |
| image_files = self._sample_frames(image_files, self.num_segments) |
| |
| images = [] |
| for image_file in image_files: |
| try: |
| images.append(Image.open(os.path.join(vis_path, image_file[1])).convert('RGB')) |
| except Exception as e: |
| continue |
| formatted_images = [] |
| for image in images: |
| im = self.preprocess_image(image) |
| if isinstance(im, list): |
| formatted_images.extend(im) |
| else: |
| formatted_images.append(im) |
| return formatted_images |
|
|
|
|
| @DATASETS.register_obj |
| def gpt4v_public(data_args): |
| data_cfg = data_configs['gpt4v_public'] |
| if 'train_data_path' in data_args.external_args: |
| data_cfg['train_data_path'] = data_args.external_args['train_data_path'] |
| anno_path = data_cfg['train_data_path'] |
| fps, conv_type, task_types = data_args.external_args['fps'], data_args.external_args['conv_type'], data_args.external_args['task_types'] |
| if 'sample_method' in data_args.external_args: |
| sample_method = data_args.external_args['sample_method'] |
| else: |
| sample_method = 'uniform' |
| return GPT4VPublicDataset(anno_path, data_args, fps, conv_type, task_types, sample_method) |
|
|
|
|
| if __name__ == '__main__': |
| pass |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |