LH-Tech-AI's picture
Create use.py
ed73909 verified
# To use this run this first:
# apt-get update -y > /dev/null
# apt-get install -y fluid-soundfont-gm > /dev/null
import torch
import torch.nn as nn
from torch.nn import functional as F
from miditok import REMI, TokenizerConfig
from midi2audio import FluidSynth
import os
n_embd = 512
n_head = 8
n_layer = 8
block_size = 1024
dropout = 0.3
vocab_size = 387
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads, head_size):
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=False)
self.c_proj = nn.Linear(n_embd, n_embd)
self.dropout = dropout
def forward(self, x):
B, T, C = x.size()
qkv = self.c_attn(x)
q, k, v = qkv.split(n_embd, dim=2)
q = q.view(B, T, self.num_heads, self.head_size).transpose(1, 2)
k = k.view(B, T, self.num_heads, self.head_size).transpose(1, 2)
v = v.view(B, T, self.num_heads, self.head_size).transpose(1, 2)
y = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True
)
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.c_proj(y)
class FeedForward(nn.Module):
def __init__(self, n_embd):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.GELU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout)
)
def forward(self, x): return self.net(x)
class Block(nn.Module):
def __init__(self, n_embd, n_head):
super().__init__()
head_size = n_embd // n_head
self.sa = MultiHeadAttention(n_head, head_size)
self.ffwd = FeedForward(n_embd)
self.ln1, self.ln2 = nn.LayerNorm(n_embd), nn.LayerNorm(n_embd)
def forward(self, x):
x = x + self.sa(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
class TinyMozart(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(*[Block(n_embd, n_head) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)
def forward(self, idx):
B, T = idx.shape
x = self.token_embedding_table(idx) + self.position_embedding_table(torch.arange(T, device=idx.device))
x = self.blocks(x)
logits = self.lm_head(self.ln_f(x))
return logits
config = TokenizerConfig(
num_velocities=16,
use_chords=True,
use_tempos=True,
use_time_signatures=True
)
tokenizer = REMI(config)
model = TinyMozart(vocab_size).to(device)
best_path = 'model.pt'
if os.path.exists(best_path):
checkpoint = torch.load(best_path, map_location=device)
state_dict = checkpoint['model_state_dict']
new_state_dict = {}
for k, v in state_dict.items():
name = k[7:] if k.startswith('module.') else k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
print(f"✅ Model loaded! (Iter {checkpoint['iter']}, Best Val Loss: {checkpoint.get('best_val_loss', 'unknow')})")
else:
print(f"❌ Checkpoint not found at {best_path}")
model.eval()
@torch.no_grad()
def generate_pro(max_len=3000, temp=1.05, top_p=0.95, top_k=25, rep_penalty=1.5):
print("🎹 TinyMozart is composing music...")
x = torch.zeros((1, 1), dtype=torch.long, device=device)
for _ in range(max_len):
x_cond = x[:, -block_size:]
logits = model(x_cond)[:, -1, :] / temp
# Repetition Penalty
for token in set(x[0, -10:].tolist()):
logits[0, token] /= rep_penalty
# Top-K
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# Top-P
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
logits.scatter_(1, sorted_indices[sorted_indices_to_remove].unsqueeze(0), -float('Inf'))
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
x = torch.cat((x, next_token), dim=1)
return x[0].cpu().numpy().tolist()
tokens = generate_pro()
if tokens[0] == 0:
tokens = tokens[1:]
midi = tokenizer.decode([tokens])
midi.dump_midi("mozart_masterpiece.mid")
print("✅ MIDI saved: mozart_masterpiece.mid")
SF2_PATH = "/usr/share/sounds/sf2/FluidR3_GM.sf2"
if os.path.exists(SF2_PATH):
print("🎵 Generating Audio...")
fs = FluidSynth(SF2_PATH)
fs.midi_to_audio("mozart_masterpiece.mid", "mozart_masterpiece.wav")
from IPython.display import Audio
display(Audio("mozart_masterpiece.wav"))
else:
print("⚠️ FluidSynth Soundfont not found!")