Hang Zhou commited on
Commit
0103f17
·
verified ·
1 Parent(s): 5b662d1

Upload folder using huggingface_hub

Browse files
datasets/__init__.py ADDED
File without changes
datasets/base.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import albumentations as A
4
+ from torch.utils.data import Dataset
5
+ from .data_utils import *
6
+
7
+ class BaseDataset(Dataset):
8
+ def __init__(self):
9
+ self.data = []
10
+
11
+ def __getitem__(self, idx):
12
+ item = self._get_sample(idx)
13
+ return item
14
+
15
+ def _get_sample(self, idx):
16
+ # Implemented for each specific dataset
17
+ pass
18
+
19
+ def __len__(self):
20
+ # We adjust the ratio of different dataset by setting the length.
21
+ pass
22
+
23
+ def aug_data_mask(self, image, mask):
24
+ transform = A.Compose([
25
+ A.RandomBrightnessContrast(p=0.5),
26
+ A.Rotate(limit=30, border_mode=cv2.BORDER_CONSTANT),
27
+ ])
28
+
29
+ transformed = transform(image=image.astype(np.uint8), mask=mask)
30
+ transformed_image = transformed["image"]
31
+ transformed_mask = transformed["mask"]
32
+ return transformed_image, transformed_mask
33
+
34
+ # def aug_patch(self, patch):
35
+ # transform = A.Compose([
36
+ # A.HorizontalFlip(p=0.2),
37
+ # A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.3),
38
+ # A.Rotate(limit=15, border_mode=cv2.BORDER_REPLICATE, p=0.5),
39
+ # ])
40
+
41
+ # return transform(image=patch)["image"]
42
+
43
+ def aug_patch(self, patch):
44
+ gray = cv2.cvtColor(patch, cv2.COLOR_RGB2GRAY)
45
+ mask = (gray < 250).astype(np.float32)[:, :, None]
46
+
47
+ transform = A.Compose([
48
+ A.HorizontalFlip(p=0.2),
49
+ A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.3),
50
+ A.Rotate(limit=15, border_mode=cv2.BORDER_REPLICATE, p=0.5),
51
+ ])
52
+
53
+ transformed = transform(image=patch.astype(np.uint8), mask=mask)
54
+ aug_img = transformed["image"]
55
+ aug_mask = transformed["mask"]
56
+ final_img = aug_img * aug_mask + 255 * (1 - aug_mask)
57
+
58
+ return final_img.astype(np.uint8)
59
+
60
+ def sample_timestep(self, max_step=1000):
61
+ if np.random.rand() < 0.3:
62
+ step = np.random.randint(0, max_step)
63
+ else:
64
+ step = np.random.randint(0, max_step // 2)
65
+ return np.array([step])
66
+
67
+ def get_patch(self, ref_image, ref_mask):
68
+ '''
69
+ extract compact patch and convert to 224x224 RGBA.
70
+ ref_mask: [0, 1]
71
+ '''
72
+
73
+ # 1. Get the outline Box of the reference image
74
+ y1, y2, x1, x2 = get_bbox_from_mask(ref_mask) # y1y2x1x2, obtain location from ref patch
75
+
76
+ # 2. Background is set to white (255)
77
+ ref_mask_3 = np.stack([ref_mask, ref_mask, ref_mask], -1)
78
+ masked_ref_image = ref_image * ref_mask_3 + np.ones_like(ref_image) * 255 * (1 - ref_mask_3)
79
+
80
+ # 3. Crop based on bounding boxes
81
+ masked_ref_image = masked_ref_image[y1:y2, x1:x2, :]
82
+ ref_mask_crop = ref_mask[y1:y2, x1:x2] # obtain a tight mask
83
+
84
+ # 4. Dilate the patch and mask
85
+ ratio = np.random.randint(11, 15) / 10
86
+ masked_ref_image, ref_mask_crop = expand_image_mask(masked_ref_image, ref_mask_crop, ratio=ratio)
87
+
88
+ # augmentation
89
+ # masked_ref_image, ref_mask_crop = self.aug_data_mask(masked_ref_image, ref_mask_crop)
90
+
91
+ # 5. Padding & Resize
92
+ masked_ref_image = pad_to_square(masked_ref_image, pad_value=255)
93
+ masked_ref_image = cv2.resize(masked_ref_image.astype(np.uint8), (224, 224))
94
+
95
+ m_local = ref_mask_crop[:, :, None] * 255
96
+ m_local = pad_to_square(m_local, pad_value=0)
97
+ m_local = cv2.resize(m_local.astype(np.uint8), (224, 224), interpolation=cv2.INTER_NEAREST)
98
+
99
+ rgba_image = np.dstack((masked_ref_image.astype(np.uint8), m_local))
100
+
101
+ return rgba_image
102
+
103
+ def _construct_collage(self, image, object_0, object_1, mask_0, mask_1):
104
+ background = image.copy()
105
+ image = pad_to_square(image, pad_value = 0, random = False).astype(np.uint8)
106
+ image = cv2.resize(image.astype(np.uint8), (512,512)).astype(np.float32)
107
+ image = image / 127.5 - 1.0
108
+ item = {}
109
+ item.update({'jpg': image.copy()}) # source image (checked) [-1, 1], 512x512x3
110
+
111
+ ratio = np.random.randint(11, 15) / 10
112
+ object_0 = expand_image(object_0, ratio=ratio)
113
+ object_0 = self.aug_patch(object_0)
114
+ object_0 = pad_to_square(object_0, pad_value = 255, random = False) # pad to square
115
+ object_0 = cv2.resize(object_0.astype(np.uint8), (224,224) ).astype(np.uint8) # check 1
116
+ object_0 = object_0 / 255
117
+ item.update({'ref0': object_0.copy()}) # patch 0 (checked) [0, 1], 224x224x3
118
+
119
+ ratio = np.random.randint(11, 15) / 10
120
+ object_1 = expand_image(object_1, ratio=ratio)
121
+ object_1 = self.aug_patch(object_1)
122
+ object_1 = pad_to_square(object_1, pad_value = 255, random = False) # pad to square
123
+ object_1 = cv2.resize(object_1.astype(np.uint8), (224,224) ).astype(np.uint8) # check 1
124
+ object_1 = object_1 / 255
125
+ item.update({'ref1': object_1.copy()}) # patch 1 (checked) [0, 1], 224x224x3
126
+
127
+ background_mask0 = background.copy() * 0.0
128
+ background_mask1 = background.copy() * 0.0
129
+ background_mask = background.copy() * 0.0
130
+
131
+ box_yyxx = get_bbox_from_mask(mask_0)
132
+ box_yyxx = expand_bbox(mask_0, box_yyxx, ratio=[1.1, 1.2]) #1.1 1.3
133
+ y1, y2, x1, x2 = box_yyxx
134
+ background[y1:y2, x1:x2,:] = 0
135
+ background_mask0[y1:y2, x1:x2, :] = 1.0
136
+ background_mask[y1:y2, x1:x2, :] = 1.0
137
+
138
+ box_yyxx = get_bbox_from_mask(mask_1)
139
+ box_yyxx = expand_bbox(mask_1, box_yyxx, ratio=[1.1, 1.2]) #1.1 1.3
140
+ y1, y2, x1, x2 = box_yyxx
141
+ background[y1:y2, x1:x2,:] = 0
142
+ background_mask1[y1:y2, x1:x2, :] = 1.0
143
+ background_mask[y1:y2, x1:x2, :] = 1.0
144
+
145
+ background = pad_to_square(background, pad_value = 0, random = False).astype(np.uint8)
146
+ background = cv2.resize(background.astype(np.uint8), (512,512)).astype(np.float32)
147
+ background_mask0 = pad_to_square(background_mask0, pad_value = 2, random = False).astype(np.uint8)
148
+ background_mask1 = pad_to_square(background_mask1, pad_value = 2, random = False).astype(np.uint8)
149
+ background_mask = pad_to_square(background_mask, pad_value = 2, random = False).astype(np.uint8)
150
+ background_mask0 = cv2.resize(background_mask0.astype(np.uint8), (512,512), interpolation = cv2.INTER_NEAREST).astype(np.float32)
151
+ background_mask1 = cv2.resize(background_mask1.astype(np.uint8), (512,512), interpolation = cv2.INTER_NEAREST).astype(np.float32)
152
+ background_mask = cv2.resize(background_mask.astype(np.uint8), (512,512), interpolation = cv2.INTER_NEAREST).astype(np.float32)
153
+
154
+ background_mask0[background_mask0 == 2] = -1
155
+ background_mask1[background_mask1 == 2] = -1
156
+ background_mask[background_mask == 2] = -1
157
+
158
+ background_mask0_ = background_mask0
159
+ background_mask0_[background_mask0_ == -1] = 0
160
+ background_mask0_ = background_mask0_[:, :, 0]
161
+
162
+ background_mask1_ = background_mask1
163
+ background_mask1_[background_mask1_ == -1] = 0
164
+ background_mask1_ = background_mask1_[:, :, 0]
165
+
166
+ background = background / 127.5 - 1.0
167
+ background = np.concatenate([background, background_mask[:,:,:1]] , -1)
168
+ item.update({'hint': background.copy()})
169
+
170
+ item.update({'mask0': background_mask0_.copy()})
171
+ item.update({'mask1': background_mask1_.copy()})
172
+
173
+ sampled_time_steps = self.sample_timestep()
174
+ item['time_steps'] = sampled_time_steps
175
+ item['object_num'] = 2
176
+
177
+ return item
datasets/bdd100k.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import cv2
3
+ import numpy as np
4
+ import os
5
+ from .data_utils import *
6
+ from .base import BaseDataset
7
+ from util.box_ops import compute_iou_matrix, draw_bboxes
8
+ from pathlib import Path
9
+ from pycocotools import mask as mask_utils
10
+ import shutil
11
+
12
+ IS_VERIFY = False
13
+
14
+ class BDD100KDataset(BaseDataset):
15
+ def __init__(self, construct_dataset_dir, obj_thr=20, area_ratio=0.02):
16
+ self.obj_thr = obj_thr
17
+ self.construct_dataset_dir = construct_dataset_dir
18
+ os.makedirs(Path(self.construct_dataset_dir), exist_ok=True)
19
+ self.area_ratio = area_ratio
20
+ self.sample_list = os.listdir(self.construct_dataset_dir)
21
+
22
+ def _intersect_2_obj(self, image_dir, samples, idx):
23
+ self.image_dir = image_dir
24
+ sample = samples[idx]
25
+ image_name = sample['name']
26
+ image_path = os.path.join(image_dir, image_name)
27
+ image = cv2.imread(image_path)
28
+ h, w = image.shape[0:2]
29
+ image_area = h * w
30
+
31
+ labels = sample['labels']
32
+
33
+ # filter by area
34
+ obj_ids = []
35
+ obj_areas = []
36
+ obj_bbox = []
37
+ for i in range(len(labels)):
38
+ obj = labels[i]
39
+ bbox = [obj['box2d']['x1'], obj['box2d']['y1'], obj['box2d']['x2'], obj['box2d']['y2']]
40
+ rle = obj['rle']
41
+ mask = mask_utils.decode(rle)
42
+ area = np.sum(mask)
43
+ if area > image_area * self.area_ratio:
44
+ obj_ids.append(i)
45
+ obj_areas.append(area)
46
+ obj_bbox.append(bbox)
47
+
48
+ if len(obj_bbox) < 2:
49
+ print(f"[Info] Skip image index {image_name[:-4]} due to insufficient bbox.")
50
+ return
51
+
52
+ os.makedirs(Path(self.construct_dataset_dir) / image_name[:-4], exist_ok=True)
53
+ bbox_xyxy = np.array(obj_bbox)
54
+
55
+ if IS_VERIFY:
56
+ image_with_boxes = draw_bboxes(image, bbox_xyxy)
57
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-4] / "bboxes_image.png"), image_with_boxes)
58
+
59
+ iou_matrix = compute_iou_matrix(bbox_xyxy)
60
+ np.fill_diagonal(iou_matrix, -1) # Exclude self-comparisons (i.e., each box with itself)
61
+
62
+ max_index = np.unravel_index(np.argmax(iou_matrix), iou_matrix.shape)
63
+ index0, index1 = max_index[0], max_index[1]
64
+ max_iou = iou_matrix[index0, index1]
65
+
66
+ if max_iou <= 0:
67
+ print(f"[Info] Skip image index {image_name[:-4]} due to no overlapping bboxes.")
68
+ return
69
+
70
+ dst = Path(self.construct_dataset_dir) / image_name[:-4] / "image.jpg"
71
+ dst.parent.mkdir(parents=True, exist_ok=True)
72
+ shutil.copy(image_path, dst)
73
+
74
+ box0 = obj_bbox[index0]
75
+ box1 = obj_bbox[index1]
76
+
77
+ counter = 0
78
+ for i in range(len(labels)):
79
+ obj = labels[i]
80
+ rle = obj['rle']
81
+ if counter == obj_ids[index0]:
82
+ mask = mask_utils.decode(rle)
83
+ counter += 1
84
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-4] / "object_0_mask.png"), 255*mask)
85
+ patch = self.get_patch(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), mask)
86
+ patch = cv2.cvtColor(patch, cv2.COLOR_RGB2BGR)
87
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-4] / "object_0.png"), patch)
88
+
89
+ if IS_VERIFY:
90
+ mask_color = np.stack([mask * 255]*3, axis=-1).astype(np.uint8)
91
+ highlight = np.zeros_like(image)
92
+ highlight[:, :, 2] = 255 # red channel
93
+ alpha = 0.5
94
+ image_with_boxes = np.where(mask_color == 255, cv2.addWeighted(image_with_boxes, 1 - alpha, highlight, alpha, 0), image_with_boxes)
95
+
96
+ counter = 0
97
+ for i in range(len(labels)):
98
+ obj = labels[i]
99
+ rle = obj['rle']
100
+ if counter == obj_ids[index1]:
101
+ mask = mask_utils.decode(rle)
102
+ counter += 1
103
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-4] / "object_1_mask.png"), 255*mask)
104
+ patch = self.get_patch(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), mask)
105
+ patch = cv2.cvtColor(patch, cv2.COLOR_RGB2BGR)
106
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-4] / "object_1.png"), patch)
107
+
108
+ if IS_VERIFY:
109
+ mask_color = np.stack([mask * 255]*3, axis=-1).astype(np.uint8)
110
+ highlight = np.zeros_like(image)
111
+ highlight[:, :, 0] = 255 # blue channel
112
+ alpha = 0.5
113
+ image_with_boxes = np.where(mask_color == 255, cv2.addWeighted(image_with_boxes, 1 - alpha, highlight, alpha, 0), image_with_boxes)
114
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-4] / "highlighted_image.png"), image_with_boxes)
115
+
116
+ def _get_sample(self, idx):
117
+ sample_path = os.path.join(self.construct_dataset_dir, self.sample_list[idx])
118
+ image = cv2.cvtColor(cv2.imread(os.path.join(sample_path, "image.jpg")), cv2.COLOR_BGR2RGB)
119
+ object_0 = cv2.cvtColor(cv2.imread(os.path.join(sample_path, "object_0.png")), cv2.COLOR_BGR2RGB)
120
+ object_1 = cv2.cvtColor(cv2.imread(os.path.join(sample_path, "object_1.png")), cv2.COLOR_BGR2RGB)
121
+ mask_0 = cv2.imread(os.path.join(sample_path, "object_0_mask.png"), cv2.IMREAD_GRAYSCALE)
122
+ mask_1 = cv2.imread(os.path.join(sample_path, "object_1_mask.png"), cv2.IMREAD_GRAYSCALE)
123
+ collage = self._construct_collage(image, object_0, object_1, mask_0, mask_1)
124
+ return collage
125
+
126
+ def __len__(self):
127
+ return len(os.listdir(self.construct_dataset_dir))
128
+
129
+
130
+ if __name__ == "__main__":
131
+ '''
132
+ two-object case: train/test: 1012/371
133
+ '''
134
+ import argparse
135
+
136
+ parser = argparse.ArgumentParser(description="BDD100KDataset Analysis")
137
+ parser.add_argument("--dataset_dir", type=str, required=True, help="Path to the dataset directory.")
138
+ parser.add_argument("--construct_dataset_dir", type=str, default='bin', help="Path to the debug bin directory.")
139
+ parser.add_argument("--dataset_name", type=str, default='bdd100k', help="Dataset name.")
140
+ parser.add_argument('--is_train', action='store_true', help="Train/Test")
141
+ parser.add_argument('--is_build_data', action='store_true', help="Build data")
142
+ parser.add_argument('--is_multiple', action='store_true', help="Multiple/Two objects")
143
+ parser.add_argument("--area_ratio", type=float, default=0.01171, help="Area ratio for filtering out small objects.")
144
+ parser.add_argument("--obj_thr", type=int, default=20, help="Object threshold for filtering.")
145
+ parser.add_argument("--index", type=int, default=0, help="Index of the sample to test.")
146
+ args = parser.parse_args()
147
+
148
+ if args.is_train:
149
+ image_dir = Path(args.dataset_dir) / args.dataset_name / "images" / "10k" / "train"
150
+ json_path = Path(args.dataset_dir) / args.dataset_name / "labels" / "ins_seg" / "rles" / "ins_seg_train.json"
151
+ max_num = 7000
152
+ else:
153
+ image_dir = Path(args.dataset_dir) / args.dataset_name / "images" / "10k" / "val"
154
+ json_path = Path(args.dataset_dir) / args.dataset_name / "labels" / "ins_seg" / "rles" / "ins_seg_val.json"
155
+ max_num = 1000
156
+
157
+ dataset = BDD100KDataset(
158
+ construct_dataset_dir = args.construct_dataset_dir,
159
+ obj_thr = args.obj_thr,
160
+ area_ratio = args.area_ratio,
161
+ )
162
+
163
+ with open(json_path) as data_file:
164
+ label = json.load(data_file)
165
+ samples = label["frames"]
166
+
167
+ if args.is_build_data:
168
+ if not args.is_multiple:
169
+ for index in range(max_num):
170
+ dataset._intersect_2_obj(image_dir, samples, index)
171
+ else:
172
+ for index in range(len(os.listdir(args.construct_dataset_dir))):
173
+ collage = dataset._get_sample(index)
datasets/cityscapes.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import os
4
+ from PIL import Image
5
+ from .data_utils import *
6
+ from .base import BaseDataset
7
+ import PIL.ImageDraw as ImageDraw
8
+ from util.box_ops import mask_to_bbox_xywh, compute_iou_matrix, draw_bboxes
9
+ from util.cityscapes_ops import Annotation, name2label
10
+ from pathlib import Path
11
+ import shutil
12
+
13
+ IS_VERIFY = False
14
+
15
+ class CityscapesDataset(BaseDataset):
16
+ def __init__(self, construct_dataset_dir, obj_thr=20, area_ratio=0.02):
17
+ self.obj_thr = obj_thr
18
+ self.construct_dataset_dir = construct_dataset_dir
19
+ os.makedirs(Path(self.construct_dataset_dir), exist_ok=True)
20
+ self.area_ratio = area_ratio
21
+ self.sample_list = os.listdir(self.construct_dataset_dir)
22
+
23
+ def _intersect_2_obj(self, image_dir, json_dir, idx):
24
+ json_list = os.listdir(json_dir)
25
+ image_name = json_list[idx][:-21]
26
+ image_path = os.path.join(image_dir, image_name+'_leftImg8bit.png')
27
+ image = cv2.imread(image_path)
28
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
29
+
30
+ json_path = os.path.join(json_dir, image_name+'_gtFine_polygons.json')
31
+ annotation = Annotation()
32
+ annotation.fromJsonFile(json_path)
33
+ size = (annotation.imgWidth, annotation.imgHeight)
34
+ image_area = size[0]*size[1]
35
+
36
+ # the background
37
+ background = name2label['unlabeled'].color
38
+
39
+ obj_ids = []
40
+ obj_areas = []
41
+ obj_bbox = []
42
+ counter = 0
43
+ # loop over all objects
44
+ for obj in annotation.objects:
45
+ label = obj.label
46
+ polygon = obj.polygon
47
+
48
+ if (not label in name2label) and label.endswith('group'):
49
+ label = label[:-len('group')]
50
+
51
+ # only get car/truck/bus class
52
+ if name2label[label].id !=26 and name2label[label].id !=27 and name2label[label].id !=28:
53
+ continue
54
+
55
+ labelImg = Image.new("RGBA", size, background)
56
+ drawer = ImageDraw.Draw(labelImg)
57
+ drawer.polygon(polygon, fill=(255, 255, 255))
58
+ mask = np.array(labelImg)[:, :, 0]
59
+ area = np.sum(mask/255)
60
+ bbox = mask_to_bbox_xywh(mask)
61
+
62
+ if area > image_area * self.area_ratio:
63
+ obj_ids.append(counter)
64
+ obj_areas.append(area)
65
+ obj_bbox.append(bbox)
66
+
67
+ counter += 1
68
+
69
+ if len(obj_bbox) < 2:
70
+ print(f"[Info] Skip image index {image_name} due to insufficient bbox.")
71
+ return
72
+
73
+ # filter by IOU
74
+ bbox_xyxy = []
75
+ for box in obj_bbox:
76
+ x, y, w, h = box
77
+ bbox_xyxy.append([x, y, x + w, y + h])
78
+ bbox_xyxy = np.array(bbox_xyxy) # shape: [N, 4]
79
+
80
+ if IS_VERIFY:
81
+ os.makedirs(Path(self.construct_dataset_dir) / image_name, exist_ok=True)
82
+ image_with_boxes = draw_bboxes(image, bbox_xyxy)
83
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name / "bboxes_image.png"), cv2.cvtColor(image_with_boxes, cv2.COLOR_RGB2BGR))
84
+
85
+ iou_matrix = compute_iou_matrix(bbox_xyxy)
86
+ np.fill_diagonal(iou_matrix, -1) # Exclude self-comparisons (i.e., each box with itself)
87
+
88
+ max_index = np.unravel_index(np.argmax(iou_matrix), iou_matrix.shape)
89
+ index0, index1 = max_index[0], max_index[1]
90
+ max_iou = iou_matrix[index0, index1]
91
+
92
+ if max_iou <= 0:
93
+ print(f"[Info] Skip image index {image_name} due to no overlapping bboxes.")
94
+ return
95
+
96
+ os.makedirs(Path(self.construct_dataset_dir) / image_name, exist_ok=True)
97
+ dst = Path(self.construct_dataset_dir) / image_name / "image.jpg"
98
+ dst.parent.mkdir(parents=True, exist_ok=True)
99
+ shutil.copy(image_path, dst)
100
+
101
+ counter = 0
102
+ for obj in annotation.objects:
103
+ label = obj.label
104
+ polygon = obj.polygon
105
+
106
+ if (not label in name2label) and label.endswith('group'):
107
+ label = label[:-len('group')]
108
+
109
+ # only get car/truck/bus class
110
+ if name2label[label].id !=26 and name2label[label].id !=27 and name2label[label].id !=28:
111
+ continue
112
+
113
+ if counter == obj_ids[index0]:
114
+ labelImg = Image.new("RGBA", size, background)
115
+ drawer = ImageDraw.Draw(labelImg)
116
+ drawer.polygon(polygon, fill=(255, 255, 255))
117
+ mask = np.array(labelImg)[:, :, 0]/255
118
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name / "object_0_mask.png"), 255*mask)
119
+ patch = self.get_patch(image, mask)
120
+ patch = cv2.cvtColor(patch, cv2.COLOR_RGB2BGR)
121
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name / "object_0.png"), patch)
122
+ break
123
+ counter += 1
124
+
125
+ if IS_VERIFY:
126
+ mask_color = np.stack([mask * 255]*3, axis=-1).astype(np.uint8)
127
+ highlight = np.zeros_like(image)
128
+ highlight[:, :, 2] = 255 # red channel
129
+ alpha = 0.5
130
+ image_with_boxes = np.where(mask_color == 255, cv2.addWeighted(image_with_boxes, 1 - alpha, highlight, alpha, 0), image_with_boxes)
131
+
132
+
133
+ counter = 0
134
+ for obj in annotation.objects:
135
+ label = obj.label
136
+ polygon = obj.polygon
137
+
138
+ if (not label in name2label) and label.endswith('group'):
139
+ label = label[:-len('group')]
140
+
141
+ # only get car/truck/bus class
142
+ if name2label[label].id !=26 and name2label[label].id !=27 and name2label[label].id !=28:
143
+ continue
144
+
145
+ if counter == obj_ids[index1]:
146
+ labelImg = Image.new("RGBA", size, background)
147
+ drawer = ImageDraw.Draw(labelImg)
148
+ drawer.polygon(polygon, fill=(255, 255, 255))
149
+ mask = np.array(labelImg)[:, :, 0]/255
150
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name / "object_1_mask.png"), 255*mask)
151
+ patch = self.get_patch(image, mask)
152
+ patch = cv2.cvtColor(patch, cv2.COLOR_RGB2BGR)
153
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name / "object_1.png"), patch)
154
+ break
155
+ counter += 1
156
+
157
+ if IS_VERIFY:
158
+ mask_color = np.stack([mask * 255]*3, axis=-1).astype(np.uint8)
159
+ highlight = np.zeros_like(image)
160
+ highlight[:, :, 0] = 255 # blue channel
161
+ alpha = 0.5
162
+ image_with_boxes = np.where(mask_color == 255, cv2.addWeighted(image_with_boxes, 1 - alpha, highlight, alpha, 0), image_with_boxes)
163
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name / "highlighted_image.png"), cv2.cvtColor(image_with_boxes, cv2.COLOR_RGB2BGR))
164
+
165
+ def _get_sample(self, idx):
166
+ sample_path = os.path.join(self.construct_dataset_dir, self.sample_list[idx])
167
+ image = cv2.cvtColor(cv2.imread(os.path.join(sample_path, "image.jpg")), cv2.COLOR_BGR2RGB)
168
+ object_0 = cv2.cvtColor(cv2.imread(os.path.join(sample_path, "object_0.png")), cv2.COLOR_BGR2RGB)
169
+ object_1 = cv2.cvtColor(cv2.imread(os.path.join(sample_path, "object_1.png")), cv2.COLOR_BGR2RGB)
170
+ mask_0 = cv2.imread(os.path.join(sample_path, "object_0_mask.png"), cv2.IMREAD_GRAYSCALE)
171
+ mask_1 = cv2.imread(os.path.join(sample_path, "object_1_mask.png"), cv2.IMREAD_GRAYSCALE)
172
+ collage = self._construct_collage(image, object_0, object_1, mask_0, mask_1)
173
+ return collage
174
+
175
+ def __len__(self):
176
+ return len(os.listdir(self.construct_dataset_dir))
177
+
178
+
179
+ if __name__ == "__main__":
180
+ '''
181
+ two-object case: train/test: 536/78
182
+ '''
183
+ import argparse
184
+
185
+ parser = argparse.ArgumentParser(description="CityscapesDataset Analysis")
186
+ parser.add_argument("--dataset_dir", type=str, required=True, help="Path to the dataset directory.")
187
+ parser.add_argument("--construct_dataset_dir", type=str, default='bin', help="Path to the debug bin directory.")
188
+ parser.add_argument("--dataset_name", type=str, default='Cityscapes', help="Dataset name.")
189
+ parser.add_argument('--is_train', action='store_true', help="Train/Test")
190
+ parser.add_argument('--is_build_data', action='store_true', help="Build data")
191
+ parser.add_argument('--is_multiple', action='store_true', help="Multiple/Two objects")
192
+ parser.add_argument("--area_ratio", type=float, default=0.01171, help="Area ratio for filtering out small objects.")
193
+ parser.add_argument("--obj_thr", type=int, default=20, help="Object threshold for filtering.")
194
+ parser.add_argument("--index", type=int, default=0, help="Index of the sample to test.")
195
+ args = parser.parse_args()
196
+
197
+ if args.is_train:
198
+ image_dir = Path(args.dataset_dir) / args.dataset_name / "train" / "images"
199
+ json_dir = Path(args.dataset_dir) / args.dataset_name / "train" / "jsons"
200
+ max_num = 2975
201
+ else:
202
+ image_dir = Path(args.dataset_dir) / args.dataset_name / "val" / "images"
203
+ json_dir = Path(args.dataset_dir) / args.dataset_name / "val" / "jsons"
204
+ max_num = 500
205
+
206
+ dataset = CityscapesDataset(
207
+ construct_dataset_dir = args.construct_dataset_dir,
208
+ obj_thr = args.obj_thr,
209
+ area_ratio = args.area_ratio,
210
+ )
211
+
212
+ if args.is_build_data:
213
+ if not args.is_multiple:
214
+ for index in range(max_num):
215
+ dataset._intersect_2_obj(image_dir, json_dir, index)
216
+ else:
217
+ for index in range(len(os.listdir(args.construct_dataset_dir))):
218
+ collage = dataset._get_sample(index)
219
+
220
+
221
+
datasets/data_utils.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+
4
+ def resize_and_pad(image, box):
5
+ '''Fitting an image to the box region while keeping the aspect ratio.'''
6
+ y1,y2,x1,x2 = box
7
+ H,W = y2-y1, x2-x1
8
+ h,w = image.shape[0], image.shape[1]
9
+ r_box = W / H
10
+ r_image = w / h
11
+ if r_box >= r_image:
12
+ h_target = H
13
+ w_target = int(w * H / h)
14
+ image = cv2.resize(image, (w_target, h_target))
15
+
16
+ w1 = (W - w_target) // 2
17
+ w2 = W - w_target - w1
18
+ pad_param = ((0,0),(w1,w2),(0,0))
19
+ image = np.pad(image, pad_param, 'constant', constant_values=255)
20
+ else:
21
+ w_target = W
22
+ h_target = int(h * W / w)
23
+ image = cv2.resize(image, (w_target, h_target))
24
+
25
+ h1 = (H-h_target) // 2
26
+ h2 = H - h_target - h1
27
+ pad_param =((h1,h2),(0,0),(0,0))
28
+ image = np.pad(image, pad_param, 'constant', constant_values=255)
29
+ return image
30
+
31
+
32
+
33
+ def expand_image_mask(image, mask, ratio=1.4):
34
+ h,w = image.shape[0], image.shape[1]
35
+ H,W = int(h * ratio), int(w * ratio)
36
+ h1 = int((H - h) // 2)
37
+ h2 = H - h - h1
38
+ w1 = int((W -w) // 2)
39
+ w2 = W -w - w1
40
+
41
+ pad_param_image = ((h1,h2),(w1,w2),(0,0))
42
+ pad_param_mask = ((h1,h2),(w1,w2))
43
+ image = np.pad(image, pad_param_image, 'constant', constant_values=255)
44
+ mask = np.pad(mask, pad_param_mask, 'constant', constant_values=0)
45
+ return image, mask
46
+
47
+
48
+ def expand_image(image, ratio=1.4):
49
+ h,w = image.shape[0], image.shape[1]
50
+ H,W = int(h * ratio), int(w * ratio)
51
+ h1 = int((H - h) // 2)
52
+ h2 = H - h - h1
53
+ w1 = int((W -w) // 2)
54
+ w2 = W -w - w1
55
+
56
+ pad_param_image = ((h1,h2),(w1,w2),(0,0))
57
+ image = np.pad(image, pad_param_image, 'constant', constant_values=255)
58
+ return image
59
+
60
+
61
+ def expand_bbox(mask,yyxx,ratio=[1.2,2.0], min_crop=0):
62
+ y1,y2,x1,x2 = yyxx
63
+ ratio = np.random.randint( ratio[0] * 10, ratio[1] * 10 ) / 10
64
+ H,W = mask.shape[0], mask.shape[1]
65
+ xc, yc = 0.5 * (x1 + x2), 0.5 * (y1 + y2)
66
+ h = ratio * (y2-y1+1)
67
+ w = ratio * (x2-x1+1)
68
+ h = max(h,min_crop)
69
+ w = max(w,min_crop)
70
+
71
+ x1 = int(xc - w * 0.5)
72
+ x2 = int(xc + w * 0.5)
73
+ y1 = int(yc - h * 0.5)
74
+ y2 = int(yc + h * 0.5)
75
+
76
+ x1 = max(0,x1)
77
+ x2 = min(W,x2)
78
+ y1 = max(0,y1)
79
+ y2 = min(H,y2)
80
+ return (y1,y2,x1,x2)
81
+
82
+
83
+ def box2squre(image, box):
84
+ H,W = image.shape[0], image.shape[1]
85
+ y1,y2,x1,x2 = box
86
+ cx = (x1 + x2) // 2
87
+ cy = (y1 + y2) // 2
88
+ h,w = y2-y1, x2-x1
89
+
90
+ if h >= w:
91
+ x1 = cx - h//2
92
+ x2 = cx + h//2
93
+ else:
94
+ y1 = cy - w//2
95
+ y2 = cy + w//2
96
+ x1 = max(0,x1)
97
+ x2 = min(W,x2)
98
+ y1 = max(0,y1)
99
+ y2 = min(H,y2)
100
+ return (y1,y2,x1,x2)
101
+
102
+
103
+ def pad_to_square(image, pad_value = 255, random = False):
104
+ H,W = image.shape[0], image.shape[1]
105
+ if H == W:
106
+ return image
107
+
108
+ padd = abs(H - W)
109
+ if random:
110
+ padd_1 = int(np.random.randint(0,padd))
111
+ else:
112
+ padd_1 = int(padd / 2)
113
+ padd_2 = padd - padd_1
114
+
115
+ if H > W:
116
+ pad_param = ((0,0),(padd_1,padd_2),(0,0))
117
+ else:
118
+ pad_param = ((padd_1,padd_2),(0,0),(0,0))
119
+
120
+ image = np.pad(image, pad_param, 'constant', constant_values=pad_value)
121
+ return image
122
+
123
+ def get_bbox_from_mask(mask):
124
+ h,w = mask.shape[0],mask.shape[1]
125
+
126
+ if mask.sum() < 10:
127
+ return 0, h, 0, w
128
+ rows = np.any(mask, axis=1)
129
+ cols = np.any(mask, axis=0)
130
+ y1,y2 = np.where(rows)[0][[0, -1]]
131
+ x1,x2 = np.where(cols)[0][[0, -1]]
132
+ return (y1, y2, x1, x2)
133
+
134
+ def box_in_box(small_box, big_box):
135
+ y1, y2, x1, x2 = small_box
136
+ y1_b, _, x1_b, _ = big_box
137
+ y1, y2, x1, x2 = y1 - y1_b ,y2 - y1_b, x1 - x1_b, x2 - x1_b
138
+ return (y1, y2, x1, x2)
139
+
140
+ def crop_back(pred, tar_image, extra_sizes, tar_box_yyxx_crop, tar_box_yyxx_crop2, is_masked=False):
141
+ H1, W1, H2, W2 = extra_sizes
142
+ y1, x1, y2, x2 = tar_box_yyxx_crop
143
+ y1_, x1_, y2_, x2_ = tar_box_yyxx_crop2
144
+ m = 0 # maigin_pixel
145
+
146
+ if H1 < W1:
147
+ pad1 = int((W1 - H1) / 2)
148
+ pad2 = W1 - H1 - pad1
149
+ pred = pred[pad1: -pad2, :, :]
150
+ elif H1 > W1:
151
+ pad1 = int((H1 - W1) / 2)
152
+ pad2 = H1 - W1 - pad1
153
+ pred = pred[:,pad1: -pad2, :]
154
+
155
+ if is_masked:
156
+ gen_image = tar_image.copy()
157
+ gen_image[y1+m :y2-m, x1+m:x2-m, :] = pred[y1+m :y2-m, x1+m:x2-m, :]
158
+ gen_image[y1_+m :y2_-m, x1_+m:x2_-m, :] = pred[y1_+m :y2_-m, x1_+m:x2_-m, :]
159
+ else:
160
+ gen_image = pred
161
+ return gen_image
datasets/lvis.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import os
4
+ from .data_utils import *
5
+ from .base import BaseDataset
6
+ from lvis import LVIS
7
+ from pathlib import Path
8
+ from util.box_ops import compute_iou_matrix, draw_bboxes
9
+ import shutil
10
+
11
+ IS_VERIFY = False
12
+
13
+ class LVISDataset(BaseDataset):
14
+ def __init__(self, construct_dataset_dir, obj_thr=20, area_ratio=0.02):
15
+ self.obj_thr = obj_thr
16
+ self.construct_dataset_dir = construct_dataset_dir
17
+ os.makedirs(Path(self.construct_dataset_dir), exist_ok=True)
18
+ self.area_ratio = area_ratio
19
+ self.sample_list = os.listdir(self.construct_dataset_dir)
20
+
21
+ def _get_image_path(self, file_name):
22
+ for img_dir in self.image_dir:
23
+ path = img_dir / file_name
24
+ if path.exists():
25
+ return str(path)
26
+ raise FileNotFoundError(f"File {file_name} not found in any of the image_dir.")
27
+
28
+ def _intersect_2_obj(self, image_dir, lvis_api, imgs_info, annos, idx):
29
+ self.image_dir = image_dir
30
+ image_name = imgs_info[idx]['coco_url'].split('/')[-1]
31
+ image_path = self._get_image_path(image_name)
32
+ image = cv2.imread(image_path)
33
+
34
+ h, w = image.shape[0:2]
35
+ image_area = h*w
36
+
37
+ anno = annos[idx]
38
+
39
+ # filter by area
40
+ obj_ids = []
41
+ obj_areas = []
42
+ obj_bbox = []
43
+ for i in range(len(anno)):
44
+ obj = anno[i]
45
+ area = obj['area']
46
+ bbox = obj['bbox'] # xyhw
47
+ if area > image_area * self.area_ratio:
48
+ obj_ids.append(i)
49
+ obj_areas.append(area)
50
+ obj_bbox.append(bbox)
51
+
52
+ if len(obj_bbox) < 2:
53
+ print(f"[Info] Skip image index {image_name[:-4]} due to insufficient bbox.")
54
+ return
55
+
56
+ # filter by IOU
57
+ bbox_xyxy = []
58
+ for box in obj_bbox:
59
+ x, y, w, h = box
60
+ bbox_xyxy.append([x, y, x + w, y + h])
61
+ bbox_xyxy = np.array(bbox_xyxy) # shape: [N, 4]
62
+
63
+ if IS_VERIFY:
64
+ os.makedirs(Path(self.construct_dataset_dir) / image_name[:-4], exist_ok=True)
65
+ image_with_boxes = draw_bboxes(image, bbox_xyxy)
66
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-4] / "bboxes_image.png"), image_with_boxes)
67
+
68
+ iou_matrix = compute_iou_matrix(bbox_xyxy)
69
+ np.fill_diagonal(iou_matrix, -1) # Exclude self-comparisons (i.e., each box with itself)
70
+
71
+ max_index = np.unravel_index(np.argmax(iou_matrix), iou_matrix.shape)
72
+ index0, index1 = max_index[0], max_index[1]
73
+ max_iou = iou_matrix[index0, index1]
74
+
75
+ if max_iou <= 0:
76
+ print(f"[Info] Skip image index {image_name[:-4]} due to no overlapping bboxes.")
77
+ return
78
+
79
+ os.makedirs(Path(self.construct_dataset_dir) / image_name[:-4], exist_ok=True)
80
+ dst = Path(self.construct_dataset_dir) / image_name[:-4] / "image.jpg"
81
+ dst.parent.mkdir(parents=True, exist_ok=True)
82
+ shutil.copy(image_path, dst)
83
+
84
+ anno_id = anno[obj_ids[index0]]
85
+ mask = lvis_api.ann_to_mask(anno_id)
86
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-4] / "object_0_mask.png"), 255*mask)
87
+ patch = self.get_patch(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), mask)
88
+ patch = cv2.cvtColor(patch, cv2.COLOR_RGB2BGR)
89
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-4] / "object_0.png"), patch)
90
+
91
+ if IS_VERIFY:
92
+ mask_color = np.stack([mask * 255]*3, axis=-1).astype(np.uint8)
93
+ highlight = np.zeros_like(image)
94
+ highlight[:, :, 2] = 255 # red channel
95
+ alpha = 0.5
96
+ image_with_boxes = np.where(mask_color == 255, cv2.addWeighted(image_with_boxes, 1 - alpha, highlight, alpha, 0), image_with_boxes)
97
+
98
+ anno_id = anno[obj_ids[index1]]
99
+ mask = lvis_api.ann_to_mask(anno_id)
100
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-4] / "object_1_mask.png"), 255*mask)
101
+ patch = self.get_patch(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), mask)
102
+ patch = cv2.cvtColor(patch, cv2.COLOR_RGB2BGR)
103
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-4] / "object_1.png"), patch)
104
+
105
+ if IS_VERIFY:
106
+ mask_color = np.stack([mask * 255]*3, axis=-1).astype(np.uint8)
107
+ highlight = np.zeros_like(image)
108
+ highlight[:, :, 0] = 255 # blue channel
109
+ alpha = 0.5
110
+ image_with_boxes = np.where(mask_color == 255, cv2.addWeighted(image_with_boxes, 1 - alpha, highlight, alpha, 0), image_with_boxes)
111
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-4] / "highlighted_image.png"), image_with_boxes)
112
+
113
+ def _intersect_3_obj(self, image_dir, lvis_api, imgs_info, annos, idx):
114
+ self.image_dir = image_dir
115
+ image_name = imgs_info[idx]['coco_url'].split('/')[-1]
116
+ image_path = self._get_image_path(image_name)
117
+ image = cv2.imread(image_path)
118
+
119
+ h, w = image.shape[0:2]
120
+ image_area = h * w
121
+
122
+ anno = annos[idx]
123
+
124
+ # filter by area
125
+ obj_ids = []
126
+ obj_areas = []
127
+ obj_bbox = []
128
+ for i, obj in enumerate(anno):
129
+ area = obj['area']
130
+ bbox = obj['bbox'] # xywh
131
+ if area > image_area * self.area_ratio:
132
+ obj_ids.append(i)
133
+ obj_areas.append(area)
134
+ obj_bbox.append(bbox)
135
+
136
+ if len(obj_bbox) < 3:
137
+ print(f"[Info] Skip image index {image_name[:-4]} due to insufficient bbox (need >=3, got {len(obj_bbox)}).")
138
+ return
139
+
140
+ # calculate IOU matrix
141
+ bbox_xyxy = []
142
+ for box in obj_bbox:
143
+ x, y, w_box, h_box = box
144
+ bbox_xyxy.append([x, y, x + w_box, y + h_box])
145
+ bbox_xyxy = np.array(bbox_xyxy) # shape: [N, 4]
146
+
147
+ if IS_VERIFY:
148
+ os.makedirs(Path(self.construct_dataset_dir) / image_name[:-4], exist_ok=True)
149
+ image_with_boxes = draw_bboxes(image, bbox_xyxy)
150
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-4] / "bboxes_image.png"), image_with_boxes)
151
+
152
+ iou_matrix = compute_iou_matrix(bbox_xyxy)
153
+ np.fill_diagonal(iou_matrix, -1) # Exclude self-comparisons
154
+
155
+ # find 3 overlapped objects
156
+ positive_iou = np.where(iou_matrix > 0, iou_matrix, 0.0)
157
+ row_sums = positive_iou.sum(axis=1)
158
+ anchor = int(np.argmax(row_sums))
159
+
160
+ partner_candidates = np.argsort(iou_matrix[anchor])[::-1]
161
+ partners = [int(p) for p in partner_candidates if iou_matrix[anchor, p] > 0]
162
+
163
+ if len(partners) < 2:
164
+ print(f"[Info] Skip image index {image_name[:-4]} due to not enough overlapping bboxes for 3 objects.")
165
+ return
166
+
167
+ index0 = anchor
168
+ index1 = partners[0]
169
+ index2 = partners[1]
170
+
171
+ max_iou_pair = max(iou_matrix[index0, index1], iou_matrix[index0, index2], iou_matrix[index1, index2])
172
+ if max_iou_pair <= 0:
173
+ print(f"[Info] Skip image index {image_name[:-4]} due to no overlapping bboxes.")
174
+ return
175
+
176
+ # copy original image
177
+ out_dir = Path(self.construct_dataset_dir) / image_name[:-4]
178
+ out_dir.mkdir(parents=True, exist_ok=True)
179
+ dst = out_dir / "image.jpg"
180
+ shutil.copy(image_path, dst)
181
+
182
+ # first object
183
+ anno_id = anno[obj_ids[index0]]
184
+ mask0 = lvis_api.ann_to_mask(anno_id)
185
+ cv2.imwrite(str(out_dir / "object_0_mask.png"), 255 * mask0)
186
+ patch0 = self.get_patch(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), mask0)
187
+ patch0 = cv2.cvtColor(patch0, cv2.COLOR_RGB2BGR)
188
+ cv2.imwrite(str(out_dir / "object_0.png"), patch0)
189
+
190
+ if IS_VERIFY:
191
+ mask_color = np.stack([mask0 * 255] * 3, axis=-1).astype(np.uint8)
192
+ highlight = np.zeros_like(image)
193
+ highlight[:, :, 2] = 255 # red channel
194
+ alpha = 0.5
195
+ image_with_boxes = np.where(
196
+ mask_color == 255,
197
+ cv2.addWeighted(image_with_boxes, 1 - alpha, highlight, alpha, 0),
198
+ image_with_boxes
199
+ )
200
+
201
+ # second object
202
+ anno_id = anno[obj_ids[index1]]
203
+ mask1 = lvis_api.ann_to_mask(anno_id)
204
+ cv2.imwrite(str(out_dir / "object_1_mask.png"), 255 * mask1)
205
+ patch1 = self.get_patch(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), mask1)
206
+ patch1 = cv2.cvtColor(patch1, cv2.COLOR_RGB2BGR)
207
+ cv2.imwrite(str(out_dir / "object_1.png"), patch1)
208
+
209
+ if IS_VERIFY:
210
+ mask_color = np.stack([mask1 * 255] * 3, axis=-1).astype(np.uint8)
211
+ highlight = np.zeros_like(image)
212
+ highlight[:, :, 0] = 255 # blue channel
213
+ alpha = 0.5
214
+ image_with_boxes = np.where(
215
+ mask_color == 255,
216
+ cv2.addWeighted(image_with_boxes, 1 - alpha, highlight, alpha, 0),
217
+ image_with_boxes
218
+ )
219
+
220
+ # third object
221
+ anno_id = anno[obj_ids[index2]]
222
+ mask2 = lvis_api.ann_to_mask(anno_id)
223
+ cv2.imwrite(str(out_dir / "object_2_mask.png"), 255 * mask2)
224
+ patch2 = self.get_patch(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), mask2)
225
+ patch2 = cv2.cvtColor(patch2, cv2.COLOR_RGB2BGR)
226
+ cv2.imwrite(str(out_dir / "object_2.png"), patch2)
227
+
228
+ if IS_VERIFY:
229
+ mask_color = np.stack([mask2 * 255] * 3, axis=-1).astype(np.uint8)
230
+ highlight = np.zeros_like(image)
231
+ highlight[:, :, 1] = 255 # green channel
232
+ alpha = 0.5
233
+ image_with_boxes = np.where(
234
+ mask_color == 255,
235
+ cv2.addWeighted(image_with_boxes, 1 - alpha, highlight, alpha, 0),
236
+ image_with_boxes
237
+ )
238
+ cv2.imwrite(str(out_dir / "highlighted_image.png"), image_with_boxes)
239
+
240
+
241
+ def _get_sample(self, idx):
242
+ sample_path = os.path.join(self.construct_dataset_dir, self.sample_list[idx])
243
+ image = cv2.cvtColor(cv2.imread(os.path.join(sample_path, "image.jpg")), cv2.COLOR_BGR2RGB)
244
+ object_0 = cv2.cvtColor(cv2.imread(os.path.join(sample_path, "object_0.png")), cv2.COLOR_BGR2RGB)
245
+ object_1 = cv2.cvtColor(cv2.imread(os.path.join(sample_path, "object_1.png")), cv2.COLOR_BGR2RGB)
246
+ mask_0 = cv2.imread(os.path.join(sample_path, "object_0_mask.png"), cv2.IMREAD_GRAYSCALE)
247
+ mask_1 = cv2.imread(os.path.join(sample_path, "object_1_mask.png"), cv2.IMREAD_GRAYSCALE)
248
+ collage = self._construct_collage(image, object_0, object_1, mask_0, mask_1)
249
+ return collage
250
+
251
+ def __len__(self):
252
+ return len(os.listdir(self.construct_dataset_dir))
253
+
254
+
255
+ if __name__ == "__main__":
256
+ '''
257
+ two-object case: train/test: 34610/8859
258
+ '''
259
+ import argparse
260
+
261
+ parser = argparse.ArgumentParser(description="LVISDataset Analysis")
262
+ parser.add_argument("--dataset_dir", type=str, required=True, help="Path to the dataset directory.")
263
+ parser.add_argument("--construct_dataset_dir", type=str, default='bin', help="Path to the debug bin directory.")
264
+ parser.add_argument("--dataset_name", type=str, default='COCO', help="Dataset name.")
265
+ parser.add_argument('--is_train', action='store_true', help="Train/Test")
266
+ parser.add_argument('--is_build_data', action='store_true', help="Build data")
267
+ parser.add_argument('--is_multiple', action='store_true', help="Multiple/Two objects")
268
+ parser.add_argument("--area_ratio", type=float, default=0.01171, help="Area ratio for filtering out small objects.")
269
+ parser.add_argument("--obj_thr", type=int, default=20, help="Object threshold for filtering.")
270
+ parser.add_argument("--index", type=int, default=0, help="Index of the sample to test.")
271
+ args = parser.parse_args()
272
+
273
+ image_dirs = [
274
+ Path(args.dataset_dir) / args.dataset_name / "train2017",
275
+ Path(args.dataset_dir) / args.dataset_name / "val2017",
276
+ ]
277
+
278
+ if args.is_train:
279
+ json_path = Path(args.dataset_dir) / args.dataset_name / "lvis_v1/lvis_v1_train.json"
280
+ max_num = 2000000
281
+ else:
282
+ json_path = Path(args.dataset_dir) / args.dataset_name / "lvis_v1/lvis_v1_val.json"
283
+ max_num = 30000
284
+
285
+ dataset = LVISDataset(
286
+ construct_dataset_dir = args.construct_dataset_dir,
287
+ obj_thr = args.obj_thr,
288
+ area_ratio = args.area_ratio,
289
+ )
290
+
291
+ lvis_api = LVIS(json_path)
292
+ img_ids = sorted(lvis_api.imgs.keys())
293
+ imgs_info = lvis_api.load_imgs(img_ids)
294
+ annos = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
295
+
296
+ if args.is_build_data:
297
+ if not args.is_multiple:
298
+ for index in range(max_num):
299
+ dataset._intersect_2_obj(image_dirs, lvis_api, imgs_info, annos, index)
300
+ # dataset._intersect_3_obj(image_dirs, lvis_api, imgs_info, annos, index)
301
+ else:
302
+ for index in range(len(os.listdir(args.construct_dataset_dir))):
303
+ collage = dataset._get_sample(index)
datasets/mapillary_vistas.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import cv2
3
+ import numpy as np
4
+ import os
5
+ from PIL import Image
6
+ from .data_utils import *
7
+ from .base import BaseDataset
8
+ from util.box_ops import mask_to_bbox_xywh, compute_iou_matrix, draw_bboxes
9
+ from pathlib import Path
10
+ import shutil
11
+
12
+ IS_VERIFY = False
13
+
14
+ class MapillaryVistasDataset(BaseDataset):
15
+ def __init__(self, construct_dataset_dir, obj_thr=20, area_ratio=0.02):
16
+ self.obj_thr = obj_thr
17
+ self.construct_dataset_dir = construct_dataset_dir
18
+ os.makedirs(Path(self.construct_dataset_dir), exist_ok=True)
19
+ self.area_ratio = area_ratio
20
+ self.sample_list = os.listdir(self.construct_dataset_dir)
21
+
22
+ def _intersect_2_obj(self, image_dir, instance_dir, labels, idx):
23
+ json_list = os.listdir(instance_dir)
24
+ image_name = json_list[idx][:-4]
25
+ image_path = os.path.join(image_dir, image_name+'.jpg')
26
+ image = cv2.imread(image_path)
27
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
28
+
29
+ instance_path = os.path.join(instance_dir, image_name+'.png')
30
+ instance_image = Image.open(instance_path)
31
+ instance_array = np.array(instance_image, dtype=np.uint16)
32
+
33
+ instance_label_array = np.array(instance_array / 256, dtype=np.uint8)
34
+ instance_ids_array = np.array(instance_array % 256, dtype=np.uint8)
35
+
36
+ img_h, img_w = image.shape[0:2]
37
+ image_area = img_h*img_w
38
+
39
+ # vehicle_keywords = ['car', 'truck', 'bus']
40
+ # excluded_keywords = ['bicycle']
41
+
42
+ # vehicle_ids = []
43
+ # for idx, label in enumerate(labels):
44
+ # name = label['name'].lower()
45
+ # if any(k in name for k in vehicle_keywords) and not any(k in name for k in excluded_keywords):
46
+ # vehicle_ids.append(idx)
47
+
48
+ '''
49
+ ids: 107, 'name': 'object--vehicle--bus', 'readable': 'Bus', 'color': [0, 60, 100]
50
+ ids: 108, 'name': 'object--vehicle--car', 'readable': 'Car', 'color': [0, 0, 142]
51
+ ids: 109, 'name': 'object--vehicle--caravan', 'readable': 'Caravan', 'color': [0, 0, 90]
52
+ ids: 114, 'name': 'object--vehicle--truck', 'readable': 'Truck', 'color': [0, 0, 70]
53
+ '''
54
+
55
+ target_class_ids = [107, 108, 109, 114]
56
+ max_instance = np.max(instance_ids_array)
57
+
58
+ obj_ids = []
59
+ obj_areas = []
60
+ obj_bbox = []
61
+ counter = 0
62
+ for target_id in target_class_ids:
63
+ semantic_mask = (instance_label_array == target_id)
64
+ for idx in range(max_instance):
65
+ instance_mask = (instance_ids_array == idx)
66
+ mask = np.logical_and(semantic_mask, instance_mask).astype(np.uint8)
67
+ area = np.sum(mask)
68
+ bbox = mask_to_bbox_xywh(mask)
69
+ if area > image_area * self.area_ratio:
70
+ obj_ids.append(counter)
71
+ obj_areas.append(area)
72
+ obj_bbox.append(bbox)
73
+ counter += 1
74
+
75
+ if len(obj_bbox) < 2:
76
+ print(f"[Info] Skip image index {image_name} due to insufficient bbox.")
77
+ return
78
+
79
+ # filter by IOU
80
+ bbox_xyxy = []
81
+ for box in obj_bbox:
82
+ x, y, w, h = box
83
+ bbox_xyxy.append([x, y, x + w, y + h])
84
+ bbox_xyxy = np.array(bbox_xyxy) # shape: [N, 4]
85
+ os.makedirs(Path(self.construct_dataset_dir) / image_name, exist_ok=True)
86
+
87
+ if IS_VERIFY:
88
+ image_with_boxes = draw_bboxes(image, bbox_xyxy)
89
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name / "bboxes_image.png"), cv2.cvtColor(image_with_boxes, cv2.COLOR_RGB2BGR))
90
+
91
+
92
+ iou_matrix = compute_iou_matrix(bbox_xyxy)
93
+ np.fill_diagonal(iou_matrix, -1) # Exclude self-comparisons (i.e., each box with itself)
94
+
95
+ max_index = np.unravel_index(np.argmax(iou_matrix), iou_matrix.shape)
96
+ index0, index1 = max_index[0], max_index[1]
97
+ max_iou = iou_matrix[index0, index1]
98
+
99
+ if max_iou <= 0:
100
+ print(f"[Info] Skip image index {image_name} due to no overlapping bboxes.")
101
+ return
102
+
103
+ dst = Path(self.construct_dataset_dir) / image_name / "image.jpg"
104
+ dst.parent.mkdir(parents=True, exist_ok=True)
105
+ shutil.copy(image_path, dst)
106
+
107
+ counter = 0
108
+ found = False
109
+ for target_id in target_class_ids:
110
+ semantic_mask = (instance_label_array == target_id)
111
+ for idx in range(max_instance):
112
+ if counter == obj_ids[index0]:
113
+ instance_mask = (instance_ids_array == idx)
114
+ mask = np.logical_and(semantic_mask, instance_mask).astype(np.uint8)
115
+ found = True
116
+ break
117
+ counter += 1
118
+ if found:
119
+ break
120
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name / "object_0_mask.png"), 255*mask)
121
+ patch = self.get_patch(image, mask)
122
+ patch = cv2.cvtColor(patch, cv2.COLOR_RGB2BGR)
123
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name / "object_0.png"), patch)
124
+
125
+ if IS_VERIFY:
126
+ mask_color = np.stack([mask * 255]*3, axis=-1).astype(np.uint8)
127
+ highlight = np.zeros_like(image)
128
+ highlight[:, :, 2] = 255 # red channel
129
+ alpha = 0.5
130
+ image_with_boxes = np.where(mask_color == 255, cv2.addWeighted(image_with_boxes, 1 - alpha, highlight, alpha, 0), image_with_boxes)
131
+
132
+ counter = 0
133
+ found = False
134
+ for target_id in target_class_ids:
135
+ semantic_mask = (instance_label_array == target_id)
136
+ for idx in range(max_instance):
137
+ if counter == obj_ids[index1]:
138
+ instance_mask = (instance_ids_array == idx)
139
+ mask = np.logical_and(semantic_mask, instance_mask).astype(np.uint8)
140
+ found = True
141
+ break
142
+ counter += 1
143
+ if found:
144
+ break
145
+
146
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name / "object_1_mask.png"), 255*mask)
147
+ patch = self.get_patch(image, mask)
148
+ patch = cv2.cvtColor(patch, cv2.COLOR_RGB2BGR)
149
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name / "object_1.png"), patch)
150
+
151
+ if IS_VERIFY:
152
+ mask_color = np.stack([mask * 255]*3, axis=-1).astype(np.uint8)
153
+ highlight = np.zeros_like(image)
154
+ highlight[:, :, 0] = 255 # blue channel
155
+ alpha = 0.5
156
+ image_with_boxes = np.where(mask_color == 255, cv2.addWeighted(image_with_boxes, 1 - alpha, highlight, alpha, 0), image_with_boxes)
157
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name / "highlighted_image.png"), cv2.cvtColor(image_with_boxes, cv2.COLOR_RGB2BGR))
158
+
159
+ def _get_sample(self, idx):
160
+ sample_path = os.path.join(self.construct_dataset_dir, self.sample_list[idx])
161
+ image = cv2.cvtColor(cv2.imread(os.path.join(sample_path, "image.jpg")), cv2.COLOR_BGR2RGB)
162
+ object_0 = cv2.cvtColor(cv2.imread(os.path.join(sample_path, "object_0.png")), cv2.COLOR_BGR2RGB)
163
+ object_1 = cv2.cvtColor(cv2.imread(os.path.join(sample_path, "object_1.png")), cv2.COLOR_BGR2RGB)
164
+ mask_0 = cv2.imread(os.path.join(sample_path, "object_0_mask.png"), cv2.IMREAD_GRAYSCALE)
165
+ mask_1 = cv2.imread(os.path.join(sample_path, "object_1_mask.png"), cv2.IMREAD_GRAYSCALE)
166
+ collage = self._construct_collage(image, object_0, object_1, mask_0, mask_1)
167
+ return collage
168
+
169
+ def __len__(self):
170
+ return len(os.listdir(self.construct_dataset_dir))
171
+
172
+
173
+ if __name__ == "__main__":
174
+ '''
175
+ two-object case: train/test: 603/190
176
+ '''
177
+ import argparse
178
+
179
+ parser = argparse.ArgumentParser(description="MapillaryVistasDataset Analysis")
180
+ parser.add_argument("--dataset_dir", type=str, required=True, help="Path to the dataset directory.")
181
+ parser.add_argument("--construct_dataset_dir", type=str, default='bin', help="Path to the debug bin directory.")
182
+ parser.add_argument("--dataset_name", type=str, default='MVD', help="Dataset name.")
183
+ parser.add_argument('--is_train', action='store_true', help="Train/Test")
184
+ parser.add_argument('--is_build_data', action='store_true', help="Build data")
185
+ parser.add_argument('--is_multiple', action='store_true', help="Multiple/Two objects")
186
+ parser.add_argument("--area_ratio", type=float, default=0.01171, help="Area ratio for filtering out small objects.")
187
+ parser.add_argument("--obj_thr", type=int, default=20, help="Object threshold for filtering.")
188
+ parser.add_argument("--index", type=int, default=0, help="Index of the sample to test.")
189
+ args = parser.parse_args()
190
+
191
+ version = "v2.0" # "v1.2"
192
+ config_path = Path(args.dataset_dir) / args.dataset_name / f'config_{version}.json'
193
+ with open(config_path) as config_file:
194
+ config = json.load(config_file)
195
+ labels = config['labels']
196
+
197
+ if args.is_train:
198
+ image_dir = Path(args.dataset_dir) / args.dataset_name / "training" / "images"
199
+ instance_dir = Path(args.dataset_dir) / args.dataset_name / "training" / "v2.0" / "instances"
200
+ max_num = 18000
201
+ else:
202
+ image_dir = Path(args.dataset_dir) / args.dataset_name / "validation" / "images"
203
+ instance_dir = Path(args.dataset_dir) / args.dataset_name / "validation" / "v2.0" / "instances"
204
+ max_num = 2000
205
+
206
+ dataset = MapillaryVistasDataset(
207
+ construct_dataset_dir = args.construct_dataset_dir,
208
+ obj_thr = args.obj_thr,
209
+ area_ratio = args.area_ratio,
210
+ )
211
+
212
+ if args.is_build_data:
213
+ if not args.is_multiple:
214
+ for index in range(max_num):
215
+ dataset._intersect_2_obj(image_dir, instance_dir, labels, index)
216
+ print('Done index ', index)
217
+ else:
218
+ for index in range(len(os.listdir(args.construct_dataset_dir))):
219
+ collage = dataset._get_sample(index)
220
+
221
+
222
+ '''
223
+ 25,000 high-resolution images
224
+ 124 semantic object categories
225
+ 100 instance-specifically annotated categories
226
+ Global reach, covering 6 continents
227
+ Variety of weather, season, time of day, camera, and viewpoint
228
+ '''
datasets/objects365.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import cv2
3
+ import numpy as np
4
+ import os
5
+ from .data_utils import *
6
+ from .base import BaseDataset
7
+ from pycocotools import mask as mask_utils
8
+ from pathlib import Path
9
+ from util.box_ops import compute_iou_matrix, draw_bboxes
10
+ import shutil
11
+
12
+ IS_VERIFY = False
13
+ IS_BOX = False
14
+
15
+ def save_bboxes(bbox_xyxy, save_path="bboxes.txt"):
16
+ bbox_xyxy = np.atleast_2d(bbox_xyxy)
17
+ with open(save_path, "a") as f:
18
+ np.savetxt(f, bbox_xyxy, fmt="%.2f", delimiter=" ")
19
+
20
+ class Objects365Dataset(BaseDataset):
21
+ def __init__(self, construct_dataset_dir, obj_thr=20, area_ratio=0.02):
22
+ self.obj_thr = obj_thr
23
+ self.construct_dataset_dir = construct_dataset_dir
24
+ os.makedirs(Path(self.construct_dataset_dir), exist_ok=True)
25
+ self.area_ratio = area_ratio
26
+ self.sample_list = os.listdir(self.construct_dataset_dir)
27
+
28
+ def _get_all_file_paths_recursive(self, root_dir):
29
+ all_files = []
30
+ for dirpath, _, filenames in os.walk(root_dir):
31
+ for f in filenames:
32
+ abs_path = os.path.abspath(os.path.join(dirpath, f))
33
+ all_files.append(abs_path)
34
+ return all_files
35
+
36
+ def _get_image_path(self, file_name):
37
+ for img_dir in self.image_dir:
38
+ path = img_dir / file_name
39
+ if path.exists():
40
+ return str(path)
41
+ raise FileNotFoundError(f"File {file_name} not found in any of the image_dir.")
42
+
43
+ def _intersect_2_obj(self, image_dir, json_dir, idx):
44
+ self.image_dir = image_dir
45
+ self.json_list = self._get_all_file_paths_recursive(json_dir)
46
+ json_path = self.json_list[idx]
47
+ image_name = json_path.split('/')[-1]
48
+ image_subset = json_path.split('/')[-2]
49
+
50
+ image_path = os.path.join(os.path.join(image_dir, image_subset), image_name[:-5]+'.jpg')
51
+ image = cv2.imread(image_path)
52
+
53
+ with open(json_path) as f:
54
+ data = json.load(f)
55
+ image_id = data["image_id"]
56
+ annotations = data["annotations"]
57
+
58
+ img_h, img_w = image.shape[0:2]
59
+ image_area = img_h*img_w
60
+
61
+ anno = annotations
62
+
63
+ # filter by area
64
+ obj_ids = []
65
+ obj_areas = []
66
+ obj_bbox = []
67
+ for i in range(len(anno)):
68
+ obj = anno[i]
69
+ area = obj['area']
70
+ bbox = obj['bbox'] # xyhw
71
+ if area > image_area * self.area_ratio:
72
+ obj_ids.append(i)
73
+ obj_areas.append(area)
74
+ obj_bbox.append(bbox)
75
+
76
+ if len(obj_bbox) < 2:
77
+ print(f"[Info] Skip image index {image_name[:-5]} due to insufficient bbox.")
78
+ return
79
+
80
+ # filter by IOU
81
+ bbox_xyxy = []
82
+ for box in obj_bbox:
83
+ x, y, w, h = box
84
+ bbox_xyxy.append([x, y, x + w, y + h])
85
+ bbox_xyxy = np.array(bbox_xyxy) # shape: [N, 4]
86
+
87
+ if IS_VERIFY:
88
+ os.makedirs(Path(self.construct_dataset_dir) / image_name[:-5], exist_ok=True)
89
+ image_with_boxes = draw_bboxes(image, bbox_xyxy)
90
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-5] / "bboxes_image.png"), image_with_boxes)
91
+
92
+ iou_matrix = compute_iou_matrix(bbox_xyxy)
93
+ np.fill_diagonal(iou_matrix, -1) # Exclude self-comparisons (i.e., each box with itself)
94
+
95
+ max_index = np.unravel_index(np.argmax(iou_matrix), iou_matrix.shape)
96
+ index0, index1 = max_index[0], max_index[1]
97
+ max_iou = iou_matrix[index0, index1]
98
+
99
+ if max_iou <= 0:
100
+ print(f"[Info] Skip image index {image_name[:-5]} due to no overlapping bboxes.")
101
+ return
102
+
103
+ if IS_BOX:
104
+ save_bboxes(bbox_xyxy[index0], '/home/hang18/links/projects/rrg-vislearn/hang18/bboxes0.txt')
105
+ save_bboxes(bbox_xyxy[index1], '/home/hang18/links/projects/rrg-vislearn/hang18/bboxes1.txt')
106
+
107
+ os.makedirs(Path(self.construct_dataset_dir) / image_name[:-5], exist_ok=True)
108
+ # cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-4] / "image.jpg"), image) # source image
109
+ dst = Path(self.construct_dataset_dir) / image_name[:-5] / "image.jpg"
110
+ dst.parent.mkdir(parents=True, exist_ok=True)
111
+ shutil.copy(image_path, dst)
112
+
113
+ segmentation = anno[obj_ids[index0]]["segmentation"]
114
+ rles = mask_utils.frPyObjects(segmentation, img_h, img_w)
115
+ rle = mask_utils.merge(rles)
116
+ mask = mask_utils.decode(rle)
117
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-5] / "object_0_mask.png"), 255*mask)
118
+ patch = self.get_patch(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), mask)
119
+ patch = cv2.cvtColor(patch, cv2.COLOR_RGB2BGR)
120
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-5] / "object_0.png"), patch)
121
+
122
+ if IS_VERIFY:
123
+ mask_color = np.stack([mask * 255]*3, axis=-1).astype(np.uint8)
124
+ highlight = np.zeros_like(image)
125
+ highlight[:, :, 2] = 255 # red channel
126
+ alpha = 0.5
127
+ image_with_boxes = np.where(mask_color == 255, cv2.addWeighted(image_with_boxes, 1 - alpha, highlight, alpha, 0), image_with_boxes)
128
+
129
+ segmentation = anno[obj_ids[index1]]["segmentation"]
130
+ rles = mask_utils.frPyObjects(segmentation, img_h, img_w)
131
+ rle = mask_utils.merge(rles)
132
+ mask = mask_utils.decode(rle)
133
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-5] / "object_1_mask.png"), 255*mask)
134
+ patch = self.get_patch(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), mask)
135
+ patch = cv2.cvtColor(patch, cv2.COLOR_RGB2BGR)
136
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-5] / "object_1.png"), patch)
137
+
138
+ if IS_VERIFY:
139
+ mask_color = np.stack([mask * 255]*3, axis=-1).astype(np.uint8)
140
+ highlight = np.zeros_like(image)
141
+ highlight[:, :, 0] = 255 # blue channel
142
+ alpha = 0.5
143
+ image_with_boxes = np.where(mask_color == 255, cv2.addWeighted(image_with_boxes, 1 - alpha, highlight, alpha, 0), image_with_boxes)
144
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-5] / "highlighted_image.png"), image_with_boxes)
145
+
146
+ def _get_sample(self, idx):
147
+ sample_path = os.path.join(self.construct_dataset_dir, self.sample_list[idx])
148
+ image = cv2.cvtColor(cv2.imread(os.path.join(sample_path, "image.jpg")), cv2.COLOR_BGR2RGB)
149
+ object_0 = cv2.cvtColor(cv2.imread(os.path.join(sample_path, "object_0.png")), cv2.COLOR_BGR2RGB)
150
+ object_1 = cv2.cvtColor(cv2.imread(os.path.join(sample_path, "object_1.png")), cv2.COLOR_BGR2RGB)
151
+ mask_0 = cv2.imread(os.path.join(sample_path, "object_0_mask.png"), cv2.IMREAD_GRAYSCALE)
152
+ mask_1 = cv2.imread(os.path.join(sample_path, "object_1_mask.png"), cv2.IMREAD_GRAYSCALE)
153
+ collage = self._construct_collage(image, object_0, object_1, mask_0, mask_1)
154
+ return collage
155
+
156
+ def __len__(self):
157
+ return len(os.listdir(self.construct_dataset_dir))
158
+
159
+
160
+ if __name__ == "__main__":
161
+ '''
162
+ two-object case: train/test: TODO/51791
163
+ '''
164
+ import argparse
165
+
166
+ parser = argparse.ArgumentParser(description="Objects365Dataset Analysis")
167
+ parser.add_argument("--dataset_dir", type=str, required=True, help="Path to the dataset directory.")
168
+ parser.add_argument("--construct_dataset_dir", type=str, default='bin', help="Path to the debug bin directory.")
169
+ parser.add_argument("--dataset_name", type=str, default='object365', help="Dataset name.")
170
+ parser.add_argument('--is_train', action='store_true', help="Train/Test")
171
+ parser.add_argument('--is_build_data', action='store_true', help="Build data")
172
+ parser.add_argument('--is_multiple', action='store_true', help="Multiple/Two objects")
173
+ parser.add_argument("--area_ratio", type=float, default=0.01171, help="Area ratio for filtering out small objects.")
174
+ parser.add_argument("--obj_thr", type=int, default=20, help="Object threshold for filtering.")
175
+ parser.add_argument("--index", type=int, default=0, help="Index of the sample to test.")
176
+ args = parser.parse_args()
177
+
178
+ if args.is_train:
179
+ image_dir = Path(args.dataset_dir) / args.dataset_name / "images" / "train"
180
+ json_dir = Path(args.dataset_dir) / args.dataset_name / "labels" / "train"
181
+ max_num = 1742289
182
+ else:
183
+ image_dir = Path(args.dataset_dir) / args.dataset_name / "images" / "val"
184
+ json_dir = Path(args.dataset_dir) / args.dataset_name / "labels" / "val"
185
+ max_num = 80000
186
+
187
+ dataset = Objects365Dataset(
188
+ # json_dir = json_dir,
189
+ construct_dataset_dir = args.construct_dataset_dir,
190
+ obj_thr = args.obj_thr,
191
+ area_ratio = args.area_ratio,
192
+ )
193
+
194
+ if args.is_build_data:
195
+ if not args.is_multiple:
196
+ for index in range(0, max_num):
197
+ dataset._intersect_2_obj(image_dir, json_dir, index)
198
+ else:
199
+ for index in range(len(os.listdir(args.construct_dataset_dir))):
200
+ collage = dataset._get_sample(index)
datasets/viton_hd.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import os
4
+ from PIL import Image
5
+ from .data_utils import *
6
+ from .base import BaseDataset
7
+ from pathlib import Path
8
+ from util.box_ops import mask_to_bbox_xywh, draw_bboxes, compute_iou_matrix
9
+ import shutil
10
+
11
+ IS_VERIFY = False
12
+
13
+ class VITONHDDataset(BaseDataset):
14
+ def __init__(self, construct_dataset_dir, obj_thr=20, area_ratio=0.02):
15
+ self.obj_thr = obj_thr
16
+ self.construct_dataset_dir = construct_dataset_dir
17
+ os.makedirs(Path(self.construct_dataset_dir), exist_ok=True)
18
+ self.area_ratio = area_ratio
19
+ self.sample_list = os.listdir(self.construct_dataset_dir)
20
+
21
+ def _intersect_2_obj(self, asset_dir, idx):
22
+ image_dir = os.path.join(asset_dir, 'image')
23
+ image_list = os.listdir(image_dir)
24
+ image_path = os.path.join(image_dir, image_list[idx])
25
+ image_name = os.path.basename(image_path)
26
+ image = cv2.imread(image_path)
27
+
28
+ mask_dir = os.path.join(asset_dir, 'image-parse-v3')
29
+ segmentation_path = os.path.join(mask_dir, image_name[:-4]+'.png')
30
+ segmentation = Image.open(segmentation_path).convert('P')
31
+ segmentation = np.array(segmentation)
32
+
33
+ h, w = image.shape[0:2]
34
+ image_area = h*w
35
+
36
+ ids = np.unique(segmentation)
37
+ ids = [ i for i in ids if i!=0 ] # remove background mask
38
+ if len(ids) < 2:
39
+ print(f"[Info] Skip image index {image_name[:-4]} due to insufficient bbox.")
40
+ return
41
+
42
+ # filter by area
43
+ obj_ids = []
44
+ obj_areas = []
45
+ obj_bbox = []
46
+ for i in ids:
47
+ mask_id = (segmentation == int(i)).astype(np.uint8)
48
+ bbox = mask_to_bbox_xywh(mask_id) # xyhw
49
+ area = np.sum(mask_id)
50
+ if area > image_area * self.area_ratio:
51
+ obj_ids.append(i)
52
+ obj_areas.append(area)
53
+ obj_bbox.append(bbox)
54
+
55
+ if len(obj_bbox) < 2:
56
+ print(f"[Info] Skip image index {image_name[:-4]} due to insufficient bbox.")
57
+ return
58
+
59
+ # filter by IOU
60
+ bbox_xyxy = []
61
+ for box in obj_bbox:
62
+ x, y, w, h = box
63
+ bbox_xyxy.append([x, y, x + w, y + h])
64
+ bbox_xyxy = np.array(bbox_xyxy) # shape: [N, 4]
65
+
66
+ if IS_VERIFY:
67
+ os.makedirs(Path(self.construct_dataset_dir) / image_name[:-4], exist_ok=True)
68
+ image_with_boxes = draw_bboxes(image, bbox_xyxy)
69
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-4] / "bboxes_image.png"), image_with_boxes)
70
+
71
+ iou_matrix = compute_iou_matrix(bbox_xyxy)
72
+ np.fill_diagonal(iou_matrix, -1) # Exclude self-comparisons (i.e., each box with itself)
73
+
74
+ sorted_obj_ids = np.argsort(obj_areas)[::-1]
75
+ assert len(sorted_obj_ids) > 0
76
+
77
+ index0 = sorted_obj_ids[0]
78
+ index1 = sorted_obj_ids[1]
79
+
80
+ os.makedirs(Path(self.construct_dataset_dir) / image_name[:-4], exist_ok=True)
81
+ dst = Path(self.construct_dataset_dir) / image_name[:-4] / "image.jpg"
82
+ dst.parent.mkdir(parents=True, exist_ok=True)
83
+ shutil.copy(image_path, dst)
84
+
85
+ mask = (segmentation == int(obj_ids[index0])).astype(np.uint8)
86
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-4] / "object_0_mask.png"), 255*mask)
87
+ patch = self.get_patch(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), mask)
88
+ patch = cv2.cvtColor(patch, cv2.COLOR_RGB2BGR)
89
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-4] / "object_0.png"), patch)
90
+
91
+ if IS_VERIFY:
92
+ mask_color = np.stack([mask * 255]*3, axis=-1).astype(np.uint8)
93
+ highlight = np.zeros_like(image)
94
+ highlight[:, :, 2] = 255 # red channel
95
+ alpha = 0.5
96
+ image_with_boxes = np.where(mask_color == 255, cv2.addWeighted(image_with_boxes, 1 - alpha, highlight, alpha, 0), image_with_boxes)
97
+
98
+ mask = (segmentation == int(obj_ids[index1])).astype(np.uint8)
99
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-4] / "object_1_mask.png"), 255*mask)
100
+ patch = self.get_patch(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), mask)
101
+ patch = cv2.cvtColor(patch, cv2.COLOR_RGB2BGR)
102
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-4] / "object_1.png"), patch)
103
+
104
+ if IS_VERIFY:
105
+ mask_color = np.stack([mask * 255]*3, axis=-1).astype(np.uint8)
106
+ highlight = np.zeros_like(image)
107
+ highlight[:, :, 0] = 255 # blue channel
108
+ alpha = 0.5
109
+ image_with_boxes = np.where(mask_color == 255, cv2.addWeighted(image_with_boxes, 1 - alpha, highlight, alpha, 0), image_with_boxes)
110
+ cv2.imwrite(str(Path(self.construct_dataset_dir) / image_name[:-4] / "highlighted_image.png"), image_with_boxes)
111
+
112
+ def _get_sample(self, idx):
113
+ sample_path = os.path.join(self.construct_dataset_dir, self.sample_list[idx])
114
+ image = cv2.cvtColor(cv2.imread(os.path.join(sample_path, "image.jpg")), cv2.COLOR_BGR2RGB)
115
+ object_0 = cv2.cvtColor(cv2.imread(os.path.join(sample_path, "object_0.png")), cv2.COLOR_BGR2RGB)
116
+ object_1 = cv2.cvtColor(cv2.imread(os.path.join(sample_path, "object_1.png")), cv2.COLOR_BGR2RGB)
117
+ mask_0 = cv2.imread(os.path.join(sample_path, "object_0_mask.png"), cv2.IMREAD_GRAYSCALE)
118
+ mask_1 = cv2.imread(os.path.join(sample_path, "object_1_mask.png"), cv2.IMREAD_GRAYSCALE)
119
+ collage = self._construct_collage(image, object_0, object_1, mask_0, mask_1)
120
+ return collage
121
+
122
+ def __len__(self):
123
+ return len(os.listdir(self.construct_dataset_dir))
124
+
125
+
126
+ if __name__ == "__main__":
127
+ '''
128
+ two-object case: train/test: 11626/2028
129
+ '''
130
+ import argparse
131
+
132
+ parser = argparse.ArgumentParser(description="VITONHDDataset Analysis")
133
+ parser.add_argument("--dataset_dir", type=str, required=True, help="Path to the dataset directory.")
134
+ parser.add_argument("--construct_dataset_dir", type=str, default='bin', help="Path to the debug bin directory.")
135
+ parser.add_argument("--dataset_name", type=str, default='VitonHD', help="Dataset name.")
136
+ parser.add_argument('--is_train', action='store_true', help="Train/Test")
137
+ parser.add_argument('--is_build_data', action='store_true', help="Build data")
138
+ parser.add_argument('--is_multiple', action='store_true', help="Multiple/Two objects")
139
+ parser.add_argument("--area_ratio", type=float, default=0.01171, help="Area ratio for filtering out small objects.")
140
+ parser.add_argument("--obj_thr", type=int, default=20, help="Object threshold for filtering.")
141
+ parser.add_argument("--index", type=int, default=0, help="Index of the sample to test.")
142
+ args = parser.parse_args()
143
+
144
+ if args.is_train:
145
+ asset_dir = Path(args.dataset_dir) / args.dataset_name / "train"
146
+ else:
147
+ asset_dir = Path(args.dataset_dir) / args.dataset_name / "test"
148
+
149
+ dataset = VITONHDDataset(
150
+ construct_dataset_dir = args.construct_dataset_dir,
151
+ obj_thr = args.obj_thr,
152
+ area_ratio = args.area_ratio,
153
+ )
154
+
155
+ max_num = 20000
156
+
157
+ if args.is_build_data:
158
+ if not args.is_multiple:
159
+ for index in range(max_num):
160
+ dataset._intersect_2_obj(asset_dir, index)
161
+ else:
162
+ for index in range(len(os.listdir(args.construct_dataset_dir))):
163
+ collage = dataset._get_sample(index)
datasets/webdataset.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import webdataset as wds
2
+ from torch.utils.data import IterableDataset
3
+ from PIL import Image
4
+ import numpy as np
5
+ import cv2
6
+
7
+ class MultiWebDataset(IterableDataset):
8
+ def __init__(
9
+ self,
10
+ urls,
11
+ construct_collage_fn,
12
+ shuffle_size=0,
13
+ seed=0,
14
+ decode_mode="pil",
15
+ ):
16
+ super().__init__()
17
+ self.urls = urls
18
+ self.shuffle_size = shuffle_size
19
+ self.seed = seed
20
+ self.decode_mode = decode_mode
21
+ self.construct_collage_fn = construct_collage_fn
22
+
23
+ def _to_rgb_np(self, img):
24
+ if isinstance(img, Image.Image):
25
+ return np.array(img.convert("RGB"))
26
+ elif isinstance(img, np.ndarray):
27
+ if img.ndim == 2:
28
+ return cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
29
+ if img.ndim == 3 and img.shape[2] == 4:
30
+ return img[:, :, :3]
31
+ return img
32
+ else:
33
+ raise TypeError(f"Unsupported image type: {type(img)}")
34
+
35
+ def _to_mask_np(self, img):
36
+ if isinstance(img, Image.Image):
37
+ m = np.array(img.convert("L"))
38
+ elif isinstance(img, np.ndarray):
39
+ if img.ndim == 3:
40
+ m = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
41
+ else:
42
+ m = img
43
+ else:
44
+ raise TypeError(f"Unsupported mask type: {type(img)}")
45
+ m = (m > 127).astype(np.uint8) * 255
46
+ return m
47
+
48
+ def __iter__(self):
49
+ ds = wds.WebDataset(self.urls, shardshuffle=True, empty_check=False)
50
+
51
+ if self.shuffle_size and self.shuffle_size > 0:
52
+ ds = ds.shuffle(self.shuffle_size)
53
+
54
+ ds = ds.decode("pil")
55
+
56
+ ds = ds.rename(
57
+ bg="bg.jpg",
58
+ obj0="obj0.png",
59
+ mask0="mask0.png",
60
+ obj1="obj1.png",
61
+ mask1="mask1.png",
62
+ )
63
+
64
+ for sample in ds:
65
+ bg = sample["bg"]
66
+ obj0 = sample["obj0"]
67
+ obj1 = sample["obj1"]
68
+ mask0 = sample["mask0"]
69
+ mask1 = sample["mask1"]
70
+
71
+ bg_np = self._to_rgb_np(bg)
72
+ obj0_np = self._to_rgb_np(obj0)
73
+ obj1_np = self._to_rgb_np(obj1)
74
+ mask0_np = self._to_mask_np(mask0)
75
+ mask1_np = self._to_mask_np(mask1)
76
+
77
+ collage = self.construct_collage_fn(
78
+ bg_np, obj0_np, obj1_np, mask0_np, mask1_np
79
+ )
80
+ yield collage