workshop-v1-pretraining / modeling_workshop_gpt.py
JustinAngel's picture
Upload modeling_workshop_gpt.py
e411988 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin
from transformers.modeling_outputs import CausalLMOutput
class WorkshopGPTConfig(PretrainedConfig):
model_type = "workshop_gpt"
def __init__(self, n_layer=12, n_head=12, n_embd=768, vocab_size=50304,
block_size=1024, n_inner=3072, rope_theta=10000.0, **kwargs):
super().__init__(**kwargs)
self.n_layer = n_layer
self.n_head = n_head
self.n_embd = n_embd
self.vocab_size = vocab_size
self.block_size = block_size
self.n_inner = n_inner
self.rope_theta = rope_theta
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.scale = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps).type_as(x) * self.scale
class RotaryPositionalEmbeddings(nn.Module):
def __init__(self, dim, max_seq_len=1024, base=10000.0):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.base = base
self.cache = None
def _build_cache(self, seq_len, device):
theta = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim))
seq = torch.arange(seq_len, device=device)
freqs = torch.outer(seq, theta)
self.cache = torch.stack([freqs.cos(), freqs.sin()], dim=-1)
def forward(self, x, *, input_pos=None):
seq_len = x.shape[-2]
if self.cache is None or self.cache.shape[0] < seq_len or self.cache.device != x.device:
self._build_cache(max(seq_len, self.max_seq_len), x.device)
cache = self.cache[:seq_len] if input_pos is None else self.cache[input_pos]
x1, x2 = x.float().unflatten(-1, (-1, 2)).unbind(-1)
cos, sin = cache.unbind(-1)
shape = [1] * (x.ndim - 2) + list(cos.shape)
cos, sin = cos.view(*shape), sin.view(*shape)
return torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1).flatten(-2).type_as(x)
class ReluSquaredMLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.fc_in = nn.Linear(dim, hidden_dim, bias=False)
self.fc_out = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
return self.fc_out(F.relu(self.fc_in(x)).square())
class CausalSelfAttention(nn.Module):
def __init__(self, n_embd, n_head, head_dim, rope):
super().__init__()
self.n_head = n_head
self.head_dim = head_dim
self.q_proj = nn.Linear(n_embd, n_embd, bias=False)
self.k_proj = nn.Linear(n_embd, n_embd, bias=False)
self.v_proj = nn.Linear(n_embd, n_embd, bias=False)
self.output_proj = nn.Linear(n_embd, n_embd, bias=False)
self.rope = rope
def forward(self, x):
B, T, C = x.shape
q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
q, k = self.rope(q), self.rope(k)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.output_proj(y.transpose(1, 2).contiguous().view(B, T, C))
class TransformerBlock(nn.Module):
def __init__(self, config):
super().__init__()
hd = config.n_embd // config.n_head
rope = RotaryPositionalEmbeddings(hd, config.block_size, config.rope_theta)
self.sa_norm = RMSNorm(config.n_embd)
self.attn = CausalSelfAttention(config.n_embd, config.n_head, hd, rope)
self.mlp_norm = RMSNorm(config.n_embd)
self.mlp = ReluSquaredMLP(config.n_embd, config.n_inner)
def forward(self, x):
x = x + self.attn(self.sa_norm(x))
return x + self.mlp(self.mlp_norm(x))
class WorkshopGPTForCausalLM(PreTrainedModel, GenerationMixin):
config_class = WorkshopGPTConfig
_tied_weights_keys = {}
def __init__(self, config):
super().__init__(config)
self.tok_embeddings = nn.Embedding(config.vocab_size, config.n_embd)
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)])
self.norm = RMSNorm(config.n_embd)
def forward(self, input_ids, **kwargs):
x = self.tok_embeddings(input_ids)
for layer in self.layers:
x = layer(x)
return CausalLMOutput(logits=F.linear(self.norm(x), self.tok_embeddings.weight))
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
@property
def all_tied_weights_keys(self):
return {}