gijl commited on
Commit
cde5e54
·
verified ·
1 Parent(s): 6c811ca

Upload 2 files

Browse files
Files changed (2) hide show
  1. model.py +55 -0
  2. tokenizer_config.json +233 -0
model.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ from torch.nn import functional as F
5
+
6
+ class SelfAttention(nn.Module):
7
+ def __init__(self, n_embd=768, n_head=8):
8
+ super().__init__()
9
+ self.qkv = nn.Linear(n_embd, n_embd * 3, bias=False)
10
+ self.proj = nn.Linear(n_embd, n_embd)
11
+ self.n_head = n_head
12
+
13
+ def forward(self, x):
14
+ B, T, C = x.shape
15
+ q, k, v = self.qkv(x).split(C, dim=2)
16
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
17
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
18
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
19
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.shape[-1]))
20
+ att = torch.softmax(att, dim=-1)
21
+ y = att @ v
22
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
23
+ return self.proj(y)
24
+
25
+ class Block(nn.Module):
26
+ def __init__(self, n_embd=768, n_head=8):
27
+ super().__init__()
28
+ self.ln1 = nn.LayerNorm(n_embd)
29
+ self.attn = SelfAttention(n_embd, n_head)
30
+ self.ln2 = nn.LayerNorm(n_embd)
31
+ self.mlp = nn.Sequential(
32
+ nn.Linear(n_embd, 4 * n_embd),
33
+ nn.GELU(),
34
+ nn.Linear(4 * n_embd, n_embd),
35
+ )
36
+
37
+ def forward(self, x):
38
+ x = x + self.attn(self.ln1(x))
39
+ x = x + self.mlp(self.ln2(x))
40
+ return x
41
+
42
+ class MedicalMasterAI(nn.Module):
43
+ def __init__(self, vocab_size=115, n_layer=48, n_head=8, n_embd=768):
44
+ super().__init__()
45
+ self.token_embedding = nn.Embedding(vocab_size, n_embd)
46
+ self.position_embedding = nn.Parameter(torch.zeros(1, 1024, n_embd))
47
+ self.blocks = nn.Sequential(*[Block(n_embd, n_head) for _ in range(n_layer)])
48
+ self.ln_f = nn.LayerNorm(n_embd)
49
+ self.lm_head = nn.Linear(n_embd, vocab_size)
50
+
51
+ def forward(self, idx):
52
+ b, t = idx.shape
53
+ x = self.token_embedding(idx) + self.position_embedding[:, :t, :]
54
+ x = self.blocks(x)
55
+ return self.lm_head(self.ln_f(x))
tokenizer_config.json ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "stoi": {
3
+ " ": 0,
4
+ "!": 1,
5
+ "(": 2,
6
+ ")": 3,
7
+ "*": 4,
8
+ "+": 5,
9
+ "-": 6,
10
+ ".": 7,
11
+ "/": 8,
12
+ "0": 9,
13
+ "1": 10,
14
+ "2": 11,
15
+ "3": 12,
16
+ "4": 13,
17
+ "5": 14,
18
+ "6": 15,
19
+ "7": 16,
20
+ "8": 17,
21
+ "9": 18,
22
+ "=": 19,
23
+ "A": 20,
24
+ "B": 21,
25
+ "C": 22,
26
+ "D": 23,
27
+ "E": 24,
28
+ "F": 25,
29
+ "G": 26,
30
+ "H": 27,
31
+ "I": 28,
32
+ "J": 29,
33
+ "K": 30,
34
+ "L": 31,
35
+ "M": 32,
36
+ "N": 33,
37
+ "O": 34,
38
+ "P": 35,
39
+ "Q": 36,
40
+ "R": 37,
41
+ "S": 38,
42
+ "T": 39,
43
+ "U": 40,
44
+ "V": 41,
45
+ "W": 42,
46
+ "X": 43,
47
+ "Y": 44,
48
+ "Z": 45,
49
+ "[": 46,
50
+ "]": 47,
51
+ "a": 48,
52
+ "b": 49,
53
+ "c": 50,
54
+ "d": 51,
55
+ "e": 52,
56
+ "f": 53,
57
+ "g": 54,
58
+ "h": 55,
59
+ "i": 56,
60
+ "j": 57,
61
+ "k": 58,
62
+ "l": 59,
63
+ "m": 60,
64
+ "n": 61,
65
+ "o": 62,
66
+ "p": 63,
67
+ "q": 64,
68
+ "r": 65,
69
+ "s": 66,
70
+ "t": 67,
71
+ "u": 68,
72
+ "v": 69,
73
+ "w": 70,
74
+ "x": 71,
75
+ "y": 72,
76
+ "z": 73,
77
+ "{": 74,
78
+ "}": 75,
79
+ "،": 76,
80
+ "؟": 77,
81
+ "ء": 78,
82
+ "آ": 79,
83
+ "أ": 80,
84
+ "ؤ": 81,
85
+ "إ": 82,
86
+ "ئ": 83,
87
+ "ب": 84,
88
+ "ة": 85,
89
+ "ت": 86,
90
+ "ث": 87,
91
+ "ج": 88,
92
+ "ح": 89,
93
+ "خ": 90,
94
+ "د": 91,
95
+ "ذ": 92,
96
+ "ر": 93,
97
+ "ز": 94,
98
+ "س": 95,
99
+ "ش": 96,
100
+ "ص": 97,
101
+ "ض": 98,
102
+ "ط": 99,
103
+ "ظ": 100,
104
+ "ع": 101,
105
+ "غ": 102,
106
+ "ـ": 103,
107
+ "ف": 104,
108
+ "ق": 105,
109
+ "ك": 106,
110
+ "ل": 107,
111
+ "م": 108,
112
+ "ن": 109,
113
+ "ه": 110,
114
+ "و": 111,
115
+ "ي": 112
116
+ },
117
+ "itos": {
118
+ "0": " ",
119
+ "1": "!",
120
+ "2": "(",
121
+ "3": ")",
122
+ "4": "*",
123
+ "5": "+",
124
+ "6": "-",
125
+ "7": ".",
126
+ "8": "/",
127
+ "9": "0",
128
+ "10": "1",
129
+ "11": "2",
130
+ "12": "3",
131
+ "13": "4",
132
+ "14": "5",
133
+ "15": "6",
134
+ "16": "7",
135
+ "17": "8",
136
+ "18": "9",
137
+ "19": "=",
138
+ "20": "A",
139
+ "21": "B",
140
+ "22": "C",
141
+ "23": "D",
142
+ "24": "E",
143
+ "25": "F",
144
+ "26": "G",
145
+ "27": "H",
146
+ "28": "I",
147
+ "29": "J",
148
+ "30": "K",
149
+ "31": "L",
150
+ "32": "M",
151
+ "33": "N",
152
+ "34": "O",
153
+ "35": "P",
154
+ "36": "Q",
155
+ "37": "R",
156
+ "38": "S",
157
+ "39": "T",
158
+ "40": "U",
159
+ "41": "V",
160
+ "42": "W",
161
+ "43": "X",
162
+ "44": "Y",
163
+ "45": "Z",
164
+ "46": "[",
165
+ "47": "]",
166
+ "48": "a",
167
+ "49": "b",
168
+ "50": "c",
169
+ "51": "d",
170
+ "52": "e",
171
+ "53": "f",
172
+ "54": "g",
173
+ "55": "h",
174
+ "56": "i",
175
+ "57": "j",
176
+ "58": "k",
177
+ "59": "l",
178
+ "60": "m",
179
+ "61": "n",
180
+ "62": "o",
181
+ "63": "p",
182
+ "64": "q",
183
+ "65": "r",
184
+ "66": "s",
185
+ "67": "t",
186
+ "68": "u",
187
+ "69": "v",
188
+ "70": "w",
189
+ "71": "x",
190
+ "72": "y",
191
+ "73": "z",
192
+ "74": "{",
193
+ "75": "}",
194
+ "76": "،",
195
+ "77": "؟",
196
+ "78": "ء",
197
+ "79": "آ",
198
+ "80": "أ",
199
+ "81": "ؤ",
200
+ "82": "إ",
201
+ "83": "ئ",
202
+ "84": "ب",
203
+ "85": "ة",
204
+ "86": "ت",
205
+ "87": "ث",
206
+ "88": "ج",
207
+ "89": "ح",
208
+ "90": "خ",
209
+ "91": "د",
210
+ "92": "ذ",
211
+ "93": "ر",
212
+ "94": "ز",
213
+ "95": "س",
214
+ "96": "ش",
215
+ "97": "ص",
216
+ "98": "ض",
217
+ "99": "ط",
218
+ "100": "ظ",
219
+ "101": "ع",
220
+ "102": "غ",
221
+ "103": "ـ",
222
+ "104": "ف",
223
+ "105": "ق",
224
+ "106": "ك",
225
+ "107": "ل",
226
+ "108": "م",
227
+ "109": "ن",
228
+ "110": "ه",
229
+ "111": "و",
230
+ "112": "ي"
231
+ },
232
+ "vocab_size": 115
233
+ }