LH-Tech-AI commited on
Commit
641b7ce
·
verified ·
1 Parent(s): ed73909

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +256 -0
train.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### ------------------------------------------------------------------------------------------------ ###
2
+ ### First: do `apt-get update && apt-get install -y fluidsynth` and `pip install miditok midi2audio` ###
3
+ ### ------------------------------------------------------------------------------------------------ ###
4
+
5
+ ### IMPORTS ###
6
+ import os
7
+ import requests
8
+ import zipfile
9
+ import numpy as np
10
+ from miditok import REMI
11
+ from pathlib import Path
12
+ from tqdm import tqdm
13
+ import torch
14
+ import torch.nn as nn
15
+ from torch.nn import functional as F
16
+ import time
17
+
18
+ ### DATA LOADING ###
19
+ MIDI_URL = "https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip"
20
+ ZIP_FILE = "maestro_midi.zip"
21
+ EXTRACT_PATH = "maestro_raw"
22
+ DATA_DIR = "data"
23
+ os.makedirs(DATA_DIR, exist_ok=True)
24
+
25
+ def download_and_prepare():
26
+ if not os.path.exists(ZIP_FILE):
27
+ print("Downloading MIDI dataset...")
28
+ r = requests.get(MIDI_URL)
29
+ with open(ZIP_FILE, "wb") as f:
30
+ f.write(r.content)
31
+
32
+ if not os.path.exists(EXTRACT_PATH):
33
+ print("Unpacking files...")
34
+ with zipfile.ZipFile(ZIP_FILE, 'r') as zip_ref:
35
+ zip_ref.extractall(EXTRACT_PATH)
36
+
37
+ config = TokenizerConfig(
38
+ num_velocities=16,
39
+ use_chords=True,
40
+ use_tempos=True,
41
+ use_time_signatures=True
42
+ )
43
+ tokenizer = REMI(config)
44
+
45
+ all_tokens = []
46
+ midi_paths = list(Path(EXTRACT_PATH).rglob("*.mid*"))
47
+
48
+ print(f"Tokenizing {len(midi_paths)} MIDI files...")
49
+ for path in tqdm(midi_paths):
50
+ try:
51
+ midi_tokens = tokenizer(path)
52
+
53
+ if isinstance(midi_tokens, list):
54
+ ids = midi_tokens[0].ids
55
+ else:
56
+ ids = midi_tokens.ids
57
+
58
+ if len(ids) > 0:
59
+ all_tokens.extend(ids)
60
+ except Exception as e:
61
+ continue
62
+
63
+ if len(all_tokens) == 0:
64
+ print("ERROR: No tokens processed!")
65
+ return
66
+
67
+ data = np.array(all_tokens, dtype=np.uint16)
68
+ n = len(data)
69
+ train_data = data[:int(n*0.9)]
70
+ val_data = data[int(n*0.9):]
71
+
72
+ train_data.tofile(os.path.join(DATA_DIR, 'train.bin'))
73
+ val_data.tofile(os.path.join(DATA_DIR, 'val.bin'))
74
+
75
+ print(f"Preparation done!")
76
+ print(f"Train Tokens: {len(train_data)} | Val Tokens: {len(val_data)}")
77
+ print(f"Vocab size: {len(tokenizer)}")
78
+
79
+ download_and_prepare()
80
+
81
+ ### TRAINING ###
82
+ batch_size = 64
83
+ block_size = 1024
84
+ max_iters = 20000
85
+ learning_rate = 5e-4
86
+ gradient_accumulation_steps = 4
87
+ eval_interval = 250
88
+ eval_iters = 100
89
+ n_embd = 512
90
+ n_head = 8
91
+ n_layer = 8
92
+ dropout = 0.3
93
+ vocab_size = 387
94
+ data_dir = 'data'
95
+ checkpoint_path = 'tinymozart_ckpt.pt'
96
+ best_model_path = 'tinymozart_best.pt'
97
+ log_path = 'training_log.txt'
98
+ device = 'cuda'
99
+
100
+ def get_batch(data):
101
+ ix = torch.randint(len(data) - block_size, (batch_size,))
102
+ x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
103
+ y = torch.stack([torch.from_numpy((data[i+1:i+block_size+1]).astype(np.int64)) for i in ix])
104
+ return x.to(device), y.to(device)
105
+
106
+ @torch.no_grad()
107
+ def estimate_loss(model, train_data, val_data):
108
+ out = {}
109
+ model.eval()
110
+ for split, data in [('train', train_data), ('val', val_data)]:
111
+ losses = torch.zeros(eval_iters)
112
+ for k in range(eval_iters):
113
+ x, y = get_batch(data)
114
+ _, loss = model(x, y)
115
+ losses[k] = loss.mean().item()
116
+ out[split] = losses.mean()
117
+ model.train()
118
+ return out
119
+
120
+ # --- 3. Architektur ---
121
+ class MultiHeadAttention(nn.Module):
122
+ def __init__(self, num_heads, head_size):
123
+ super().__init__()
124
+ self.num_heads = num_heads
125
+ self.head_size = head_size
126
+ self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=False)
127
+ self.c_proj = nn.Linear(n_embd, n_embd)
128
+ self.dropout = dropout
129
+
130
+ def forward(self, x):
131
+ B, T, C = x.size()
132
+ qkv = self.c_attn(x)
133
+ q, k, v = qkv.split(n_embd, dim=2)
134
+
135
+ q = q.view(B, T, self.num_heads, self.head_size).transpose(1, 2)
136
+ k = k.view(B, T, self.num_heads, self.head_size).transpose(1, 2)
137
+ v = v.view(B, T, self.num_heads, self.head_size).transpose(1, 2)
138
+
139
+ y = F.scaled_dot_product_attention(
140
+ q, k, v,
141
+ dropout_p=self.dropout if self.training else 0.0,
142
+ is_causal=True
143
+ )
144
+
145
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
146
+ return self.c_proj(y)
147
+
148
+ class FeedForward(nn.Module):
149
+ def __init__(self, n_embd):
150
+ super().__init__()
151
+ self.net = nn.Sequential(nn.Linear(n_embd, 4 * n_embd), nn.GELU(), nn.Linear(4 * n_embd, n_embd), nn.Dropout(dropout))
152
+ def forward(self, x): return self.net(x)
153
+
154
+ class Block(nn.Module):
155
+ def __init__(self, n_embd, n_head):
156
+ super().__init__()
157
+ head_size = n_embd // n_head
158
+ self.sa = MultiHeadAttention(n_head, head_size)
159
+ self.ffwd = FeedForward(n_embd)
160
+ self.ln1, self.ln2 = nn.LayerNorm(n_embd), nn.LayerNorm(n_embd)
161
+ def forward(self, x):
162
+ x = x + self.sa(self.ln1(x))
163
+ x = x + self.ffwd(self.ln2(x))
164
+ return x
165
+
166
+ class TinyMozart(nn.Module):
167
+ def __init__(self, vocab_size):
168
+ super().__init__()
169
+ self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
170
+ self.position_embedding_table = nn.Embedding(block_size, n_embd)
171
+ self.blocks = nn.Sequential(*[Block(n_embd, n_head) for _ in range(n_layer)])
172
+ self.ln_f = nn.LayerNorm(n_embd)
173
+ self.lm_head = nn.Linear(n_embd, vocab_size)
174
+ def forward(self, idx, targets=None):
175
+ B, T = idx.shape
176
+ x = self.token_embedding_table(idx) + self.position_embedding_table(torch.arange(T, device=idx.device))
177
+ x = self.blocks(x)
178
+ logits = self.lm_head(self.ln_f(x))
179
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) if targets is not None else None
180
+ return logits, loss
181
+
182
+ def train():
183
+ train_data = np.fromfile(os.path.join(data_dir, 'train.bin'), dtype=np.uint16)
184
+ val_data = np.fromfile(os.path.join(data_dir, 'val.bin'), dtype=np.uint16)
185
+
186
+ model = TinyMozart(vocab_size).to(device)
187
+
188
+ if torch.cuda.device_count() > 1:
189
+ print(f"🚀 Using {torch.cuda.device_count()} GPUs!")
190
+ model = nn.DataParallel(model)
191
+
192
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.1)
193
+
194
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_iters)
195
+
196
+ start_iter = 0
197
+ best_val_loss = float('inf')
198
+
199
+ target_ckpt = checkpoint_path if os.path.exists(checkpoint_path) else (best_model_path if os.path.exists(best_model_path) else None)
200
+
201
+ if target_ckpt:
202
+ print(f"Loading checkpoint from {target_ckpt}...")
203
+ checkpoint = torch.load(target_ckpt, map_location=device)
204
+ model.load_state_dict(checkpoint['model_state_dict'])
205
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
206
+ start_iter = checkpoint['iter']
207
+ best_val_loss = checkpoint.get('best_val_loss', float('inf'))
208
+ print(f"Resuming from iter {start_iter} with best_val_loss {best_val_loss:.4f}")
209
+
210
+ model.train()
211
+ t0 = time.time()
212
+
213
+ for iter in range(start_iter, max_iters):
214
+ optimizer.zero_grad(set_to_none=True)
215
+ accum_loss = 0
216
+
217
+ for _ in range(gradient_accumulation_steps):
218
+ xb, yb = get_batch(train_data)
219
+ logits, loss = model(xb, yb)
220
+ loss = loss.mean() / gradient_accumulation_steps
221
+ loss.backward()
222
+ accum_loss += loss.item()
223
+
224
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
225
+
226
+ optimizer.step()
227
+ scheduler.step()
228
+
229
+ if iter % 50 == 0:
230
+ dt = time.time() - t0
231
+ t0 = time.time()
232
+ print(f"Iter {iter}: Loss {accum_loss:.4f} | {dt*1000/50:.1f}ms/step", flush=True)
233
+
234
+ if iter % eval_interval == 0:
235
+ losses = estimate_loss(model, train_data, val_data)
236
+ print(f">>> EVAL {iter}: Train {losses['train']:.4f}, Val {losses['val']:.4f}", flush=True)
237
+
238
+ with open(log_path, 'a') as f:
239
+ f.write(f"{iter},{losses['train']:.4f},{losses['val']:.4f}\n")
240
+
241
+ checkpoint = {
242
+ 'iter': iter,
243
+ 'model_state_dict': model.state_dict(),
244
+ 'optimizer_state_dict': optimizer.state_dict(),
245
+ 'best_val_loss': best_val_loss
246
+ }
247
+ torch.save(checkpoint, checkpoint_path)
248
+
249
+ if losses['val'] < best_val_loss:
250
+ best_val_loss = losses['val']
251
+ checkpoint['best_val_loss'] = best_val_loss
252
+ torch.save(checkpoint, best_model_path)
253
+ print(f"✨ New best model saved! (Loss: {best_val_loss:.4f})")
254
+
255
+ if __name__ == "__main__":
256
+ train()