temporal-twins-code / src /tgn /time_encoding.py
temporal-twins-anon's picture
Add anonymous Temporal Twins code release
a3682cf verified
raw
history blame contribute delete
432 Bytes
import torch
import math
class TimeEncoding(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
device = t.device
freqs = torch.arange(self.dim, device=device).float()
freqs = 1 / (10 ** (freqs / self.dim))
t = t.unsqueeze(1)
angles = t * freqs
return torch.cat([torch.sin(angles), torch.cos(angles)], dim=1)