| import random |
| import math |
| import numpy as np |
| import torch |
| from torch.nn import functional as F |
|
|
|
|
| class SceneMaskGenerator: |
| def __init__( |
| self, |
| input_size, |
| min_num_patches=16, |
| max_num_patches_ratio=0.5, |
| min_aspect=0.3, |
| ): |
| if not isinstance(input_size, tuple): |
| input_size = (input_size,) * 2 |
| self.input_size = input_size |
| self.num_patches = input_size[0] * input_size[1] |
|
|
| self.min_num_patches = min_num_patches |
| self.max_num_patches = max_num_patches_ratio * self.num_patches |
|
|
| self.log_aspect_ratio = (math.log(min_aspect), -math.log(min_aspect)) |
|
|
| def _mask(self, mask, max_mask_patches): |
| delta = 0 |
| for _ in range(4): |
| target_area = random.uniform(self.min_num_patches, max_mask_patches) |
| aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) |
| h = int(round(math.sqrt(target_area * aspect_ratio))) |
| w = int(round(math.sqrt(target_area / aspect_ratio))) |
| height, width = self.input_size |
| if w < width and h < height: |
| top = random.randint(0, height - h) |
| left = random.randint(0, width - w) |
|
|
| num_masked = mask[top : top + h, left : left + w].sum() |
| |
| if 0 < h * w - num_masked <= max_mask_patches: |
| mask[top : top + h, left : left + w] = 1 |
| delta = h * w - num_masked |
| break |
| return delta |
|
|
| def __call__(self, head_mask): |
| mask = np.zeros(shape=self.input_size, dtype=bool) |
| mask_count = 0 |
| num_masking_patches = random.uniform(self.min_num_patches, self.max_num_patches) |
| while mask_count < num_masking_patches: |
| max_mask_patches = num_masking_patches - mask_count |
| delta = self._mask(mask, max_mask_patches) |
| if delta == 0: |
| break |
| else: |
| mask_count += delta |
|
|
| mask = torch.from_numpy(mask).unsqueeze(0) |
| head_mask = ( |
| F.interpolate(head_mask.unsqueeze(0), mask.shape[-2:]).squeeze(0) < 0.5 |
| ) |
| return torch.logical_and(mask, head_mask).squeeze(0) |
|
|
|
|
| class HeadMaskGenerator: |
| def __init__( |
| self, |
| input_size, |
| min_num_patches=4, |
| max_num_patches_ratio=0.5, |
| min_aspect=0.3, |
| ): |
| if not isinstance(input_size, tuple): |
| input_size = (input_size,) * 2 |
| self.input_size = input_size |
| self.num_patches = input_size[0] * input_size[1] |
|
|
| self.min_num_patches = min_num_patches |
| self.max_num_patches_ratio = max_num_patches_ratio |
|
|
| self.log_aspect_ratio = (math.log(min_aspect), -math.log(min_aspect)) |
|
|
| def __call__( |
| self, |
| x_min, |
| y_min, |
| x_max, |
| y_max, |
| ): |
| height = math.floor((y_max - y_min) * self.input_size[0]) |
| width = math.floor((x_max - x_min) * self.input_size[1]) |
| origin_area = width * height |
| if origin_area < self.min_num_patches: |
| return torch.zeros(size=self.input_size, dtype=bool) |
|
|
| target_area = random.uniform( |
| self.min_num_patches, self.max_num_patches_ratio * origin_area |
| ) |
| aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) |
| h = min(int(round(math.sqrt(target_area * aspect_ratio))), height) |
| w = min(int(round(math.sqrt(target_area / aspect_ratio))), width) |
| top = random.randint(0, height - h) + int(y_min * self.input_size[0]) |
| left = random.randint(0, width - w) + int(x_min * self.input_size[1]) |
| mask = torch.zeros(size=self.input_size, dtype=bool) |
| mask[top : top + h, left : left + w] = True |
| return mask |
|
|
|
|
| class MaskGenerator: |
| def __init__( |
| self, |
| input_size, |
| mask_scene: bool = False, |
| mask_head: bool = False, |
| min_scene_patches=16, |
| max_scene_patches_ratio=0.5, |
| min_head_patches=4, |
| max_head_patches_ratio=0.5, |
| min_aspect=0.3, |
| mask_prob=0.2, |
| head_prob=0.2, |
| ): |
| if not isinstance(input_size, tuple): |
| input_size = (input_size,) * 2 |
| self.input_size = input_size |
| if mask_scene: |
| self.scene_mask_generator = SceneMaskGenerator( |
| input_size, min_scene_patches, max_scene_patches_ratio, min_aspect |
| ) |
| else: |
| self.scene_mask_generator = None |
|
|
| if mask_head: |
| self.head_mask_generator = HeadMaskGenerator( |
| input_size, min_head_patches, max_head_patches_ratio, min_aspect |
| ) |
| else: |
| self.head_mask_generator = None |
|
|
| self.no_mask = not (mask_scene or mask_head) |
| self.mask_head = mask_head and not mask_scene |
| self.mask_scene = mask_scene and not mask_head |
| self.scene_prob = mask_prob |
| self.head_prob = head_prob |
|
|
| def __call__( |
| self, |
| x_min, |
| y_min, |
| x_max, |
| y_max, |
| head_mask, |
| ): |
| mask_scene = random.random() < self.scene_prob |
| mask_head = random.random() < self.head_prob |
| no_mask = ( |
| self.no_mask |
| or (self.mask_head and not mask_head) |
| or (self.mask_scene and not mask_scene) |
| or not (mask_scene or mask_head) |
| ) |
| if no_mask: |
| return torch.zeros(size=self.input_size, dtype=bool) |
| if self.mask_scene: |
| return self.scene_mask_generator(head_mask) |
| if self.mask_head: |
| return self.head_mask_generator(x_min, y_min, x_max, y_max) |
| if mask_head and mask_scene: |
| return torch.logical_or( |
| self.scene_mask_generator(head_mask), |
| self.head_mask_generator(x_min, y_min, x_max, y_max), |
| ) |
| elif mask_head: |
| return self.head_mask_generator(x_min, y_min, x_max, y_max) |
| return self.scene_mask_generator(head_mask) |
|
|