gijl commited on
Commit
3d4f07b
·
verified ·
1 Parent(s): 5201445

Delete model.py

Browse files
Files changed (1) hide show
  1. model.py +0 -67
model.py DELETED
@@ -1,67 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import math
4
- from torch.nn import functional as F
5
-
6
- class SelfAttention(nn.Module):
7
- def __init__(self, n_embd=768, n_head=8):
8
- super().__init__()
9
- self.qkv = nn.Linear(n_embd, n_embd * 3, bias=False)
10
- self.proj = nn.Linear(n_embd, n_embd)
11
- self.n_head = n_head
12
-
13
- # إضافة القناع (Causal Mask) لمنع النموذج من رؤية المستقبل
14
- # تم ضبطه على 256 ليتوافق مع حجم أوزانك
15
- self.register_buffer("tril", torch.tril(torch.ones(256, 256))
16
- .view(1, 1, 256, 256))
17
-
18
- def forward(self, x):
19
- B, T, C = x.shape
20
- q, k, v = self.qkv(x).split(C, dim=2)
21
-
22
- q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
23
- k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
24
- v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
25
-
26
- att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.shape[-1]))
27
-
28
- # تطبيق القناع: إخفاء الحروف المستقبلية
29
- att = att.masked_fill(self.tril[:,:,:T,:T] == 0, float('-inf'))
30
-
31
- att = torch.softmax(att, dim=-1)
32
- y = att @ v
33
- y = y.transpose(1, 2).contiguous().view(B, T, C)
34
- return self.proj(y)
35
-
36
- class Block(nn.Module):
37
- def __init__(self, n_embd=768, n_head=8):
38
- super().__init__()
39
- self.ln1 = nn.LayerNorm(n_embd)
40
- self.attn = SelfAttention(n_embd, n_head)
41
- self.ln2 = nn.LayerNorm(n_embd)
42
- self.mlp = nn.Sequential(
43
- nn.Linear(n_embd, 4 * n_embd),
44
- nn.GELU(),
45
- nn.Linear(4 * n_embd, n_embd),
46
- )
47
-
48
- def forward(self, x):
49
- x = x + self.attn(self.ln1(x))
50
- x = x + self.mlp(self.ln2(x))
51
- return x
52
-
53
- class MedicalMasterAI(nn.Module):
54
- def __init__(self, vocab_size=115, n_layer=48, n_head=8, n_embd=768):
55
- super().__init__()
56
- self.token_embedding = nn.Embedding(vocab_size, n_embd)
57
- # تم التعديل إلى 256 بناءً على سجل الخطأ في أوزانك
58
- self.position_embedding = nn.Parameter(torch.zeros(1, 256, n_embd))
59
- self.blocks = nn.Sequential(*[Block(n_embd, n_head) for _ in range(n_layer)])
60
- self.ln_f = nn.LayerNorm(n_embd)
61
- self.lm_head = nn.Linear(n_embd, vocab_size)
62
-
63
- def forward(self, idx):
64
- b, t = idx.shape
65
- x = self.token_embedding(idx) + self.position_embedding[:, :t, :]
66
- x = self.blocks(x)
67
- return self.lm_head(self.ln_f(x))