import torch import math class TimeEncoding(torch.nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, t): device = t.device freqs = torch.arange(self.dim, device=device).float() freqs = 1 / (10 ** (freqs / self.dim)) t = t.unsqueeze(1) angles = t * freqs return torch.cat([torch.sin(angles), torch.cos(angles)], dim=1)