Abdullah-Nazhat commited on
Commit
208ef89
·
verified ·
1 Parent(s): a8c7584

Update litetensormapper.py

Browse files
Files changed (1) hide show
  1. litetensormapper.py +35 -12
litetensormapper.py CHANGED
@@ -2,8 +2,6 @@ import torch
2
  from torch import nn, Tensor
3
 
4
 
5
-
6
-
7
  class VecDyT(nn.Module):
8
  def __init__(self, input_shape):
9
 
@@ -24,14 +22,34 @@ class VecDyGeluSine(nn.Module):
24
  self.alpha = nn.Parameter(torch.randn(input_shape))
25
  self.beta = nn.Parameter(torch.randn(input_shape))
26
  self.gamma = nn.Parameter(torch.randn(1))
27
- self.eta = nn.Parameter(torch.randn(1))
28
  self.gelu = nn.GELU()
29
 
30
  def forward(self, x):
31
- x = self.gamma * self.gelu(self.alpha * x) + self.eta * torch.sin(self.beta * x)
 
32
 
33
  return x
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
  class TTT(nn.Module):
@@ -65,12 +83,12 @@ class TTT(nn.Module):
65
 
66
  return out
67
 
68
- class FFUnit(nn.Module):
69
  def __init__(self,dim):
70
 
71
  super().__init__()
72
 
73
- self.proj = nn.Linear(dim,dim,bias=False)
74
  self.modulate = VecDyGeluSine(dim)
75
 
76
 
@@ -91,24 +109,28 @@ class LiteTensorMapperBlock(nn.Module):
91
 
92
  self.norm_1 = VecDyT(dim)
93
  self.norm_2 = VecDyT(dim)
94
- self.memory = TTT(dim)
95
  self.feedforward = FFUnit(dim)
96
 
97
 
98
  def forward(self, x):
99
 
100
 
101
- memorypath, FeedForwardpath = x, x
102
 
103
  memorypath = self.norm_1(memorypath)
104
 
105
  memorypath = self.memory(memorypath)
106
 
107
- FeedForwardpath = self.norm_2(FeedForwardpath)
108
 
109
- FeedForwardpath = self.feedforward(FeedForwardpath)
110
 
111
- x = memorypath + FeedForwardpath
 
 
 
 
112
 
113
  return x
114
 
@@ -123,4 +145,5 @@ class LiteTensorMapper(nn.Module):
123
 
124
  def forward(self, x):
125
 
126
- return self.model(x)
 
 
2
  from torch import nn, Tensor
3
 
4
 
 
 
5
  class VecDyT(nn.Module):
6
  def __init__(self, input_shape):
7
 
 
22
  self.alpha = nn.Parameter(torch.randn(input_shape))
23
  self.beta = nn.Parameter(torch.randn(input_shape))
24
  self.gamma = nn.Parameter(torch.randn(1))
25
+ self.etta = nn.Parameter(torch.randn(1))
26
  self.gelu = nn.GELU()
27
 
28
  def forward(self, x):
29
+
30
+ x = self.gamma * self.gelu(self.alpha * x) + self.etta * torch.sin(self.beta * x)
31
 
32
  return x
33
 
34
+ class FFUnit(nn.Module):
35
+ def __init__(self,dim):
36
+
37
+ super().__init__()
38
+
39
+ self.proj = nn.Linear(dim,dim,bias=False)
40
+ self.modulate = VecDyGeluSine(dim)
41
+
42
+
43
+ def forward(self, x):
44
+
45
+ u, v = x, x
46
+
47
+ u = self.modulate(u)
48
+ v = self.proj(v)
49
+ g = u * v
50
+
51
+ return g
52
+
53
 
54
 
55
  class TTT(nn.Module):
 
83
 
84
  return out
85
 
86
+ class FFUnit_TTT(nn.Module):
87
  def __init__(self,dim):
88
 
89
  super().__init__()
90
 
91
+ self.proj = TTT(dim)
92
  self.modulate = VecDyGeluSine(dim)
93
 
94
 
 
109
 
110
  self.norm_1 = VecDyT(dim)
111
  self.norm_2 = VecDyT(dim)
112
+ self.memory = FFUnit_TTT(dim)
113
  self.feedforward = FFUnit(dim)
114
 
115
 
116
  def forward(self, x):
117
 
118
 
119
+ memorypath,residual = x, x
120
 
121
  memorypath = self.norm_1(memorypath)
122
 
123
  memorypath = self.memory(memorypath)
124
 
125
+ x = memorypath + residual
126
 
127
+ FFpath, residual = x, x
128
 
129
+ FFpath = self.norm_2(FFpath)
130
+
131
+ FFpath = self.feedforward(FFpath)
132
+
133
+ x = FFpath + residual
134
 
135
  return x
136
 
 
145
 
146
  def forward(self, x):
147
 
148
+ return self.model(x)
149
+