teszenofficial commited on
Commit
e398fee
·
verified ·
1 Parent(s): d646f4f

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +202 -0
model.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+
7
+ class MultiHeadSelfAttention(nn.Module):
8
+ """Multi-Head Self-Attention mechanism"""
9
+
10
+ def __init__(self, d_model, n_heads, dropout=0.1):
11
+ super().__init__()
12
+ assert d_model % n_heads == 0
13
+
14
+ self.d_model = d_model
15
+ self.n_heads = n_heads
16
+ self.d_k = d_model // n_heads
17
+
18
+ self.q_linear = nn.Linear(d_model, d_model)
19
+ self.k_linear = nn.Linear(d_model, d_model)
20
+ self.v_linear = nn.Linear(d_model, d_model)
21
+ self.out_linear = nn.Linear(d_model, d_model)
22
+
23
+ self.dropout = nn.Dropout(dropout)
24
+
25
+ def forward(self, x, mask=None):
26
+ batch_size, seq_len, d_model = x.size()
27
+
28
+ # Linear projections
29
+ Q = self.q_linear(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
30
+ K = self.k_linear(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
31
+ V = self.v_linear(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
32
+
33
+ # Scaled dot-product attention
34
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
35
+
36
+ if mask is not None:
37
+ scores = scores.masked_fill(mask == 0, float('-inf'))
38
+
39
+ attn_weights = F.softmax(scores, dim=-1)
40
+ attn_weights = self.dropout(attn_weights)
41
+
42
+ context = torch.matmul(attn_weights, V)
43
+ context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
44
+
45
+ output = self.out_linear(context)
46
+ return output
47
+
48
+
49
+ class FeedForward(nn.Module):
50
+ """Position-wise Feed-Forward Network"""
51
+
52
+ def __init__(self, d_model, d_ff, dropout=0.1):
53
+ super().__init__()
54
+ self.linear1 = nn.Linear(d_model, d_ff)
55
+ self.linear2 = nn.Linear(d_ff, d_model)
56
+ self.dropout = nn.Dropout(dropout)
57
+
58
+ def forward(self, x):
59
+ return self.linear2(self.dropout(F.gelu(self.linear1(x))))
60
+
61
+
62
+ class TransformerBlock(nn.Module):
63
+ """Single Transformer Decoder Block"""
64
+
65
+ def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
66
+ super().__init__()
67
+ self.attention = MultiHeadSelfAttention(d_model, n_heads, dropout)
68
+ self.feed_forward = FeedForward(d_model, d_ff, dropout)
69
+ self.ln1 = nn.LayerNorm(d_model)
70
+ self.ln2 = nn.LayerNorm(d_model)
71
+ self.dropout1 = nn.Dropout(dropout)
72
+ self.dropout2 = nn.Dropout(dropout)
73
+
74
+ def forward(self, x, mask=None):
75
+ # Self-attention with residual connection
76
+ attn_output = self.attention(self.ln1(x), mask)
77
+ x = x + self.dropout1(attn_output)
78
+
79
+ # Feed-forward with residual connection
80
+ ff_output = self.feed_forward(self.ln2(x))
81
+ x = x + self.dropout2(ff_output)
82
+
83
+ return x
84
+
85
+
86
+ class MTPMiniModel(nn.Module):
87
+ """MTP Mini - GPT-style Transformer Language Model"""
88
+
89
+ def __init__(self, vocab_size, d_model=256, n_layers=4, n_heads=4,
90
+ d_ff=1024, max_seq_len=128, dropout=0.1):
91
+ super().__init__()
92
+
93
+ self.vocab_size = vocab_size
94
+ self.d_model = d_model
95
+ self.max_seq_len = max_seq_len
96
+
97
+ # Token embeddings
98
+ self.token_embedding = nn.Embedding(vocab_size, d_model)
99
+
100
+ # Positional embeddings (learnable)
101
+ self.position_embedding = nn.Embedding(max_seq_len, d_model)
102
+
103
+ # Transformer blocks
104
+ self.blocks = nn.ModuleList([
105
+ TransformerBlock(d_model, n_heads, d_ff, dropout)
106
+ for _ in range(n_layers)
107
+ ])
108
+
109
+ # Final layer norm
110
+ self.ln_f = nn.LayerNorm(d_model)
111
+
112
+ # Output projection to vocabulary
113
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
114
+
115
+ # Weight tying
116
+ self.lm_head.weight = self.token_embedding.weight
117
+
118
+ self.dropout = nn.Dropout(dropout)
119
+
120
+ # Initialize weights
121
+ self.apply(self._init_weights)
122
+
123
+ def _init_weights(self, module):
124
+ if isinstance(module, nn.Linear):
125
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
126
+ if module.bias is not None:
127
+ torch.nn.init.zeros_(module.bias)
128
+ elif isinstance(module, nn.Embedding):
129
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
130
+ elif isinstance(module, nn.LayerNorm):
131
+ torch.nn.init.zeros_(module.bias)
132
+ torch.nn.init.ones_(module.weight)
133
+
134
+ def forward(self, input_ids, targets=None):
135
+ batch_size, seq_len = input_ids.size()
136
+
137
+ # Create causal mask
138
+ mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).view(1, 1, seq_len, seq_len)
139
+
140
+ # Token embeddings + positional embeddings
141
+ positions = torch.arange(0, seq_len, device=input_ids.device).unsqueeze(0)
142
+ tok_emb = self.token_embedding(input_ids)
143
+ pos_emb = self.position_embedding(positions)
144
+ x = self.dropout(tok_emb + pos_emb)
145
+
146
+ # Pass through transformer blocks
147
+ for block in self.blocks:
148
+ x = block(x, mask)
149
+
150
+ # Final layer norm
151
+ x = self.ln_f(x)
152
+
153
+ # Project to vocabulary
154
+ logits = self.lm_head(x)
155
+
156
+ # Calculate loss if targets provided
157
+ loss = None
158
+ if targets is not None:
159
+ loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1))
160
+
161
+ return logits, loss
162
+
163
+ def generate(self, input_ids, max_new_tokens=50, temperature=1.0, top_k=50, top_p=0.9):
164
+ """Autoregressive generation with sampling"""
165
+ self.eval()
166
+
167
+ with torch.no_grad():
168
+ for _ in range(max_new_tokens):
169
+ # Crop to max_seq_len
170
+ input_ids_cond = input_ids if input_ids.size(1) <= self.max_seq_len else input_ids[:, -self.max_seq_len:]
171
+
172
+ # Forward pass
173
+ logits, _ = self(input_ids_cond)
174
+ logits = logits[:, -1, :] / temperature
175
+
176
+ # Top-k filtering
177
+ if top_k > 0:
178
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
179
+ logits[logits < v[:, [-1]]] = float('-inf')
180
+
181
+ # Top-p (nucleus) filtering
182
+ if top_p < 1.0:
183
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
184
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
185
+ sorted_indices_to_remove = cumulative_probs > top_p
186
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
187
+ sorted_indices_to_remove[:, 0] = 0
188
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
189
+ logits[indices_to_remove] = float('-inf')
190
+
191
+ # Sample from distribution
192
+ probs = F.softmax(logits, dim=-1)
193
+ next_token = torch.multinomial(probs, num_samples=1)
194
+
195
+ # Append to sequence
196
+ input_ids = torch.cat([input_ids, next_token], dim=1)
197
+
198
+ return input_ids
199
+
200
+ def count_parameters(self):
201
+ """Count trainable parameters"""
202
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)