JustinAngel commited on
Commit
6733f6f
·
verified ·
1 Parent(s): a91eca4

Upload 3 files

Browse files
Files changed (3) hide show
  1. config.json +15 -0
  2. model.safetensors +3 -0
  3. modeling_workshop_gpt.py +96 -0
config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "WorkshopGPTForCausalLM"
4
+ ],
5
+ "block_size": 1024,
6
+ "dtype": "float32",
7
+ "model_type": "workshop_gpt",
8
+ "n_embd": 768,
9
+ "n_head": 12,
10
+ "n_inner": 3072,
11
+ "n_layer": 12,
12
+ "rope_theta": 10000.0,
13
+ "transformers_version": "5.0.0",
14
+ "vocab_size": 50304
15
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b837fb282a8a03a3a1ba64fd1ca3b2ddba8b036fd12e718b0e6750c3adcc1460
3
+ size 648892992
modeling_workshop_gpt.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import PretrainedConfig, PreTrainedModel
5
+
6
+ class WorkshopGPTConfig(PretrainedConfig):
7
+ model_type = "workshop_gpt"
8
+ def __init__(self, n_layer=12, n_head=12, n_embd=768, vocab_size=50304,
9
+ block_size=1024, n_inner=3072, rope_theta=10000.0, **kwargs):
10
+ super().__init__(**kwargs)
11
+ self.n_layer, self.n_head, self.n_embd = n_layer, n_head, n_embd
12
+ self.vocab_size, self.block_size = vocab_size, block_size
13
+ self.n_inner, self.rope_theta = n_inner, rope_theta
14
+
15
+ class RMSNorm(nn.Module):
16
+ def __init__(self, dim, eps=1e-6):
17
+ super().__init__()
18
+ self.scale = nn.Parameter(torch.ones(dim))
19
+ self.eps = eps
20
+ def forward(self, x):
21
+ return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps).type_as(x) * self.scale
22
+
23
+ class RotaryPositionalEmbeddings(nn.Module):
24
+ def __init__(self, dim, max_seq_len=1024, base=10000.0):
25
+ super().__init__()
26
+ self.dim, self.max_seq_len, self.base = dim, max_seq_len, base
27
+ theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
28
+ self.register_buffer("theta", theta, persistent=False)
29
+ self._build_cache(max_seq_len)
30
+ def _build_cache(self, seq_len):
31
+ seq = torch.arange(seq_len, device=self.theta.device)
32
+ freqs = torch.outer(seq, self.theta)
33
+ self.register_buffer("cache", torch.stack([freqs.cos(), freqs.sin()], dim=-1), persistent=False)
34
+ def forward(self, x, *, input_pos=None):
35
+ seq_len = x.shape[-2]
36
+ if seq_len > self.cache.shape[0]: self._build_cache(seq_len)
37
+ cache = self.cache[:seq_len] if input_pos is None else self.cache[input_pos]
38
+ x1, x2 = x.float().unflatten(-1, (-1, 2)).unbind(-1)
39
+ cos, sin = cache.unbind(-1)
40
+ shape = [1] * (x.ndim - 2) + list(cos.shape)
41
+ cos, sin = cos.view(*shape), sin.view(*shape)
42
+ return torch.stack([x1*cos - x2*sin, x1*sin + x2*cos], dim=-1).flatten(-2).type_as(x)
43
+
44
+ class ReluSquaredMLP(nn.Module):
45
+ def __init__(self, dim, hidden_dim):
46
+ super().__init__()
47
+ self.fc_in = nn.Linear(dim, hidden_dim, bias=False)
48
+ self.fc_out = nn.Linear(hidden_dim, dim, bias=False)
49
+ def forward(self, x):
50
+ return self.fc_out(F.relu(self.fc_in(x)).square())
51
+
52
+ class CausalSelfAttention(nn.Module):
53
+ def __init__(self, n_embd, n_head, head_dim, rope):
54
+ super().__init__()
55
+ self.n_head, self.head_dim = n_head, head_dim
56
+ self.q_proj = nn.Linear(n_embd, n_embd, bias=False)
57
+ self.k_proj = nn.Linear(n_embd, n_embd, bias=False)
58
+ self.v_proj = nn.Linear(n_embd, n_embd, bias=False)
59
+ self.output_proj = nn.Linear(n_embd, n_embd, bias=False)
60
+ self.rope = rope
61
+ def forward(self, x):
62
+ B, T, C = x.shape
63
+ q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
64
+ k = self.k_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
65
+ v = self.v_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
66
+ q, k = self.rope(q), self.rope(k)
67
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
68
+ return self.output_proj(y.transpose(1, 2).contiguous().view(B, T, C))
69
+
70
+ class TransformerBlock(nn.Module):
71
+ def __init__(self, config):
72
+ super().__init__()
73
+ hd = config.n_embd // config.n_head
74
+ rope = RotaryPositionalEmbeddings(hd, config.block_size, config.rope_theta)
75
+ self.sa_norm = RMSNorm(config.n_embd)
76
+ self.attn = CausalSelfAttention(config.n_embd, config.n_head, hd, rope)
77
+ self.mlp_norm = RMSNorm(config.n_embd)
78
+ self.mlp = ReluSquaredMLP(config.n_embd, config.n_inner)
79
+ def forward(self, x):
80
+ x = x + self.attn(self.sa_norm(x))
81
+ return x + self.mlp(self.mlp_norm(x))
82
+
83
+ class WorkshopGPTForCausalLM(PreTrainedModel):
84
+ config_class = WorkshopGPTConfig
85
+ def __init__(self, config):
86
+ super().__init__(config)
87
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.n_embd)
88
+ self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)])
89
+ self.norm = RMSNorm(config.n_embd)
90
+ self.output = nn.Linear(config.n_embd, config.vocab_size, bias=False)
91
+ self.output.weight = nn.Parameter(self.tok_embeddings.weight.clone())
92
+ def forward(self, input_ids, **kwargs):
93
+ x = self.tok_embeddings(input_ids)
94
+ for layer in self.layers:
95
+ x = layer(x)
96
+ return type("Out", (), {"logits": self.output(self.norm(x))})()