Eeppa commited on
Commit
f2c7fb0
·
verified ·
1 Parent(s): 0afbae0

Update modeling_tinybuddy.py

Browse files
Files changed (1) hide show
  1. modeling_tinybuddy.py +36 -68
modeling_tinybuddy.py CHANGED
@@ -1,46 +1,13 @@
1
  """
2
  Tiny GPT-style transformer (~30M params target).
3
-
4
- Config:
5
- - 6 layers
6
- - 8 heads
7
- - d_model = 256
8
- - vocab_size = 32000 (chosen to push param count up to ~30M, since the
9
- transformer blocks themselves only have ~5M params at d_model=256/L=6;
10
- the embedding + tied LM head dominates the parameter budget.)
11
-
12
- Parameter accounting (approx):
13
- Token embedding : 32000 * 256 = 8,192,000
14
- LM head (untied) : 256 * 32000 + 32000 = 8,224,000
15
- Positional emb : 512 * 256 = 131,072
16
- Per block (x6):
17
- attn (qkv+out) : 4 * 256 * 256 + 4*256 = 263,168
18
- mlp (2 linear): 256*1024 + 1024 + 1024*256+256 = 525,568
19
- 2x LayerNorm : 4 * 256 = 1,024
20
- block total = 789,760
21
- Blocks total : 6 * 789,760 = 4,738,560
22
- Final LN : 512
23
- ---------------------------------------------------------
24
- TOTAL ~ 21.3M (tied) or ~29.5M (untied lm head) -> ~30M ✓
25
  """
26
 
27
  import math
28
  import torch
29
  import torch.nn as nn
30
  import torch.nn.functional as F
31
- from dataclasses import dataclass
32
-
33
-
34
- @dataclass
35
- class GPTConfig:
36
- vocab_size: int = 50000
37
- block_size: int = 512 # max context length
38
- n_layer: int = 6
39
- n_head: int = 8
40
- n_embd: int = 256
41
- mlp_ratio: int = 4 # hidden = 4 * n_embd
42
- dropout: float = 0.0
43
- tie_weights: bool = False # False -> ~30M params; True -> ~21M
44
 
45
 
46
  class CausalSelfAttention(nn.Module):
@@ -97,59 +64,59 @@ class Block(nn.Module):
97
  return x
98
 
99
 
100
- class TinyGPT(nn.Module):
101
- def __init__(self, cfg: GPTConfig):
102
- super().__init__()
103
- self.cfg = cfg
104
- self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.n_embd)
105
- self.pos_emb = nn.Embedding(cfg.block_size, cfg.n_embd)
106
- self.drop = nn.Dropout(cfg.dropout)
107
- self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layer)])
108
- self.ln_f = nn.LayerNorm(cfg.n_embd)
109
- self.lm_head = nn.Linear(cfg.n_embd, cfg.vocab_size, bias=False)
110
- if cfg.tie_weights:
 
 
111
  self.lm_head.weight = self.tok_emb.weight
112
- self.apply(self._init_weights)
113
 
114
- @staticmethod
115
- def _init_weights(m):
116
- if isinstance(m, nn.Linear):
117
- nn.init.normal_(m.weight, mean=0.0, std=0.02)
118
- if m.bias is not None:
119
- nn.init.zeros_(m.bias)
120
- elif isinstance(m, nn.Embedding):
121
- nn.init.normal_(m.weight, mean=0.0, std=0.02)
122
 
123
  def num_params(self, non_embedding=False):
124
  n = sum(p.numel() for p in self.parameters())
125
  if non_embedding:
126
  n -= self.tok_emb.weight.numel() + self.pos_emb.weight.numel()
127
- if not self.cfg.tie_weights:
128
  n -= self.lm_head.weight.numel()
129
  return n
130
 
131
- def forward(self, idx, targets=None):
132
- B, T = idx.shape
133
- assert T <= self.cfg.block_size, f"sequence length {T} > block_size {self.cfg.block_size}"
134
- pos = torch.arange(T, device=idx.device)
135
- x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
136
  x = self.drop(x)
137
  for blk in self.blocks:
138
  x = blk(x)
139
  x = self.ln_f(x)
140
  logits = self.lm_head(x)
141
  loss = None
