| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin |
| from transformers.modeling_outputs import CausalLMOutput |
|
|
|
|
| class WorkshopGPTConfig(PretrainedConfig): |
| model_type = "workshop_gpt" |
|
|
| def __init__(self, n_layer=12, n_head=12, n_embd=768, vocab_size=50304, |
| block_size=1024, n_inner=3072, rope_theta=10000.0, **kwargs): |
| super().__init__(**kwargs) |
| self.n_layer = n_layer |
| self.n_head = n_head |
| self.n_embd = n_embd |
| self.vocab_size = vocab_size |
| self.block_size = block_size |
| self.n_inner = n_inner |
| self.rope_theta = rope_theta |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-6): |
| super().__init__() |
| self.scale = nn.Parameter(torch.ones(dim)) |
| self.eps = eps |
|
|
| def forward(self, x): |
| return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps).type_as(x) * self.scale |
|
|
|
|
| class RotaryPositionalEmbeddings(nn.Module): |
| def __init__(self, dim, max_seq_len=1024, base=10000.0): |
| super().__init__() |
| self.dim = dim |
| self.max_seq_len = max_seq_len |
| self.base = base |
| self.cache = None |
|
|
| def _build_cache(self, seq_len, device): |
| theta = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim)) |
| seq = torch.arange(seq_len, device=device) |
| freqs = torch.outer(seq, theta) |
| self.cache = torch.stack([freqs.cos(), freqs.sin()], dim=-1) |
|
|
| def forward(self, x, *, input_pos=None): |
| seq_len = x.shape[-2] |
| if self.cache is None or self.cache.shape[0] < seq_len or self.cache.device != x.device: |
| self._build_cache(max(seq_len, self.max_seq_len), x.device) |
| cache = self.cache[:seq_len] if input_pos is None else self.cache[input_pos] |
| x1, x2 = x.float().unflatten(-1, (-1, 2)).unbind(-1) |
| cos, sin = cache.unbind(-1) |
| shape = [1] * (x.ndim - 2) + list(cos.shape) |
| cos, sin = cos.view(*shape), sin.view(*shape) |
| return torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1).flatten(-2).type_as(x) |
|
|
|
|
| class ReluSquaredMLP(nn.Module): |
| def __init__(self, dim, hidden_dim): |
| super().__init__() |
| self.fc_in = nn.Linear(dim, hidden_dim, bias=False) |
| self.fc_out = nn.Linear(hidden_dim, dim, bias=False) |
|
|
| def forward(self, x): |
| return self.fc_out(F.relu(self.fc_in(x)).square()) |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__(self, n_embd, n_head, head_dim, rope): |
| super().__init__() |
| self.n_head = n_head |
| self.head_dim = head_dim |
| self.q_proj = nn.Linear(n_embd, n_embd, bias=False) |
| self.k_proj = nn.Linear(n_embd, n_embd, bias=False) |
| self.v_proj = nn.Linear(n_embd, n_embd, bias=False) |
| self.output_proj = nn.Linear(n_embd, n_embd, bias=False) |
| self.rope = rope |
|
|
| def forward(self, x): |
| B, T, C = x.shape |
| q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
| k = self.k_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
| v = self.v_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
| q, k = self.rope(q), self.rope(k) |
| y = F.scaled_dot_product_attention(q, k, v, is_causal=True) |
| return self.output_proj(y.transpose(1, 2).contiguous().view(B, T, C)) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| hd = config.n_embd // config.n_head |
| rope = RotaryPositionalEmbeddings(hd, config.block_size, config.rope_theta) |
| self.sa_norm = RMSNorm(config.n_embd) |
| self.attn = CausalSelfAttention(config.n_embd, config.n_head, hd, rope) |
| self.mlp_norm = RMSNorm(config.n_embd) |
| self.mlp = ReluSquaredMLP(config.n_embd, config.n_inner) |
|
|
| def forward(self, x): |
| x = x + self.attn(self.sa_norm(x)) |
| return x + self.mlp(self.mlp_norm(x)) |
|
|
|
|
| class WorkshopGPTForCausalLM(PreTrainedModel, GenerationMixin): |
| config_class = WorkshopGPTConfig |
| _tied_weights_keys = {} |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.tok_embeddings = nn.Embedding(config.vocab_size, config.n_embd) |
| self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)]) |
| self.norm = RMSNorm(config.n_embd) |
|
|
| def forward(self, input_ids, **kwargs): |
| x = self.tok_embeddings(input_ids) |
| for layer in self.layers: |
| x = layer(x) |
| return CausalLMOutput(logits=F.linear(self.norm(x), self.tok_embeddings.weight)) |
|
|
| def prepare_inputs_for_generation(self, input_ids, **kwargs): |
| return {"input_ids": input_ids} |
|
|
| @property |
| def all_tied_weights_keys(self): |
| return {} |