| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn import functional as F |
| from miditok import REMI, TokenizerConfig |
| from midi2audio import FluidSynth |
| import os |
|
|
| n_embd = 512 |
| n_head = 8 |
| n_layer = 8 |
| block_size = 1024 |
| dropout = 0.3 |
| vocab_size = 387 |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
| class MultiHeadAttention(nn.Module): |
| def __init__(self, num_heads, head_size): |
| super().__init__() |
| self.num_heads = num_heads |
| self.head_size = head_size |
| self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=False) |
| self.c_proj = nn.Linear(n_embd, n_embd) |
| self.dropout = dropout |
|
|
| def forward(self, x): |
| B, T, C = x.size() |
| qkv = self.c_attn(x) |
| q, k, v = qkv.split(n_embd, dim=2) |
| q = q.view(B, T, self.num_heads, self.head_size).transpose(1, 2) |
| k = k.view(B, T, self.num_heads, self.head_size).transpose(1, 2) |
| v = v.view(B, T, self.num_heads, self.head_size).transpose(1, 2) |
|
|
| y = F.scaled_dot_product_attention( |
| q, k, v, |
| dropout_p=self.dropout if self.training else 0.0, |
| is_causal=True |
| ) |
| y = y.transpose(1, 2).contiguous().view(B, T, C) |
| return self.c_proj(y) |
|
|
| class FeedForward(nn.Module): |
| def __init__(self, n_embd): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(n_embd, 4 * n_embd), |
| nn.GELU(), |
| nn.Linear(4 * n_embd, n_embd), |
| nn.Dropout(dropout) |
| ) |
| def forward(self, x): return self.net(x) |
|
|
| class Block(nn.Module): |
| def __init__(self, n_embd, n_head): |
| super().__init__() |
| head_size = n_embd // n_head |
| self.sa = MultiHeadAttention(n_head, head_size) |
| self.ffwd = FeedForward(n_embd) |
| self.ln1, self.ln2 = nn.LayerNorm(n_embd), nn.LayerNorm(n_embd) |
| def forward(self, x): |
| x = x + self.sa(self.ln1(x)) |
| x = x + self.ffwd(self.ln2(x)) |
| return x |
|
|
| class TinyMozart(nn.Module): |
| def __init__(self, vocab_size): |
| super().__init__() |
| self.token_embedding_table = nn.Embedding(vocab_size, n_embd) |
| self.position_embedding_table = nn.Embedding(block_size, n_embd) |
| self.blocks = nn.Sequential(*[Block(n_embd, n_head) for _ in range(n_layer)]) |
| self.ln_f = nn.LayerNorm(n_embd) |
| self.lm_head = nn.Linear(n_embd, vocab_size) |
| def forward(self, idx): |
| B, T = idx.shape |
| x = self.token_embedding_table(idx) + self.position_embedding_table(torch.arange(T, device=idx.device)) |
| x = self.blocks(x) |
| logits = self.lm_head(self.ln_f(x)) |
| return logits |
|
|
| config = TokenizerConfig( |
| num_velocities=16, |
| use_chords=True, |
| use_tempos=True, |
| use_time_signatures=True |
| ) |
| tokenizer = REMI(config) |
| model = TinyMozart(vocab_size).to(device) |
|
|
| best_path = 'model.pt' |
| if os.path.exists(best_path): |
| checkpoint = torch.load(best_path, map_location=device) |
| state_dict = checkpoint['model_state_dict'] |
| |
| new_state_dict = {} |
| for k, v in state_dict.items(): |
| name = k[7:] if k.startswith('module.') else k |
| new_state_dict[name] = v |
| |
| model.load_state_dict(new_state_dict) |
| print(f"✅ Model loaded! (Iter {checkpoint['iter']}, Best Val Loss: {checkpoint.get('best_val_loss', 'unknow')})") |
| else: |
| print(f"❌ Checkpoint not found at {best_path}") |
|
|
| model.eval() |
|
|
| @torch.no_grad() |
| def generate_pro(max_len=3000, temp=1.05, top_p=0.95, top_k=25, rep_penalty=1.5): |
| print("🎹 TinyMozart is composing music...") |
| x = torch.zeros((1, 1), dtype=torch.long, device=device) |
| |
| for _ in range(max_len): |
| x_cond = x[:, -block_size:] |
| logits = model(x_cond)[:, -1, :] / temp |
| |
| |
| for token in set(x[0, -10:].tolist()): |
| logits[0, token] /= rep_penalty |
| |
| |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = -float('Inf') |
| |
| |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| sorted_indices_to_remove = cumulative_probs > top_p |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| sorted_indices_to_remove[..., 0] = 0 |
| logits.scatter_(1, sorted_indices[sorted_indices_to_remove].unsqueeze(0), -float('Inf')) |
| |
| probs = F.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| x = torch.cat((x, next_token), dim=1) |
| |
| return x[0].cpu().numpy().tolist() |
|
|
| tokens = generate_pro() |
| if tokens[0] == 0: |
| tokens = tokens[1:] |
|
|
| midi = tokenizer.decode([tokens]) |
| midi.dump_midi("mozart_masterpiece.mid") |
| print("✅ MIDI saved: mozart_masterpiece.mid") |
|
|
| SF2_PATH = "/usr/share/sounds/sf2/FluidR3_GM.sf2" |
| if os.path.exists(SF2_PATH): |
| print("🎵 Generating Audio...") |
| fs = FluidSynth(SF2_PATH) |
| fs.midi_to_audio("mozart_masterpiece.mid", "mozart_masterpiece.wav") |
| from IPython.display import Audio |
| display(Audio("mozart_masterpiece.wav")) |
| else: |
| print("⚠️ FluidSynth Soundfont not found!") |
|
|