File size: 432 Bytes
a3682cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
import math


class TimeEncoding(torch.nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        device = t.device
        freqs = torch.arange(self.dim, device=device).float()
        freqs = 1 / (10 ** (freqs / self.dim))

        t = t.unsqueeze(1)
        angles = t * freqs

        return torch.cat([torch.sin(angles), torch.cos(angles)], dim=1)