| import glob |
| from typing import Callable, Optional |
| from os import path as osp |
|
|
| import torch |
| from torch.utils.data.dataset import Dataset |
| import torchvision.transforms.functional as TF |
| import numpy as np |
| import pandas as pd |
| from PIL import Image |
|
|
| from . import augmentation |
| from . import data_utils as utils |
| from .masking import MaskGenerator |
|
|
|
|
| class VideoAttentionTarget(Dataset): |
| def __init__( |
| self, |
| image_root: str, |
| anno_root: str, |
| head_root: str, |
| transform: Callable, |
| input_size: int, |
| output_size: int, |
| quant_labelmap: bool = True, |
| is_train: bool = True, |
| *, |
| mask_generator: Optional[MaskGenerator] = None, |
| bbox_jitter: float = 0.5, |
| rand_crop: float = 0.5, |
| rand_flip: float = 0.5, |
| color_jitter: float = 0.5, |
| rand_rotate: float = 0.0, |
| rand_lsj: float = 0.0, |
| ): |
| frames = [] |
| for show_dir in glob.glob(osp.join(anno_root, "*")): |
| for sequence_path in glob.glob(osp.join(show_dir, "*", "*.txt")): |
| df = pd.read_csv( |
| sequence_path, |
| header=None, |
| index_col=False, |
| names=[ |
| "path", |
| "x_min", |
| "y_min", |
| "x_max", |
| "y_max", |
| "gaze_x", |
| "gaze_y", |
| ], |
| ) |
|
|
| show_name = sequence_path.split("/")[-3] |
| clip = sequence_path.split("/")[-2] |
| df["path"] = df["path"].apply( |
| lambda path: osp.join(show_name, clip, path) |
| ) |
| |
| df["eye_x"] = (df["x_min"] + df["x_max"]) / 2 |
| df["eye_y"] = (df["y_min"] + df["y_max"]) / 2 |
| df = df.sample(frac=0.2, random_state=42) |
| frames.extend(df.values.tolist()) |
|
|
| df = pd.DataFrame( |
| frames, |
| columns=[ |
| "path", |
| "x_min", |
| "y_min", |
| "x_max", |
| "y_max", |
| "gaze_x", |
| "gaze_y", |
| "eye_x", |
| "eye_y", |
| ], |
| ) |
| |
| coords = torch.tensor( |
| np.array( |
| ( |
| df["x_min"].values, |
| df["y_min"].values, |
| df["x_max"].values, |
| df["y_max"].values, |
| ) |
| ).transpose(1, 0) |
| ) |
| valid_bboxes = (coords[:, 2:] >= coords[:, :2]).all(dim=1) |
| df = df.loc[valid_bboxes.tolist(), :] |
| df.reset_index(inplace=True) |
| self.df = df |
| self.length = len(df) |
|
|
| self.data_dir = image_root |
| self.head_dir = head_root |
| self.transform = transform |
| self.draw_labelmap = ( |
| utils.draw_labelmap if quant_labelmap else utils.draw_labelmap_no_quant |
| ) |
| self.is_train = is_train |
|
|
| self.input_size = input_size |
| self.output_size = output_size |
|
|
| if self.is_train: |
| |
| self.augment = augmentation.AugmentationList( |
| [ |
| augmentation.ColorJitter(color_jitter), |
| augmentation.BoxJitter(bbox_jitter), |
| augmentation.RandomCrop(rand_crop), |
| augmentation.RandomFlip(rand_flip), |
| augmentation.RandomRotate(rand_rotate), |
| augmentation.RandomLSJ(rand_lsj), |
| ] |
| ) |
|
|
| self.mask_generator = mask_generator |
|
|
| def __getitem__(self, index): |
| ( |
| _, |
| path, |
| x_min, |
| y_min, |
| x_max, |
| y_max, |
| gaze_x, |
| gaze_y, |
| eye_x, |
| eye_y, |
| ) = self.df.iloc[index] |
| gaze_inside = gaze_x != -1 or gaze_y != -1 |
|
|
| img = Image.open(osp.join(self.data_dir, path)) |
| img = img.convert("RGB") |
| width, height = img.size |
| |
| |
| if osp.exists(osp.join(self.head_dir, path)): |
| head_mask = Image.open(osp.join(self.head_dir, path)).resize( |
| (width, height) |
| ) |
| else: |
| head_mask = Image.fromarray(np.zeros((height, width), dtype=np.float32)) |
| x_min, y_min, x_max, y_max = map(float, [x_min, y_min, x_max, y_max]) |
| if x_max < x_min: |
| x_min, x_max = x_max, x_min |
| if y_max < y_min: |
| y_min, y_max = y_max, y_min |
| gaze_x, gaze_y = gaze_x / width, gaze_y / height |
| |
| k = 0.1 |
| x_min = max(x_min - k * abs(x_max - x_min), 0) |
| y_min = max(y_min - k * abs(y_max - y_min), 0) |
| x_max = min(x_max + k * abs(x_max - x_min), width - 1) |
| y_max = min(y_max + k * abs(y_max - y_min), height - 1) |
|
|
| if self.is_train: |
| img, bbox, gaze, head_mask, size = self.augment( |
| img, |
| (x_min, y_min, x_max, y_max), |
| (gaze_x, gaze_y), |
| head_mask, |
| (width, height), |
| ) |
| x_min, y_min, x_max, y_max = bbox |
| gaze_x, gaze_y = gaze |
| width, height = size |
|
|
| head_channel = utils.get_head_box_channel( |
| x_min, |
| y_min, |
| x_max, |
| y_max, |
| width, |
| height, |
| resolution=self.input_size, |
| coordconv=False, |
| ).unsqueeze(0) |
|
|
| if self.is_train and self.mask_generator is not None: |
| image_mask = self.mask_generator( |
| x_min / width, |
| y_min / height, |
| x_max / width, |
| y_max / height, |
| head_channel, |
| ) |
|
|
| if self.transform is not None: |
| img = self.transform(img) |
| head_mask = TF.to_tensor( |
| TF.resize(head_mask, (self.input_size, self.input_size)) |
| ) |
|
|
| |
| gaze_heatmap = torch.zeros( |
| self.output_size, self.output_size |
| ) |
|
|
| gaze_heatmap = self.draw_labelmap( |
| gaze_heatmap, |
| [gaze_x * self.output_size, gaze_y * self.output_size], |
| 3, |
| type="Gaussian", |
| ) |
|
|
| imsize = torch.IntTensor([width, height]) |
|
|
| out_dict = { |
| "images": img, |
| "head_channels": head_channel, |
| "heatmaps": gaze_heatmap, |
| "gazes": torch.FloatTensor([gaze_x, gaze_y]), |
| "gaze_inouts": torch.FloatTensor([gaze_inside]), |
| "head_masks": head_mask, |
| "imsize": imsize, |
| } |
| if self.is_train and self.mask_generator is not None: |
| out_dict["image_masks"] = image_mask |
| return out_dict |
|
|
| def __len__(self): |
| return self.length |
|
|