| from huggingface_hub import PyTorchModelHubMixin
|
|
|
|
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import math
|
|
|
|
|
|
|
| n_embed = 64
|
| n_head = 4
|
| n_layer = 4
|
| dropout = 0.1
|
|
|
|
|
|
|
| class CausalSelfAttention(nn.Module):
|
| """A multi-head masked self-attention module."""
|
|
|
| def __init__(self, n_embed, n_head, block_size, dropout):
|
| super().__init__()
|
|
|
| self.n_embed = n_embed
|
| self.n_head = n_head
|
| self.head_size = n_embed // n_head
|
|
|
|
|
| self.c_attn = nn.Linear(n_embed, 3 * n_embed, bias=False)
|
|
|
| self.c_proj = nn.Linear(n_embed, n_embed, bias=False)
|
| self.attn_dropout = nn.Dropout(dropout)
|
| self.resid_dropout = nn.Dropout(dropout)
|
|
|
|
|
|
|
| self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))
|
| .view(1, 1, block_size, block_size))
|
|
|
| def forward(self, x):
|
| B, T, C = x.shape
|
|
|
|
|
|
|
| qkv = self.c_attn(x)
|
| q, k, v = qkv.split(self.n_embed, dim=2)
|
|
|
|
|
|
|
| k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2)
|
| q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2)
|
| v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2)
|
|
|
|
|
|
|
| wei = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_size))
|
|
|
|
|
|
|
| wei = wei.masked_fill(self.tril[:,:,:T,:T] == 0, float('-inf'))
|
|
|
|
|
| wei = F.softmax(wei, dim=-1)
|
| wei = self.attn_dropout(wei)
|
|
|
|
|
| out = wei @ v
|
|
|
|
|
| out = out.transpose(1, 2).contiguous().view(B, T, C)
|
|
|
|
|
| out = self.resid_dropout(self.c_proj(out))
|
| return out
|
|
|
|
|
|
|
| class FeedForward(nn.Module):
|
| """A two-layer MLP for processing attention output."""
|
| def __init__(self, n_embed, dropout):
|
| super().__init__()
|
| self.net = nn.Sequential(
|
|
|
| nn.Linear(n_embed, 4 * n_embed),
|
| nn.GELU(),
|
| nn.Linear(4 * n_embed, n_embed),
|
| nn.Dropout(dropout),
|
| )
|
|
|
| def forward(self, x):
|
| return self.net(x)
|
|
|
|
|
|
|
|
|
| class TransformerBlock(nn.Module):
|
| """A standard Transformer decoder block with Attention and FFN."""
|
|
|
| def __init__(self, n_embed, n_head, block_size, dropout):
|
| super().__init__()
|
|
|
| self.ln_1 = nn.LayerNorm(n_embed)
|
| self.attn = CausalSelfAttention(n_embed, n_head, block_size, dropout)
|
| self.ln_2 = nn.LayerNorm(n_embed)
|
| self.ffn = FeedForward(n_embed, dropout)
|
|
|
| def forward(self, x):
|
|
|
| x = x + self.attn(self.ln_1(x))
|
|
|
| x = x + self.ffn(self.ln_2(x))
|
| return x
|
|
|
|
|
|
|
| class TinyLLM(nn.Module, PyTorchModelHubMixin):
|
| """The complete Decoder-Only Transformer model."""
|
|
|
| def __init__(self, vocab_size, n_embed, n_head, n_layer, block_size, dropout):
|
| super().__init__()
|
|
|
| self.block_size = block_size
|
|
|
| self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
|
|
|
| self.position_embedding_table = nn.Embedding(block_size, n_embed)
|
|
|
|
|
| self.blocks = nn.Sequential(*[
|
| TransformerBlock(n_embed, n_head, block_size, dropout)
|
| for _ in range(n_layer)
|
| ])
|
|
|
| self.ln_f = nn.LayerNorm(n_embed)
|
|
|
| self.lm_head = nn.Linear(n_embed, vocab_size)
|
|
|
| def forward(self, idx, targets=None):
|
|
|
| B, T = idx.shape
|
|
|
|
|
|
|
| tok_emb = self.token_embedding_table(idx)
|
|
|
| pos = torch.arange(T, device=idx.device)
|
| pos_emb = self.position_embedding_table(pos)
|
|
|
|
|
| x = tok_emb + pos_emb
|
|
|
|
|
| x = self.blocks(x)
|
|
|
|
|
| x = self.ln_f(x)
|
| logits = self.lm_head(x)
|
|
|
| loss = None
|
| if targets is not None:
|
|
|
| B, T, C = logits.shape
|
| logits = logits.view(B*T, C)
|
| targets = targets.view(B*T)
|
|
|
|
|
| loss = F.cross_entropy(logits, targets)
|
|
|
| return logits, loss |