| import json |
| import os |
| import random |
|
|
| import numpy as np |
| import torch |
| import torchvision.transforms as transforms |
| from PIL import Image |
| from torch.utils.data.dataset import Dataset |
|
|
|
|
| class CC15M(Dataset): |
| def __init__( |
| self, |
| json_path, |
| video_folder=None, |
| resolution=512, |
| enable_bucket=False, |
| ): |
| print(f"loading annotations from {json_path} ...") |
| self.dataset = json.load(open(json_path, 'r')) |
| self.length = len(self.dataset) |
| print(f"data scale: {self.length}") |
| |
| self.enable_bucket = enable_bucket |
| self.video_folder = video_folder |
|
|
| resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution) |
| self.pixel_transforms = transforms.Compose([ |
| transforms.Resize(resolution[0]), |
| transforms.CenterCrop(resolution), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
| ]) |
| |
| def get_batch(self, idx): |
| video_dict = self.dataset[idx] |
| video_id, name = video_dict['file_path'], video_dict['text'] |
|
|
| if self.video_folder is None: |
| video_dir = video_id |
| else: |
| video_dir = os.path.join(self.video_folder, video_id) |
|
|
| pixel_values = Image.open(video_dir).convert("RGB") |
| return pixel_values, name |
|
|
| def __len__(self): |
| return self.length |
|
|
| def __getitem__(self, idx): |
| while True: |
| try: |
| pixel_values, name = self.get_batch(idx) |
| break |
| except Exception as e: |
| print(e) |
| idx = random.randint(0, self.length-1) |
|
|
| if not self.enable_bucket: |
| pixel_values = self.pixel_transforms(pixel_values) |
| else: |
| pixel_values = np.array(pixel_values) |
|
|
| sample = dict(pixel_values=pixel_values, text=name) |
| return sample |
|
|
| if __name__ == "__main__": |
| dataset = CC15M( |
| csv_path="/mnt_wg/zhoumo.xjq/CCUtils/cc15m_add_index.json", |
| resolution=512, |
| ) |
| |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,) |
| for idx, batch in enumerate(dataloader): |
| print(batch["pixel_values"].shape, len(batch["text"])) |