| from os import path as osp |
| from typing import Callable, Optional |
|
|
| import torch |
| from torch.utils.data import Dataset |
| from torchvision.transforms import functional as TF |
| from PIL import Image |
| import pandas as pd |
|
|
| from . import augmentation |
| from .masking import MaskGenerator |
| from . import data_utils as utils |
|
|
|
|
| class GazeFollow(Dataset): |
| def __init__( |
| self, |
| image_root: str, |
| anno_root: str, |
| head_root: str, |
| transform: Callable, |
| input_size: int, |
| output_size: int, |
| quant_labelmap: bool = True, |
| is_train: bool = True, |
| *, |
| mask_generator: Optional[MaskGenerator] = None, |
| bbox_jitter: float = 0.5, |
| rand_crop: float = 0.5, |
| rand_flip: float = 0.5, |
| color_jitter: float = 0.5, |
| rand_rotate: float = 0.0, |
| rand_lsj: float = 0.0, |
| ): |
| if is_train: |
| column_names = [ |
| "path", |
| "idx", |
| "body_bbox_x", |
| "body_bbox_y", |
| "body_bbox_w", |
| "body_bbox_h", |
| "eye_x", |
| "eye_y", |
| "gaze_x", |
| "gaze_y", |
| "bbox_x_min", |
| "bbox_y_min", |
| "bbox_x_max", |
| "bbox_y_max", |
| "inout", |
| "meta0", |
| "meta1", |
| ] |
| df = pd.read_csv( |
| anno_root, |
| sep=",", |
| names=column_names, |
| index_col=False, |
| encoding="utf-8-sig", |
| ) |
| df = df[ |
| df["inout"] != -1 |
| ] |
| df.reset_index(inplace=True) |
| self.y_train = df[ |
| [ |
| "bbox_x_min", |
| "bbox_y_min", |
| "bbox_x_max", |
| "bbox_y_max", |
| "eye_x", |
| "eye_y", |
| "gaze_x", |
| "gaze_y", |
| "inout", |
| ] |
| ] |
| self.X_train = df["path"] |
| self.length = len(df) |
| else: |
| column_names = [ |
| "path", |
| "idx", |
| "body_bbox_x", |
| "body_bbox_y", |
| "body_bbox_w", |
| "body_bbox_h", |
| "eye_x", |
| "eye_y", |
| "gaze_x", |
| "gaze_y", |
| "bbox_x_min", |
| "bbox_y_min", |
| "bbox_x_max", |
| "bbox_y_max", |
| "meta0", |
| "meta1", |
| ] |
| df = pd.read_csv( |
| anno_root, |
| sep=",", |
| names=column_names, |
| index_col=False, |
| encoding="utf-8-sig", |
| ) |
| df = df[ |
| [ |
| "path", |
| "eye_x", |
| "eye_y", |
| "gaze_x", |
| "gaze_y", |
| "bbox_x_min", |
| "bbox_y_min", |
| "bbox_x_max", |
| "bbox_y_max", |
| ] |
| ].groupby(["path", "eye_x"]) |
| self.keys = list(df.groups.keys()) |
| self.X_test = df |
| self.length = len(self.keys) |
|
|
| self.data_dir = image_root |
| self.head_dir = head_root |
| self.transform = transform |
| self.is_train = is_train |
|
|
| self.input_size = input_size |
| self.output_size = output_size |
|
|
| self.draw_labelmap = ( |
| utils.draw_labelmap if quant_labelmap else utils.draw_labelmap_no_quant |
| ) |
|
|
| if self.is_train: |
| |
| self.augment = augmentation.AugmentationList( |
| [ |
| augmentation.ColorJitter(color_jitter), |
| augmentation.BoxJitter(bbox_jitter), |
| augmentation.RandomCrop(rand_crop), |
| augmentation.RandomFlip(rand_flip), |
| augmentation.RandomRotate(rand_rotate), |
| augmentation.RandomLSJ(rand_lsj), |
| ] |
| ) |
|
|
| self.mask_generator = mask_generator |
|
|
| def __getitem__(self, index): |
| if not self.is_train: |
| g = self.X_test.get_group(self.keys[index]) |
| cont_gaze = [] |
| for _, row in g.iterrows(): |
| path = row["path"] |
| x_min = row["bbox_x_min"] |
| y_min = row["bbox_y_min"] |
| x_max = row["bbox_x_max"] |
| y_max = row["bbox_y_max"] |
| eye_x = row["eye_x"] |
| eye_y = row["eye_y"] |
| gaze_x = row["gaze_x"] |
| gaze_y = row["gaze_y"] |
| cont_gaze.append( |
| [gaze_x, gaze_y] |
| ) |
| for _ in range(len(cont_gaze), 20): |
| cont_gaze.append( |
| [-1, -1] |
| ) |
| cont_gaze = torch.FloatTensor(cont_gaze) |
| gaze_inside = True |
| else: |
| path = self.X_train.iloc[index] |
| ( |
| x_min, |
| y_min, |
| x_max, |
| y_max, |
| eye_x, |
| eye_y, |
| gaze_x, |
| gaze_y, |
| inout, |
| ) = self.y_train.iloc[index] |
| gaze_inside = bool(inout) |
|
|
| img = Image.open(osp.join(self.data_dir, path)) |
| img = img.convert("RGB") |
| head_mask = Image.open(osp.join(self.head_dir, path)) |
| width, height = img.size |
| x_min, y_min, x_max, y_max = map(float, [x_min, y_min, x_max, y_max]) |
| if x_max < x_min: |
| x_min, x_max = x_max, x_min |
| if y_max < y_min: |
| y_min, y_max = y_max, y_min |
| |
| k = 0.1 |
| x_min = max(x_min - k * abs(x_max - x_min), 0) |
| y_min = max(y_min - k * abs(y_max - y_min), 0) |
| x_max = min(x_max + k * abs(x_max - x_min), width - 1) |
| y_max = min(y_max + k * abs(y_max - y_min), height - 1) |
|
|
| if self.is_train: |
| img, bbox, gaze, head_mask, size = self.augment( |
| img, |
| (x_min, y_min, x_max, y_max), |
| (gaze_x, gaze_y), |
| head_mask, |
| (width, height), |
| ) |
| x_min, y_min, x_max, y_max = bbox |
| gaze_x, gaze_y = gaze |
| width, height = size |
|
|
| head_channel = utils.get_head_box_channel( |
| x_min, |
| y_min, |
| x_max, |
| y_max, |
| width, |
| height, |
| resolution=self.input_size, |
| coordconv=False, |
| ).unsqueeze(0) |
|
|
| if self.is_train and self.mask_generator is not None: |
| image_mask = self.mask_generator( |
| x_min / width, |
| y_min / height, |
| x_max / width, |
| y_max / height, |
| head_channel, |
| ) |
|
|
| if self.transform is not None: |
| img = self.transform(img) |
| head_mask = TF.to_tensor( |
| TF.resize(head_mask, (self.input_size, self.input_size)) |
| ) |
|
|
| |
| gaze_heatmap = torch.zeros( |
| self.output_size, self.output_size |
| ) |
| if not self.is_train: |
| num_valid = 0 |
| for gaze_x, gaze_y in cont_gaze: |
| if gaze_x != -1: |
| num_valid += 1 |
| gaze_heatmap += self.draw_labelmap( |
| torch.zeros(self.output_size, self.output_size), |
| [gaze_x * self.output_size, gaze_y * self.output_size], |
| 3, |
| type="Gaussian", |
| ) |
| gaze_heatmap /= num_valid |
| else: |
| |
| gaze_heatmap = self.draw_labelmap( |
| gaze_heatmap, |
| [gaze_x * self.output_size, gaze_y * self.output_size], |
| 3, |
| type="Gaussian", |
| ) |
|
|
| imsize = torch.IntTensor([width, height]) |
|
|
| if self.is_train: |
| out_dict = { |
| "images": img, |
| "head_channels": head_channel, |
| "heatmaps": gaze_heatmap, |
| "gazes": torch.FloatTensor([gaze_x, gaze_y]), |
| "gaze_inouts": torch.FloatTensor([gaze_inside]), |
| "head_masks": head_mask, |
| "imsize": imsize, |
| } |
| if self.mask_generator is not None: |
| out_dict["image_masks"] = image_mask |
| return out_dict |
| else: |
| return { |
| "images": img, |
| "head_channels": head_channel, |
| "heatmaps": gaze_heatmap, |
| "gazes": cont_gaze, |
| "gaze_inouts": torch.FloatTensor([gaze_inside]), |
| "head_masks": head_mask, |
| "imsize": imsize, |
| } |
|
|
| def __len__(self): |
| return self.length |
|
|