| |
| |
|
|
| import os |
| import cv2 |
| import numpy as np |
| from loguru import logger |
| from functools import wraps |
| from pycocotools.coco import COCO |
| from torch.utils.data.dataset import Dataset as torchDataset |
|
|
| COCO_CLASSES = ( |
| 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', |
| 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', |
| 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', |
| 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', |
| 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', |
| 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', |
| 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', |
| 'teddy bear', 'hair drier', 'toothbrush') |
|
|
|
|
| def remove_useless_info(coco): |
| """ |
| Remove useless info in coco dataset. COCO object is modified inplace. |
| This function is mainly used for saving memory (save about 30% mem). |
| """ |
| if isinstance(coco, COCO): |
| dataset = coco.dataset |
| dataset.pop("info", None) |
| dataset.pop("licenses", None) |
| for img in dataset["images"]: |
| img.pop("license", None) |
| img.pop("coco_url", None) |
| img.pop("date_captured", None) |
| img.pop("flickr_url", None) |
| if "annotations" in coco.dataset: |
| for anno in coco.dataset["annotations"]: |
| anno.pop("segmentation", None) |
|
|
|
|
| class Dataset(torchDataset): |
| """ This class is a subclass of the base :class:`torch.utils.data.Dataset`, |
| that enables on the fly resizing of the ``input_dim``. |
| |
| Args: |
| input_dimension (tuple): (width,height) tuple with default dimensions of the network |
| """ |
|
|
| def __init__(self, input_dimension, mosaic=True): |
| super().__init__() |
| self.__input_dim = input_dimension[:2] |
| self.enable_mosaic = mosaic |
|
|
| @property |
| def input_dim(self): |
| """ |
| Dimension that can be used by transforms to set the correct image size, etc. |
| This allows transforms to have a single source of truth |
| for the input dimension of the network. |
| |
| Return: |
| list: Tuple containing the current width,height |
| """ |
| if hasattr(self, "_input_dim"): |
| return self._input_dim |
| return self.__input_dim |
|
|
| @staticmethod |
| def mosaic_getitem(getitem_fn): |
| """ |
| Decorator method that needs to be used around the ``__getitem__`` method. |br| |
| This decorator enables the closing mosaic |
| |
| Example: |
| >>> class CustomSet(ln.data.Dataset): |
| ... def __len__(self): |
| ... return 10 |
| ... @ln.data.Dataset.mosaic_getitem |
| ... def __getitem__(self, index): |
| ... return self.enable_mosaic |
| """ |
|
|
| @wraps(getitem_fn) |
| def wrapper(self, index): |
| if not isinstance(index, int): |
| self.enable_mosaic = index[0] |
| index = index[1] |
| ret_val = getitem_fn(self, index) |
| return ret_val |
|
|
| return wrapper |
|
|
|
|
| class COCODataset(Dataset): |
| """ |
| COCO dataset class. |
| """ |
|
|
| def __init__( |
| self, |
| data_dir='data/COCO', |
| json_file="instances_train2017.json", |
| name="train2017", |
| img_size=(416, 416), |
| preproc=None |
| ): |
| """ |
| COCO dataset initialization. Annotation data are read into memory by COCO API. |
| Args: |
| data_dir (str): dataset root directory |
| json_file (str): COCO json file name |
| name (str): COCO data name (e.g. 'train2017' or 'val2017') |
| img_size (tuple(int)): target image size after pre-processing |
| preproc: data augmentation strategy |
| """ |
| super().__init__(img_size) |
| self.data_dir = data_dir |
| self.json_file = json_file |
| self.coco = COCO(os.path.join(self.data_dir, "annotations", self.json_file)) |
| remove_useless_info(self.coco) |
| self.ids = self.coco.getImgIds() |
| self.class_ids = sorted(self.coco.getCatIds()) |
| self.cats = self.coco.loadCats(self.coco.getCatIds()) |
| self._classes = tuple([c["name"] for c in self.cats]) |
| self.imgs = None |
| self.name = name |
| self.img_size = img_size |
| self.preproc = preproc |
| self.annotations = self._load_coco_annotations() |
|
|
| def __len__(self): |
| return len(self.ids) |
|
|
| def __del__(self): |
| del self.imgs |
|
|
| def _load_coco_annotations(self): |
| return [self.load_anno_from_ids(_ids) for _ids in self.ids] |
|
|
| def load_anno_from_ids(self, id_): |
| im_ann = self.coco.loadImgs(id_)[0] |
| width = im_ann["width"] |
| height = im_ann["height"] |
| anno_ids = self.coco.getAnnIds(imgIds=[int(id_)], iscrowd=False) |
| annotations = self.coco.loadAnns(anno_ids) |
| objs = [] |
| for obj in annotations: |
| x1 = np.max((0, obj["bbox"][0])) |
| y1 = np.max((0, obj["bbox"][1])) |
| x2 = np.min((width, x1 + np.max((0, obj["bbox"][2])))) |
| y2 = np.min((height, y1 + np.max((0, obj["bbox"][3])))) |
| if obj["area"] > 0 and x2 >= x1 and y2 >= y1: |
| obj["clean_bbox"] = [x1, y1, x2, y2] |
| objs.append(obj) |
| num_objs = len(objs) |
| res = np.zeros((num_objs, 5)) |
| for ix, obj in enumerate(objs): |
| cls = self.class_ids.index(obj["category_id"]) |
| res[ix, 0:4] = obj["clean_bbox"] |
| res[ix, 4] = cls |
| r = min(self.img_size[0] / height, self.img_size[1] / width) |
| res[:, :4] *= r |
| img_info = (height, width) |
| resized_info = (int(height * r), int(width * r)) |
| file_name = ( |
| im_ann["file_name"] |
| if "file_name" in im_ann |
| else "{:012}".format(id_) + ".jpg" |
| ) |
| return res, img_info, resized_info, file_name |
|
|
| def load_anno(self, index): |
| return self.annotations[index][0] |
|
|
| def load_resized_img(self, index): |
| img = self.load_image(index) |
| r = min(self.img_size[0] / img.shape[0], self.img_size[1] / img.shape[1]) |
| resized_img = cv2.resize( |
| img, |
| (int(img.shape[1] * r), int(img.shape[0] * r)), |
| interpolation=cv2.INTER_LINEAR, |
| ).astype(np.uint8) |
| return resized_img |
|
|
| def load_image(self, index): |
| file_name = self.annotations[index][3] |
| img_file = os.path.join(self.data_dir, self.name, file_name) |
| img = cv2.imread(img_file) |
| assert img is not None, f"file named {img_file} not found" |
| return img |
|
|
| def pull_item(self, index): |
| id_ = self.ids[index] |
| res, img_info, resized_info, _ = self.annotations[index] |
| if self.imgs is not None: |
| pad_img = self.imgs[index] |
| img = pad_img[: resized_info[0], : resized_info[1], :].copy() |
| else: |
| img = self.load_resized_img(index) |
| return img, res.copy(), img_info, np.array([id_]) |
|
|
| @Dataset.mosaic_getitem |
| def __getitem__(self, index): |
| """ |
| One image / label pair for the given index is picked up and pre-processed. |
| |
| Args: |
| index (int): data index |
| |
| Returns: |
| img (numpy.ndarray): pre-processed image |
| target (torch.Tensor): pre-processed label data. |
| The shape is :math:`[max_labels, 5]`. |
| each label consists of [class, xc, yc, w, h]: |
| class (float): class index. |
| xc, yc (float) : center of bbox whose values range from 0 to 1. |
| w, h (float) : size of bbox whose values range from 0 to 1. |
| img_info : tuple of h, w. |
| h, w (int): original shape of the image |
| img_id (int): same as the input index. Used for evaluation. |
| """ |
| img, target, img_info, img_id = self.pull_item(index) |
| if self.preproc is not None: |
| img, target = self.preproc(img, target, self.input_dim) |
| return img, target, img_info, img_id |
|
|