OmniShotCut / architecture /position_encoding.py
HikariDawn's picture
feat: initial push
796e051
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Various positional encodings for the transformer.
"""
import math
import torch
from torch import nn
from util.misc import NestedTensor
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
Modified for 3dims (+Temporal) based on VisTR implementations (https://github.com/YuqingWang1029/VisTR/blob/master/models/position_encoding.py).
"""
def __init__(self, num_pos_feats=64, num_frames = 200, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
self.frames = num_frames
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
mask = tensor_list.mask
n,h,w = mask.shape
mask = mask.reshape(n//self.frames, self.frames,h,w)
assert mask is not None
not_mask = ~mask
z_embed = not_mask.cumsum(1, dtype=torch.float32)
y_embed = not_mask.cumsum(2, dtype=torch.float32)
x_embed = not_mask.cumsum(3, dtype=torch.float32)
if self.normalize:
eps = 1e-6
z_embed = z_embed / (z_embed[:, -1:, :, :] + eps) * self.scale
y_embed = y_embed / (y_embed[:, :, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, :, None] / dim_t
pos_y = y_embed[:, :, :, :, None] / dim_t
pos_z = z_embed[:, :, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
pos_z = torch.stack((pos_z[:, :, :, :, 0::2].sin(), pos_z[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
pos = torch.cat((pos_z, pos_y, pos_x), dim=4).permute(0, 1, 4, 2, 3)
return pos
def build_position_encoding(args):
# Modify from 2 dimensions to 3 dimensions (+ Temporal)
N_steps = args.hidden_dim // 3
if args.position_embedding in ('v2', 'sine'):
# TODO find a better way of exposing other arguments
position_embedding = PositionEmbeddingSine(N_steps, num_frames = args.max_process_window_length, normalize=True)
else:
raise ValueError(f"not supported {args.position_embedding}")
return position_embedding