TCMVince commited on
Commit
0efcf9c
·
verified ·
1 Parent(s): d972a70

Upload positional.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. positional.py +27 -0
positional.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # positional.py
2
+ import torch
3
+ import torch.nn as nn
4
+ from math import log
5
+
6
+ class PositionalEncoding(nn.Module):
7
+ def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 500):
8
+ super().__init__()
9
+ self.dropout = nn.Dropout(p=dropout)
10
+
11
+ position = torch.arange(max_len, dtype=torch.float).unsqueeze(1) # (max_len, 1)
12
+ div_term = torch.exp(
13
+ torch.arange(0, d_model, 2, dtype=torch.float) * (-log(10000.0) / d_model)
14
+ ) # (d_model/2,)
15
+
16
+ pe = torch.zeros(max_len, d_model, dtype=torch.float) # (max_len, d_model)
17
+ pe[:, 0::2] = torch.sin(position * div_term) # (max_len, d_model/2)
18
+ pe[:, 1::2] = torch.cos(position * div_term)
19
+
20
+ pe = pe.unsqueeze(0)
21
+ self.register_buffer("pe", pe, persistent=False) # buffer, pas paramètre
22
+
23
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
24
+ # x: (B, S, D)
25
+ s = x.size(1)
26
+ x = x + self.pe[:, :s, :] # (1,S,D) broadcast -> (B,S,D)
27
+ return self.dropout(x)