| import torch | |
| import math | |
| class TimeEncoding(torch.nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, t): | |
| device = t.device | |
| freqs = torch.arange(self.dim, device=device).float() | |
| freqs = 1 / (10 ** (freqs / self.dim)) | |
| t = t.unsqueeze(1) | |
| angles = t * freqs | |
| return torch.cat([torch.sin(angles), torch.cos(angles)], dim=1) |