142
- if targets is not None:
143
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
144
- targets.view(-1), ignore_index=-100)
145
- return logits, loss
146
 
147
- @torch.no_grad()
148
  def generate(self, idx, max_new_tokens=100, temperature=1.0, top_k=None):
149
  self.eval()
150
  for _ in range(max_new_tokens):
151
- idx_cond = idx if idx.size(1) <= self.cfg.block_size else idx[:, -self.cfg.block_size:]
152
- logits, _ = self(idx_cond)
153
  logits = logits[:, -1, :] / max(temperature, 1e-6)
154
  if top_k is not None:
155
  v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
@@ -161,9 +128,10 @@ class TinyGPT(nn.Module):
161
 
162
 
163
  if __name__ == "__main__":
 
164
  cfg = GPTConfig()
165
  m = TinyGPT(cfg)
166
  total = m.num_params()
167
  nonemb = m.num_params(non_embedding=True)
168
  print(f"Total params : {total:,} (~{total/1e6:.2f}M)")
169
- print(f"Non-embedding params: {nonemb:,} (~{nonemb/1e6:.2f}M)")
 
1
  """
2
  Tiny GPT-style transformer (~30M params target).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  """
4
 
5
  import math
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
+ from transformers import PreTrainedModel
10
+ from configuration_tinybuddy import GPTConfig
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  class CausalSelfAttention(nn.Module):
 
64
  return x
65
 
66
 
67
+ class TinyGPT(PreTrainedModel):
68
+ config_class = GPTConfig
69
+
70
+ def __init__(self, config: GPTConfig):
71
+ super().__init__(config)
72
+ self.config = config
73
+ self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
74
+ self.pos_emb = nn.Embedding(config.block_size, config.n_embd)
75
+ self.drop = nn.Dropout(config.dropout)
76
+ self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
77
+ self.ln_f = nn.LayerNorm(config.n_embd)
78
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
79
+ if config.tie_weights:
80
  self.lm_head.weight = self.tok_emb.weight
81
+ self.post_init()
82
 
83
+ def _init_weights(self, module):
84
+ if isinstance(module, nn.Linear):
85
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range if hasattr(self.config, 'initializer_range') else 0.02)
86
+ if module.bias is not None:
87
+ module.bias.data.zero_()
88
+ elif isinstance(module, nn.Embedding):
89
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range if hasattr(self.config, 'initializer_range') else 0.02)
 
90
 
91
  def num_params(self, non_embedding=False):
92
  n = sum(p.numel() for p in self.parameters())
93
  if non_embedding:
94
  n -= self.tok_emb.weight.numel() + self.pos_emb.weight.numel()
95
+ if not self.config.tie_weights:
96
  n -= self.lm_head.weight.numel()
97
  return n
98
 
99
+ def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
100
+ B, T = input_ids.shape
101
+ assert T <= self.config.block_size, f"sequence length {T} > block_size {self.config.block_size}"
102
+ pos = torch.arange(T, device=input_ids.device)
103
+ x = self.tok_emb(input_ids) + self.pos_emb(pos)[None, :, :]
104
  x = self.drop(x)
105
  for blk in self.blocks:
106
  x = blk(x)
107
  x = self.ln_f(x)
108
  logits = self.lm_head(x)
109
  loss = None
110
+ if labels is not None:
111
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
112
+ labels.view(-1), ignore_index=-100)
113
+ return (logits,) if loss is None else (logits, loss)
114
 
 
115
  def generate(self, idx, max_new_tokens=100, temperature=1.0, top_k=None):
116
  self.eval()
117
  for _ in range(max_new_tokens):
118
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
119
+ logits = self(idx_cond)[0]
120
  logits = logits[:, -1, :] / max(temperature, 1e-6)
121
  if top_k is not None:
122
  v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
 
128
 
129
 
130
  if __name__ == "__main__":
131
+ from configuration_tinybuddy import GPTConfig
132
  cfg = GPTConfig()
133
  m = TinyGPT(cfg)
134
  total = m.num_params()
135
  nonemb = m.num_params(non_embedding=True)
136
  print(f"Total params : {total:,} (~{total/1e6:.2f}M)")
137
+ print(f"Non-embedding params: {nonemb:,} (~{nonemb/1e6:.2f}M)")