| from packaging import version |
| from PIL import Image |
| from torchvision import transforms |
| import os |
| import PIL |
| from torch.utils.data import Dataset |
| import torchvision |
| import numpy as np |
| import torch |
| import random |
| import albumentations as A |
| import copy |
| import cv2 |
| import pandas as pd |
|
|
|
|
| imagenet_templates_small = [ |
| "a photo of a {}", |
| "a rendering of a {}", |
| "a cropped photo of the {}", |
| "the photo of a {}", |
| "a photo of a clean {}", |
| "a photo of a dirty {}", |
| "a dark photo of the {}", |
| "a photo of my {}", |
| "a photo of the cool {}", |
| "a close-up photo of a {}", |
| "a bright photo of the {}", |
| "a cropped photo of a {}", |
| "a photo of the {}", |
| "a good photo of the {}", |
| "a photo of one {}", |
| "a close-up photo of the {}", |
| "a rendition of the {}", |
| "a photo of the clean {}", |
| "a rendition of a {}", |
| "a photo of a nice {}", |
| "a good photo of a {}", |
| "a photo of the nice {}", |
| "a photo of the small {}", |
| "a photo of the weird {}", |
| "a photo of the large {}", |
| "a photo of a cool {}", |
| "a photo of a small {}", |
| ] |
|
|
|
|
| if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): |
| PIL_INTERPOLATION = { |
| "linear": PIL.Image.Resampling.BILINEAR, |
| "bilinear": PIL.Image.Resampling.BILINEAR, |
| "bicubic": PIL.Image.Resampling.BICUBIC, |
| "lanczos": PIL.Image.Resampling.LANCZOS, |
| "nearest": PIL.Image.Resampling.NEAREST, |
| } |
| else: |
| PIL_INTERPOLATION = { |
| "linear": PIL.Image.LINEAR, |
| "bilinear": PIL.Image.BILINEAR, |
| "bicubic": PIL.Image.BICUBIC, |
| "lanczos": PIL.Image.LANCZOS, |
| "nearest": PIL.Image.NEAREST, |
| } |
|
|
| def is_image(file): |
| return 'jpg' in file.lower() or 'png' in file.lower() or 'jpeg' in file.lower() |
|
|
| class CustomDatasetWithBG(Dataset): |
| def __init__( |
| self, |
| data_root, |
| tokenizer, |
| size=512, |
| interpolation="bicubic", |
| placeholder_token="*", |
| template="a photo of a {}", |
| ): |
| self.data_root = data_root |
| self.tokenizer = tokenizer |
| self.size = size |
| self.placeholder_token = placeholder_token |
|
|
| self.image_paths = [] |
| self.image_paths += [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root) if is_image(file_path) and not 'bg' in file_path] |
|
|
| self.image_paths = sorted(self.image_paths) |
|
|
| self.num_images = len(self.image_paths) |
| self._length = self.num_images |
|
|
| self.interpolation = { |
| "linear": PIL_INTERPOLATION["linear"], |
| "bilinear": PIL_INTERPOLATION["bilinear"], |
| "bicubic": PIL_INTERPOLATION["bicubic"], |
| "lanczos": PIL_INTERPOLATION["lanczos"], |
| }[interpolation] |
|
|
| self.template = template |
|
|
| def __len__(self): |
| return self._length |
|
|
| def get_tensor_clip(self, normalize=True, toTensor=True): |
| transform_list = [] |
| if toTensor: |
| transform_list += [torchvision.transforms.ToTensor()] |
| if normalize: |
| transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), |
| (0.26862954, 0.26130258, 0.27577711))] |
| return torchvision.transforms.Compose(transform_list) |
|
|
| def process(self, image): |
| img = cv2.resize(image, (self.size, self.size), interpolation=cv2.INTER_CUBIC) |
| img = np.array(img).astype(np.float32) |
| img = img / 127.5 - 1.0 |
| return torch.from_numpy(img).permute(2, 0, 1) |
|
|
| def __getitem__(self, i): |
| example = {} |
|
|
| placeholder_string = self.placeholder_token |
| text = self.template.format(placeholder_string) |
| example["text"] = text |
|
|
| placeholder_index = 0 |
| words = text.strip().split(' ') |
| for idx, word in enumerate(words): |
| if word == placeholder_string: |
| placeholder_index = idx + 1 |
|
|
| example["index"] = torch.tensor(placeholder_index) |
|
|
| example["input_ids"] = self.tokenizer( |
| text, |
| padding="max_length", |
| truncation=True, |
| max_length=self.tokenizer.model_max_length, |
| return_tensors="pt", |
| ).input_ids[0] |
|
|
| image = Image.open(self.image_paths[i % self.num_images]) |
|
|
| mask_path = self.image_paths[i % self.num_images].replace('.jpeg', '.png').replace('.jpg', '.png').replace('.JPEG', '.png')[:-4] + '_bg.png' |
| mask = np.array(Image.open(mask_path)) |
|
|
| mask = np.where(mask > 0, 1, 0) |
|
|
| if not image.mode == "RGB": |
| image = image.convert("RGB") |
|
|
| image_np = np.array(image) |
| object_tensor = image_np * mask |
| example["pixel_values"] = self.process(image_np) |
|
|
|
|
| ref_object_tensor = Image.fromarray(object_tensor.astype('uint8')).resize((224, 224), resample=self.interpolation) |
| ref_image_tenser = Image.fromarray(image_np.astype('uint8')).resize((224, 224), resample=self.interpolation) |
| example["pixel_values_obj"] = self.get_tensor_clip()(ref_object_tensor) |
| example["pixel_values_clip"] = self.get_tensor_clip()(ref_image_tenser) |
|
|
| ref_seg_tensor = Image.fromarray(mask.astype('uint8') * 255) |
| ref_seg_tensor = self.get_tensor_clip(normalize=False)(ref_seg_tensor) |
| example["pixel_values_seg"] = torch.nn.functional.interpolate(ref_seg_tensor.unsqueeze(0), size=(128, 128), mode='nearest').squeeze(0) |
|
|
| return example |
|
|
|
|
| class OpenImagesDataset(Dataset): |
| def __init__( |
| self, |
| data_root, |
| tokenizer, |
| size=512, |
| interpolation="bicubic", |
| set="train", |
| placeholder_token="*", |
| ): |
| self.data_root = data_root |
| self.tokenizer = tokenizer |
| self.size = size |
| self.placeholder_token = placeholder_token |
| self.set_type = set |
|
|
| self.random_trans = A.Compose([ |
| A.Resize(height=224, width=224), |
| A.HorizontalFlip(p=0.5), |
| A.Rotate(limit=20), |
| A.Blur(p=0.3), |
| A.ElasticTransform(p=0.3) |
| ]) |
|
|
| self.bbox_path_list = [] |
| if set == "train": |
| bboxs_path = os.path.join(data_root, 'annotations', f'oidv6-train-annotations-bbox.csv') |
| elif set == "validation": |
| bboxs_path = os.path.join(data_root, 'annotations', f'validation-annotations-bbox.csv') |
| else: |
| bboxs_path = os.path.join(data_root, 'annotations', f'test-annotations-bbox.csv') |
|
|
| df_val_bbox = pd.read_csv(bboxs_path) |
| bbox_groups = df_val_bbox.groupby(df_val_bbox.LabelName) |
|
|
| bbox_full = [] |
| for label_name in df_val_bbox['LabelName'].unique(): |
| bboxs = bbox_groups.get_group(label_name)[ |
| ['XMin', 'XMax', 'YMin', 'YMax', 'LabelName', 'ImageID', |
| 'IsOccluded', 'IsTruncated', 'IsGroupOf', 'IsInside']].values.tolist() |
| bboxs_new = [] |
| for bbox in bboxs: |
| if not ((bbox[1] - bbox[0]) * (bbox[3] - bbox[2]) > 0.8 or (bbox[1] - bbox[0]) * ( |
| bbox[3] - bbox[2]) < 0.02): |
| bboxs_new.append([bbox[0], bbox[1], bbox[2], bbox[3], bbox[4], bbox[5]]) |
| bbox_full.extend(bboxs_new) |
|
|
| self.bboxs_full = bbox_full |
|
|
| self.num_images = len(bbox_full) |
|
|
| print('{}: total {} images ...'.format(set, self.num_images)) |
|
|
| self._length = self.num_images |
|
|
| self.interpolation = { |
| "linear": PIL_INTERPOLATION["linear"], |
| "bilinear": PIL_INTERPOLATION["bilinear"], |
| "bicubic": PIL_INTERPOLATION["bicubic"], |
| "lanczos": PIL_INTERPOLATION["lanczos"], |
| }[interpolation] |
|
|
| self.templates = imagenet_templates_small |
|
|
|
|
| def __len__(self): |
| return self._length |
|
|
| def get_tensor_clip(self, normalize=True, toTensor=True): |
| transform_list = [] |
| if toTensor: |
| transform_list += [torchvision.transforms.ToTensor()] |
| if normalize: |
| transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), |
| (0.26862954, 0.26130258, 0.27577711))] |
| return torchvision.transforms.Compose(transform_list) |
|
|
| def process(self, image): |
| img = np.array(image) |
| img = cv2.resize(img, (self.size, self.size), interpolation=cv2.INTER_CUBIC) |
| img = np.array(img).astype(np.float32) |
| img = img / 127.5 - 1.0 |
| return torch.from_numpy(img).permute(2, 0, 1) |
|
|
| def obtain_text(self, add_caption, object_category=None): |
|
|
| if object_category is None: |
| placeholder_string = self.placeholder_token |
| else: |
| placeholder_string = object_category |
|
|
| text = random.choice(self.templates).format(placeholder_string) |
| text = add_caption + text[1:] |
|
|
| placeholder_index = 0 |
| words = text.strip().split(' ') |
| for idx, word in enumerate(words): |
| if word == placeholder_string: |
| placeholder_index = idx + 1 |
|
|
| index = torch.tensor(placeholder_index) |
|
|
| input_ids = self.tokenizer( |
| text, |
| padding="max_length", |
| truncation=True, |
| max_length=self.tokenizer.model_max_length, |
| return_tensors="pt", |
| ).input_ids[0] |
| return input_ids, index, text |
|
|
| def __getitem__(self, i): |
| example = {} |
|
|
| input_ids, index, text = self.obtain_text('a') |
| example["input_ids"] = input_ids |
| example["index"] = index |
| example["text"] = text |
|
|
| bbox_sample = self.bboxs_full[i % self.num_images] |
| bbox_sample = copy.copy(bbox_sample) |
|
|
| file_name = bbox_sample[-1] + '.jpg' |
| img_path = os.path.join(self.data_root, 'images', self.set_type, file_name) |
|
|
| try: |
| img_p = Image.open(img_path).convert("RGB") |
| img_p_np = np.array(img_p) |
| bbox_sample[0] *= int(img_p_np.shape[1]) |
| bbox_sample[1] *= int(img_p_np.shape[1]) |
| bbox_sample[2] *= int(img_p_np.shape[0]) |
| bbox_sample[3] *= int(img_p_np.shape[0]) |
|
|
| bbox_pad = copy.copy(bbox_sample) |
| bbox_pad[0] = int(bbox_sample[0] - min(10, bbox_sample[0] - 0)) |
| bbox_pad[1] = int(bbox_sample[1] + min(10, img_p.size[0] - bbox_sample[1])) |
| bbox_pad[2] = int(bbox_sample[2] - min(10, bbox_sample[2] - 0)) |
| bbox_pad[3] = int(bbox_sample[3] + min(10, img_p.size[1] - bbox_sample[3])) |
|
|
| image_tensor = img_p_np[bbox_pad[2]:bbox_pad[3], bbox_pad[0]:bbox_pad[1], :] |
| example["pixel_values"] = self.process(image_tensor) |
|
|
| ref_image_tensor = self.random_trans(image=image_tensor) |
| ref_image_tensor = Image.fromarray(ref_image_tensor["image"]) |
| example["pixel_values_clip"] = self.get_tensor_clip()(ref_image_tensor) |
|
|
| except Exception as e: |
| example["pixel_values"] = torch.zeros((3, 512, 512)) |
| example["pixel_values_clip"] = torch.zeros((3, 224, 224)) |
| with open('error.txt', 'a+') as f: |
| f.write(str(e) + '\n') |
|
|
| return example |
|
|
|
|
| class OpenImagesDatasetWithMask(OpenImagesDataset): |
| def __init__(self, |
| data_root, |
| tokenizer, |
| size=512, |
| interpolation="bicubic", |
| set="train", |
| placeholder_token="*"): |
|
|
| |
| self.data_root = data_root |
| self.tokenizer = tokenizer |
| self.size = size |
| self.placeholder_token = placeholder_token |
| self.set = set |
|
|
| class_anno_path = os.path.join(data_root, 'annotations', f'oidv6-class-descriptions.csv') |
| anno_files = pd.read_csv(class_anno_path) |
| class_groups = anno_files.groupby(anno_files.LabelName) |
|
|
| if set == "train": |
| bboxs_path = os.path.join(data_root, 'annotations', f'train-annotations-object-segmentation.csv') |
| dict_path = os.path.join(data_root, 'segs', f'train_bbox_dict.npy') |
| elif set == "validation": |
| bboxs_path = os.path.join(data_root, 'annotations', f'validation-annotations-object-segmentation.csv') |
| dict_path = os.path.join(data_root, 'segs', f'validation_bbox_dict.npy') |
| else: |
| bboxs_path = os.path.join(data_root, 'annotations', f'test-annotations-object-segmentation.csv') |
| dict_path = os.path.join(data_root, 'segs', f'test_bbox_dict.npy') |
|
|
| bbox_dict = np.load(dict_path, allow_pickle=True).item() |
|
|
| df_val_bbox = pd.read_csv(bboxs_path) |
| bbox_groups = df_val_bbox.groupby(df_val_bbox.LabelName) |
| bboxes_full = [] |
| for label_name in df_val_bbox['LabelName'].unique(): |
| bboxs = bbox_groups.get_group(label_name)[ |
| ['BoxXMin', 'BoxXMax', 'BoxYMin', 'BoxYMax', 'LabelName', 'MaskPath']].values.tolist() |
| bboxes_new = [] |
| for box in bboxs: |
| if not box[-1] in bbox_dict: |
| continue |
| bbox_data = bbox_dict[box[-1]] |
|
|
| if (bbox_data[2] - bbox_data[1]) < 100 or (bbox_data[4] - bbox_data[3]) < 100: |
| continue |
| if not ((bbox_data[2] - bbox_data[1]) / (bbox_data[4] - bbox_data[3]) < 0.5 or ( |
| bbox_data[4] - bbox_data[3]) / ( bbox_data[2] - bbox_data[1]) < 0.5): |
| class_name = class_groups.get_group(box[4])[['DisplayName']].values.tolist()[0][0] |
| bboxes_new.append([box[-1], bbox_data[1], bbox_data[2], bbox_data[3], bbox_data[4], class_name]) |
|
|
| bboxes_full.extend(bboxes_new) |
|
|
| self.bboxes_full = bboxes_full |
| self.num_images = len(bboxes_full) |
|
|
| print('{}: total {} images ...'.format(set, self.num_images)) |
|
|
| self._length = self.num_images |
| self.interpolation = { |
| "linear": PIL_INTERPOLATION["linear"], |
| "bilinear": PIL_INTERPOLATION["bilinear"], |
| "bicubic": PIL_INTERPOLATION["bicubic"], |
| "lanczos": PIL_INTERPOLATION["lanczos"], |
| }[interpolation] |
|
|
| self.templates = imagenet_templates_small |
|
|
|
|
| def __len__(self): |
| return self._length |
|
|
| |
| def custom_aug(self, instance_image): |
| instance_image = Image.fromarray(instance_image) |
| |
| if np.random.randint(0, 3) < 2: |
| random_scale = np.random.randint(self.size // 3, self.size + 1) |
| else: |
| random_scale = np.random.randint(int(1.2 * self.size), int(1.4 * self.size)) |
|
|
| if random_scale % 2 == 1: |
| random_scale += 1 |
|
|
| if random_scale < 0.6 * self.size: |
| add_to_caption = np.random.choice(["a far away", "very small"]) |
| cx = np.random.randint(random_scale // 2, self.size - random_scale // 2 + 1) |
| cy = np.random.randint(random_scale // 2, self.size - random_scale // 2 + 1) |
|
|
| instance_image1 = instance_image.resize((random_scale, random_scale), resample=self.interpolation) |
| instance_image1 = np.array(instance_image1).astype(np.uint8) |
| instance_image1 = (instance_image1 / 127.5 - 1.0).astype(np.float32) |
|
|
| instance_image = np.zeros((self.size, self.size, 3), dtype=np.float32) |
| instance_image[cx - random_scale // 2: cx + random_scale // 2, |
| cy - random_scale // 2: cy + random_scale // 2, :] = instance_image1 |
|
|
| mask = np.zeros((self.size // 8, self.size // 8)) |
| mask[(cx - random_scale // 2) // 8 + 1: (cx + random_scale // 2) // 8 - 1, |
| (cy - random_scale // 2) // 8 + 1: (cy + random_scale // 2) // 8 - 1] = 1. |
|
|
| elif random_scale > self.size: |
| add_to_caption = np.random.choice(["zoomed in", "close up"]) |
| cx = np.random.randint(self.size // 2, random_scale - self.size // 2 + 1) |
| cy = np.random.randint(self.size // 2, random_scale - self.size // 2 + 1) |
|
|
| instance_image = instance_image.resize((random_scale, random_scale), resample=self.interpolation) |
| instance_image = np.array(instance_image).astype(np.uint8) |
| instance_image = (instance_image / 127.5 - 1.0).astype(np.float32) |
| instance_image = instance_image[cx - self.size // 2: cx + self.size // 2, |
| cy - self.size // 2: cy + self.size // 2, :] |
| mask = np.ones((self.size // 8, self.size // 8)) |
| else: |
| add_to_caption = "a" |
| if self.size is not None: |
| instance_image = instance_image.resize((self.size, self.size), resample=self.interpolation) |
| instance_image = np.array(instance_image).astype(np.uint8) |
| instance_image = (instance_image / 127.5 - 1.0).astype(np.float32) |
| mask = np.ones((self.size // 8, self.size // 8)) |
|
|
| return torch.from_numpy(instance_image).permute(2, 0, 1), torch.from_numpy(mask[:, :, None]).permute(2, 0, 1), add_to_caption |
|
|
| def aug_cv2(self, img, seg): |
|
|
| img_auged = np.array(img).copy() |
| seg_auged = np.array(seg).copy() |
| |
| if random.choice([0, 1]) == 0: |
| new_size = random.randint(224, 256) |
| img_auged = cv2.resize(img_auged, (new_size, new_size), interpolation=cv2.INTER_CUBIC) |
| seg_auged = cv2.resize(seg_auged, (new_size, new_size), interpolation=cv2.INTER_NEAREST) |
|
|
| start_x, start_y = random.randint(0, new_size - 224), random.randint(0, new_size - 224) |
| img_auged = img_auged[start_x:start_x + 224, start_y:start_y + 224, :] |
| seg_auged = seg_auged[start_x:start_x + 224, start_y:start_y + 224, :] |
|
|
| h, w = img_auged.shape[:2] |
| |
| if random.choice([0, 1]) == 0: |
| |
| angle = random.randint(-30, 30) |
| M = cv2.getRotationMatrix2D((112, 112), angle, 1) |
| img_auged = cv2.warpAffine(img_auged, M, (w, h), flags=cv2.INTER_CUBIC) |
| seg_auged = cv2.warpAffine(seg_auged, M, (w, h), flags=cv2.INTER_NEAREST) |
|
|
| |
| if random.choice([0, 1]) == 0: |
| trans_x = random.randint(-60, 60) |
| trans_y = random.randint(-60, 60) |
| H = np.float32([[1, 0, trans_x], |
| [0, 1, trans_y]]) |
| img_auged = cv2.warpAffine(img_auged, H, (w, h), flags=cv2.INTER_CUBIC) |
| seg_auged = cv2.warpAffine(seg_auged, H, (w, h), flags=cv2.INTER_NEAREST) |
|
|
| img_auged = Image.fromarray(img_auged) |
| seg_auged = Image.fromarray(seg_auged) |
|
|
| return img_auged, seg_auged |
|
|
|
|
| def __getitem__(self, i): |
| example = {} |
|
|
| seg_name = self.bboxes_full[i % self.num_images][0] |
| file_name = seg_name.split('_')[0] + '.jpg' |
| img_path = os.path.join(self.data_root, 'images', self.set, file_name) |
| seg_path = os.path.join(self.data_root, 'segs', self.set, seg_name) |
|
|
| try: |
| |
| bbox_sample = self.bboxes_full[i % self.num_images][1:] |
| img_p_np = cv2.imread(img_path) |
| img_p_np = cv2.cvtColor(img_p_np, cv2.COLOR_BGR2RGB) |
| seg_p_np = cv2.imread(seg_path).astype('float') |
| seg_p_np = cv2.resize(seg_p_np, img_p_np.shape[:2][::-1], interpolation=cv2.INTER_NEAREST) |
|
|
| bbox_pad = copy.copy(bbox_sample) |
| pad_size = random.choice(list(range(10, 20))) |
| bbox_pad[0] = int(bbox_pad[0] - min(pad_size, bbox_pad[0] - 0)) |
| bbox_pad[1] = int(bbox_pad[1] + pad_size) |
| bbox_pad[2] = int(bbox_pad[2] - min(pad_size, bbox_pad[2] - 0)) |
| bbox_pad[3] = int(bbox_pad[3] + pad_size) |
|
|
| image_tensor = img_p_np[bbox_pad[0]:bbox_pad[1], bbox_pad[2]:bbox_pad[3], :] |
| seg_tensor = seg_p_np[bbox_pad[0]:bbox_pad[1], bbox_pad[2]:bbox_pad[3], :] |
|
|
| |
| augged_image, augged_mask, add_caption = self.custom_aug(image_tensor) |
| input_ids, index, text = self.obtain_text(add_caption) |
|
|
| example["pixel_values"] = augged_image |
| example["mask_values"] = augged_mask |
| example["input_ids"] = input_ids |
| example["index"] = index |
| example["text"] = text |
|
|
| object_tensor = image_tensor * (seg_tensor / 255) |
| ref_object_tensor = cv2.resize(object_tensor, (224, 224), interpolation=cv2.INTER_CUBIC) |
| ref_image_tenser = cv2.resize(image_tensor, (224, 224), interpolation=cv2.INTER_CUBIC) |
| ref_seg_tensor = cv2.resize(seg_tensor, (224, 224), interpolation=cv2.INTER_NEAREST) |
|
|
| ref_object_tensor, ref_seg_tensor = self.aug_cv2(ref_object_tensor.astype('uint8'), ref_seg_tensor.astype('uint8')) |
| example["pixel_values_clip"] = self.get_tensor_clip()(Image.fromarray(ref_image_tenser)) |
| example["pixel_values_obj"] = self.get_tensor_clip()(ref_object_tensor) |
| example["pixel_values_seg"] = self.get_tensor_clip(normalize=False)(ref_seg_tensor) |
|
|
| except Exception as e: |
| example["pixel_values"] = torch.zeros((3, 512, 512)) |
| example["pixel_values_obj"] = torch.zeros((3, 224, 224)) |
| example["pixel_values_clip"] = torch.zeros((3, 224, 224)) |
| example["pixel_values_seg"] = torch.zeros((3, 224, 224)) |
|
|
| input_ids, index, text = self.obtain_text("a") |
| example["input_ids"] = input_ids |
| example["index"] = index |
| example["text"] = text |
|
|
| with open('error.txt', 'a+') as f: |
| f.write(str(e) + '\n') |
|
|
| return example |
|
|