File size: 4,096 Bytes
1d0a8b6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | import math, torch, torch.nn as nn, torch.nn.functional as F
def build_model(vocab_size, n_layer, n_head, n_embd, block_size, dropout=0.2):
class Head(nn.Module):
def __init__(self, n_embd, head_size):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
self.scale = head_size ** -0.5
self.drop = nn.Dropout(dropout)
def forward(self, x):
B, T, C = x.shape
k = self.key(x); q = self.query(x)
wei = (q @ k.transpose(-2, -1)) * self.scale
wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
wei = torch.softmax(wei, dim=-1)
wei = self.drop(wei)
v = self.value(x)
return wei @ v
class MultiHeadAttention(nn.Module):
def __init__(self, n_embd, n_head):
super().__init__()
head_size = n_embd // n_head
self.heads = nn.ModuleList([Head(n_embd, head_size) for _ in range(n_head)])
self.proj = nn.Linear(n_embd, n_embd)
self.drop = nn.Dropout(dropout)
def forward(self, x):
x = torch.cat([h(x) for h in self.heads], dim=-1)
return self.drop(self.proj(x))
class FeedForward(nn.Module):
def __init__(self, n_embd):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout),
)
def forward(self, x): return self.net(x)
class Block(nn.Module):
def __init__(self, n_embd, n_head):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.sa = MultiHeadAttention(n_embd, n_head)
self.ln2 = nn.LayerNorm(n_embd)
self.ffw = FeedForward(n_embd)
def forward(self, x):
x = x + self.sa(self.ln1(x))
x = x + self.ffw(self.ln2(x))
return x
class GPTLanguageModel(nn.Module):
def __init__(self):
super().__init__()
self.block_size = block_size
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(*[Block(n_embd, n_head) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)
def forward(self, idx, targets=None):
B, T = idx.shape
tok = self.token_embedding_table(idx)
pos = self.position_embedding_table(torch.arange(T, device=idx.device))
x = tok + pos
x = self.blocks(x)
x = self.ln_f(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / max(temperature, 1e-6)
if top_k is not None:
v, ix = torch.topk(logits, k=min(top_k, logits.size(-1)))
mask = torch.ones_like(logits, dtype=torch.bool)
mask.scatter_(1, ix, False)
logits = logits.masked_fill(mask, float("-inf"))
probs = torch.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, next_id), dim=1)
return idx
return GPTLanguageModel()
|