File size: 4,096 Bytes
1d0a8b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import math, torch, torch.nn as nn, torch.nn.functional as F

def build_model(vocab_size, n_layer, n_head, n_embd, block_size, dropout=0.2):
    class Head(nn.Module):
        def __init__(self, n_embd, head_size):
            super().__init__()
            self.key = nn.Linear(n_embd, head_size, bias=False)
            self.query = nn.Linear(n_embd, head_size, bias=False)
            self.value = nn.Linear(n_embd, head_size, bias=False)
            self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
            self.scale = head_size ** -0.5
            self.drop = nn.Dropout(dropout)
        def forward(self, x):
            B, T, C = x.shape
            k = self.key(x); q = self.query(x)
            wei = (q @ k.transpose(-2, -1)) * self.scale
            wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
            wei = torch.softmax(wei, dim=-1)
            wei = self.drop(wei)
            v = self.value(x)
            return wei @ v

    class MultiHeadAttention(nn.Module):
        def __init__(self, n_embd, n_head):
            super().__init__()
            head_size = n_embd // n_head
            self.heads = nn.ModuleList([Head(n_embd, head_size) for _ in range(n_head)])
            self.proj = nn.Linear(n_embd, n_embd)
            self.drop = nn.Dropout(dropout)
        def forward(self, x):
            x = torch.cat([h(x) for h in self.heads], dim=-1)
            return self.drop(self.proj(x))

    class FeedForward(nn.Module):
        def __init__(self, n_embd):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(n_embd, 4 * n_embd),
                nn.ReLU(),
                nn.Linear(4 * n_embd, n_embd),
                nn.Dropout(dropout),
            )
        def forward(self, x): return self.net(x)

    class Block(nn.Module):
        def __init__(self, n_embd, n_head):
            super().__init__()
            self.ln1 = nn.LayerNorm(n_embd)
            self.sa  = MultiHeadAttention(n_embd, n_head)
            self.ln2 = nn.LayerNorm(n_embd)
            self.ffw = FeedForward(n_embd)
        def forward(self, x):
            x = x + self.sa(self.ln1(x))
            x = x + self.ffw(self.ln2(x))
            return x

    class GPTLanguageModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.block_size = block_size
            self.token_embedding_table    = nn.Embedding(vocab_size, n_embd)
            self.position_embedding_table = nn.Embedding(block_size, n_embd)
            self.blocks = nn.Sequential(*[Block(n_embd, n_head) for _ in range(n_layer)])
            self.ln_f = nn.LayerNorm(n_embd)
            self.lm_head = nn.Linear(n_embd, vocab_size)
        def forward(self, idx, targets=None):
            B, T = idx.shape
            tok = self.token_embedding_table(idx)
            pos = self.position_embedding_table(torch.arange(T, device=idx.device))
            x = tok + pos
            x = self.blocks(x)
            x = self.ln_f(x)
            logits = self.lm_head(x)
            loss = None
            if targets is not None:
                loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
            return logits, loss
        @torch.no_grad()
        def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
            for _ in range(max_new_tokens):
                idx_cond = idx[:, -self.block_size:]
                logits, _ = self(idx_cond)
                logits = logits[:, -1, :] / max(temperature, 1e-6)
                if top_k is not None:
                    v, ix = torch.topk(logits, k=min(top_k, logits.size(-1)))
                    mask = torch.ones_like(logits, dtype=torch.bool)
                    mask.scatter_(1, ix, False)
                    logits = logits.masked_fill(mask, float("-inf"))
                probs = torch.softmax(logits, dim=-1)
                next_id = torch.multinomial(probs, num_samples=1)
                idx = torch.cat((idx, next_id), dim=1)
            return idx

    return GPTLanguageModel()