ziadkassem commited on
Commit
9c26aec
·
verified ·
1 Parent(s): 9170f60

Upload StoryGPT/model/rmsnorm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. StoryGPT/model/rmsnorm.py +34 -0
StoryGPT/model/rmsnorm.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ """
6
+ Formula
7
+
8
+ RMS(x) = sqrt( mean(x^2) )
9
+
10
+ x_norm = x / RMS(x) * gamma
11
+
12
+ """
13
+ class RMSNorm(nn.Module):
14
+ def __init__ (self,cfg,eps=1e-8):
15
+ super().__init__()
16
+ self.eps = eps
17
+ self.gamma = nn.Parameter(torch.ones(cfg["emb_dim"]))
18
+
19
+ def forward(self,x):
20
+ RMS = x.pow(2).mean(dim=-1,keepdim=True).sqrt()
21
+ return (x / (RMS+self.eps)) * self.gamma
22
+
23
+
24
+ """
25
+ Explaining the idea, We use keepdim for:
26
+
27
+ x.shape # (2, 4, 8)
28
+ rms.shape # (2, 4) - dimension is gone
29
+
30
+ and we use dim=-1 so we select the last dim (8)
31
+
32
+ so final dim becomes (2,4,1) which is broadcastable and divisible by (2,4,8)
33
+
34
+ """