# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ Various positional encodings for the transformer. """ import math import torch from torch import nn from util.misc import NestedTensor class PositionEmbeddingSine(nn.Module): """ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you need paper, generalized to work on images. Modified for 3dims (+Temporal) based on VisTR implementations (https://github.com/YuqingWang1029/VisTR/blob/master/models/position_encoding.py). """ def __init__(self, num_pos_feats=64, num_frames = 200, temperature=10000, normalize=False, scale=None): super().__init__() self.num_pos_feats = num_pos_feats self.temperature = temperature self.normalize = normalize self.frames = num_frames if scale is not None and normalize is False: raise ValueError("normalize should be True if scale is passed") if scale is None: scale = 2 * math.pi self.scale = scale def forward(self, tensor_list: NestedTensor): x = tensor_list.tensors mask = tensor_list.mask n,h,w = mask.shape mask = mask.reshape(n//self.frames, self.frames,h,w) assert mask is not None not_mask = ~mask z_embed = not_mask.cumsum(1, dtype=torch.float32) y_embed = not_mask.cumsum(2, dtype=torch.float32) x_embed = not_mask.cumsum(3, dtype=torch.float32) if self.normalize: eps = 1e-6 z_embed = z_embed / (z_embed[:, -1:, :, :] + eps) * self.scale y_embed = y_embed / (y_embed[:, :, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, :, -1:] + eps) * self.scale dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) pos_x = x_embed[:, :, :, :, None] / dim_t pos_y = y_embed[:, :, :, :, None] / dim_t pos_z = z_embed[:, :, :, :, None] / dim_t pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5).flatten(4) pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5).flatten(4) pos_z = torch.stack((pos_z[:, :, :, :, 0::2].sin(), pos_z[:, :, :, :, 1::2].cos()), dim=5).flatten(4) pos = torch.cat((pos_z, pos_y, pos_x), dim=4).permute(0, 1, 4, 2, 3) return pos def build_position_encoding(args): # Modify from 2 dimensions to 3 dimensions (+ Temporal) N_steps = args.hidden_dim // 3 if args.position_embedding in ('v2', 'sine'): # TODO find a better way of exposing other arguments position_embedding = PositionEmbeddingSine(N_steps, num_frames = args.max_process_window_length, normalize=True) else: raise ValueError(f"not supported {args.position_embedding}") return position_embedding