asdf98 commited on
Commit
408dd91
·
verified ·
1 Parent(s): a5069eb

v0.3: Fix training hang — precompute tensors, num_workers=0, torch.compile, warmup batch, timing diagnostics

Browse files
Files changed (1) hide show
  1. MuseMorphic_Training.ipynb +179 -253
MuseMorphic_Training.ipynb CHANGED
@@ -2,32 +2,17 @@
2
  "nbformat": 4,
3
  "nbformat_minor": 0,
4
  "metadata": {
5
- "colab": {
6
- "provenance": [],
7
- "gpuType": "T4"
8
- },
9
- "kernelspec": {
10
- "name": "python3",
11
- "display_name": "Python 3"
12
- },
13
- "language_info": {
14
- "name": "python"
15
- },
16
  "accelerator": "GPU"
17
  },
18
  "cells": [
19
  {
20
  "cell_type": "markdown",
21
  "metadata": {},
22
- "source": [
23
- "# 🎵 MuseMorphic: Lightweight MIDI Generator (v0.2)\n",
24
- "\n",
25
- "**A novel consumer-grade architecture for controllable, infinite-length MIDI generation.**\n",
26
- "\n",
27
- "v0.2 Performance fixes: weight_norm (not spectral_norm), chunked SSM scan, vectorized masking.\n",
28
- "\n",
29
- "📄 [Architecture Details](https://huggingface.co/asdf98/MuseMorphic)"
30
- ]
31
  },
32
  {
33
  "cell_type": "code",
@@ -35,19 +20,12 @@
35
  "metadata": {},
36
  "outputs": [],
37
  "source": [
38
- "# ============================================================\n",
39
- "# 1. Install Dependencies\n",
40
- "# ============================================================\n",
41
- "!pip install -q torch torchvision torchaudio\n",
42
- "!pip install -q einops datasets pretty_midi midiutil huggingface_hub\n",
43
- "\n",
44
- "# Clone MuseMorphic repo (fresh pull to get v0.2 fixes)\n",
45
  "!rm -rf /content/MuseMorphic\n",
46
  "!git clone https://huggingface.co/asdf98/MuseMorphic /content/MuseMorphic\n",
47
- "\n",
48
- "import sys\n",
49
- "sys.path.insert(0, '/content/MuseMorphic/musemorphic')\n",
50
- "print('✅ Dependencies installed!')"
51
  ]
52
  },
53
  {
@@ -56,24 +34,16 @@
56
  "metadata": {},
57
  "outputs": [],
58
  "source": [
59
- "# ============================================================\n",
60
- "# 2. Check GPU & Hardware\n",
61
- "# ============================================================\n",
62
- "import torch\n",
63
- "print(f'PyTorch: {torch.__version__}')\n",
64
- "print(f'CUDA: {torch.cuda.is_available()}')\n",
65
  "if torch.cuda.is_available():\n",
66
- " print(f'GPU: {torch.cuda.get_device_name(0)}')\n",
67
- " print(f'VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')\n",
68
- " print(f'BF16: {torch.cuda.is_bf16_supported()}')\n",
69
- "\n",
70
- "if torch.cuda.is_available() and torch.cuda.is_bf16_supported():\n",
71
- " DTYPE = 'bf16'; amp_dtype = torch.bfloat16\n",
72
- "elif torch.cuda.is_available():\n",
73
- " DTYPE = 'fp16'; amp_dtype = torch.float16\n",
74
- "else:\n",
75
- " DTYPE = 'fp32'; amp_dtype = torch.float32\n",
76
- "print(f'\\nUsing: {DTYPE}')"
77
  ]
78
  },
79
  {
@@ -82,20 +52,15 @@
82
  "metadata": {},
83
  "outputs": [],
84
  "source": [
85
- "# ============================================================\n",
86
- "# 3. Configuration\n",
87
- "# ============================================================\n",
88
  "MODEL_SIZE = 'small' # @param ['tiny', 'small', 'medium']\n",
89
- "DATASET = 'auto' # @param ['auto', 'maestro', 'synthetic']\n",
90
- "MAX_PIECES = 500 # @param {type: 'integer'}\n",
91
  "VAE_EPOCHS = 15 # @param {type: 'integer'}\n",
92
  "MAMBA_EPOCHS = 30 # @param {type: 'integer'}\n",
93
  "BATCH_SIZE = 32 # @param {type: 'integer'}\n",
94
  "LEARNING_RATE = 3e-4 # @param {type: 'number'}\n",
95
- "OUTPUT_DIR = '/content/checkpoints'\n",
96
- "PUSH_TO_HUB = False\n",
97
- "HUB_MODEL_ID = ''\n",
98
- "print(f'Config: {MODEL_SIZE} model, {DATASET} data, VAE {VAE_EPOCHS}ep + Mamba {MAMBA_EPOCHS}ep')"
99
  ]
100
  },
101
  {
@@ -104,24 +69,18 @@
104
  "metadata": {},
105
  "outputs": [],
106
  "source": [
107
- "# ============================================================\n",
108
- "# 4. Build Model\n",
109
- "# ============================================================\n",
110
- "from model import MuseMorphicConfig, MuseMorphic, model_summary\n",
111
- "\n",
112
- "SIZE_CONFIGS = {\n",
113
  " 'tiny': MuseMorphicConfig(d_model=128, vae_encoder_layers=2, vae_decoder_layers=2,\n",
114
- " vae_n_heads=4, vae_d_ff=256, latent_dim=32, mamba_d_model=128,\n",
115
- " mamba_n_layers=4, mamba_d_state=8, mamba_expand=2),\n",
116
  " 'small': MuseMorphicConfig(d_model=256, vae_encoder_layers=3, vae_decoder_layers=3,\n",
117
- " vae_n_heads=4, vae_d_ff=512, latent_dim=64, mamba_d_model=256,\n",
118
- " mamba_n_layers=8, mamba_d_state=16, mamba_expand=2),\n",
119
  " 'medium': MuseMorphicConfig(d_model=384, vae_encoder_layers=4, vae_decoder_layers=4,\n",
120
- " vae_n_heads=6, vae_d_ff=768, latent_dim=96, mamba_d_model=384,\n",
121
- " mamba_n_layers=12, mamba_d_state=16, mamba_expand=2),\n",
122
  "}\n",
123
- "config = SIZE_CONFIGS[MODEL_SIZE]\n",
124
- "model = model_summary(config)"
125
  ]
126
  },
127
  {
@@ -130,32 +89,29 @@
130
  "metadata": {},
131
  "outputs": [],
132
  "source": [
133
- "# ============================================================\n",
134
- "# 5. Load & Preprocess Dataset\n",
135
- "# ============================================================\n",
136
- "import logging\n",
137
- "logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')\n",
138
  "from data_pipeline import auto_select_dataset, load_dataset_notes, preprocess_dataset, _generate_synthetic_dataset\n",
139
  "from tokenizer import REMIPlusTokenizer\n",
140
  "\n",
141
  "tokenizer = REMIPlusTokenizer()\n",
142
- "if DATASET == 'auto': dataset_name = auto_select_dataset()\n",
143
- "elif DATASET == 'maestro': dataset_name = 'maestro_v1_sustain'\n",
144
- "else: dataset_name = None\n",
145
- "\n",
146
- "if dataset_name:\n",
147
- " try:\n",
148
- " pieces = load_dataset_notes(dataset_name, max_pieces=MAX_PIECES)\n",
149
- " print(f'✅ Loaded {len(pieces)} pieces from {dataset_name}')\n",
150
- " except Exception as e:\n",
151
- " print(f'⚠️ {e}\\nFalling back to synthetic data...')\n",
152
- " pieces = _generate_synthetic_dataset(MAX_PIECES)\n",
153
- "else:\n",
154
  " pieces = _generate_synthetic_dataset(MAX_PIECES)\n",
155
- " print(f'✅ Generated {len(pieces)} synthetic pieces')\n",
 
 
 
 
 
 
 
156
  "\n",
 
157
  "phrases, controls = preprocess_dataset(pieces, tokenizer, max_phrase_len=config.vae_max_seq_len)\n",
158
- "print(f'📊 {len(phrases)} phrases, avg {sum(len(p) for p in phrases)/len(phrases):.1f} tokens')"
 
159
  ]
160
  },
161
  {
@@ -164,75 +120,99 @@
164
  "metadata": {},
165
  "outputs": [],
166
  "source": [
167
- "# ============================================================\n",
168
- "# 6. Training — Stage 1: PhraseVAE\n",
169
- "# ============================================================\n",
170
- "import time, random, math, numpy as np\n",
 
 
 
 
 
171
  "import torch.nn.functional as F\n",
172
- "from torch.utils.data import DataLoader, Dataset\n",
173
  "from model import PhraseVAE, ZClip, apply_span_mask_vectorized\n",
174
  "\n",
175
- "# Seed\n",
176
  "SEED = 42\n",
177
  "random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)\n",
178
  "if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)\n",
179
  "\n",
180
- "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
181
- "use_scaler = (amp_dtype == torch.float16)\n",
182
- "scaler = torch.amp.GradScaler() if use_scaler else None\n",
183
- "\n",
184
- "class PhraseDS(Dataset):\n",
185
- " def __init__(self, phrases, max_len, pad_id=0):\n",
186
- " self.phrases = phrases; self.max_len = max_len; self.pad_id = pad_id\n",
187
- " def __len__(self): return len(self.phrases)\n",
188
- " def __getitem__(self, idx):\n",
189
- " ids = self.phrases[idx][:self.max_len]\n",
190
- " return torch.tensor(ids + [self.pad_id] * (self.max_len - len(ids)), dtype=torch.long)\n",
191
- "\n",
192
- "train_ds = PhraseDS(phrases, config.vae_max_seq_len)\n",
193
  "train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,\n",
194
- " num_workers=2, pin_memory=True, drop_last=True)\n",
195
  "\n",
 
196
  "vae = PhraseVAE(config).to(device)\n",
197
  "print(f'PhraseVAE: {sum(p.numel() for p in vae.parameters()):,} params')\n",
198
  "\n",
199
  "optimizer = torch.optim.AdamW(vae.parameters(), lr=LEARNING_RATE, weight_decay=0.01)\n",
200
  "scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=500, T_mult=2, eta_min=1e-6)\n",
201
  "zclip = ZClip(z_thresh=2.5)\n",
 
202
  "\n",
203
- "# ---- Training Loop (uses vectorized span masking NO Python loop over batch) ----\n",
 
 
 
 
 
 
 
 
 
 
 
 
204
  "print('\\n' + '='*60 + '\\nStarting PhraseVAE Training\\n' + '='*60)\n",
205
  "\n",
206
- "pretrain_epochs = max(1, VAE_EPOCHS // 5)\n",
207
- "ae_epochs = max(1, VAE_EPOCHS * 3 // 5)\n",
208
- "vae_epochs_count = max(1, VAE_EPOCHS - pretrain_epochs - ae_epochs)\n",
209
- "stages = [('1a-Pretrain', pretrain_epochs, 0.0, True),\n",
210
- " ('1b-AE', ae_epochs, 0.0, False),\n",
211
- " ('1c-VAE', vae_epochs_count, 0.01, False)]\n",
212
  "\n",
213
  "global_step = 0\n",
214
  "history = {'loss': [], 'recon': [], 'kl': []}\n",
215
  "\n",
216
  "for stage_name, n_epochs, kl_weight, use_masking in stages:\n",
217
- " print(f'\\n--- Stage {stage_name} ({n_epochs} epochs, KL={kl_weight}) ---')\n",
218
  " if stage_name == '1c-VAE':\n",
219
  " for pg in optimizer.param_groups: pg['lr'] = LEARNING_RATE * 0.1\n",
220
  "\n",
221
  " for epoch in range(n_epochs):\n",
222
- " vae.train(); epoch_loss = 0; n_batches = 0; t0 = time.time()\n",
223
- " for batch in train_loader:\n",
224
- " token_ids = batch.to(device)\n",
225
- " # VECTORIZED masking — runs entirely on GPU, no Python loop\n",
226
- " input_ids = apply_span_mask_vectorized(token_ids, mask_id=config.mask_token_id) if use_masking else token_ids\n",
 
 
 
 
 
 
227
  "\n",
228
  " optimizer.zero_grad(set_to_none=True)\n",
229
- " with torch.autocast(device_type=device.type, dtype=amp_dtype):\n",
 
230
  " outputs = vae(input_ids, target_tokens=token_ids, kl_weight=kl_weight)\n",
231
  " loss = outputs['loss']\n",
232
  "\n",
233
  " if torch.isnan(loss) or torch.isinf(loss):\n",
234
- " print(f'⚠️ NaN at step {global_step}! Skipping...')\n",
235
- " optimizer.zero_grad(set_to_none=True); continue\n",
236
  "\n",
237
  " if use_scaler:\n",
238
  " scaler.scale(loss).backward()\n",
@@ -247,15 +227,12 @@
247
  " history['recon'].append(outputs['recon_loss'].item())\n",
248
  " history['kl'].append(outputs['kl_loss'].item())\n",
249
  "\n",
250
- " elapsed = time.time() - t0\n",
251
- " avg_loss = epoch_loss / max(n_batches, 1)\n",
252
- " print(f' Epoch {epoch+1}/{n_epochs} | Loss: {avg_loss:.4f} | '\n",
253
- " f'Recon: {outputs[\"recon_loss\"].item():.4f} | '\n",
254
- " f'KL: {outputs[\"kl_loss\"].item():.4f} | '\n",
255
- " f'LR: {optimizer.param_groups[0][\"lr\"]:.2e} | {elapsed:.1f}s | '\n",
256
- " f'{n_batches/elapsed:.1f} batch/s')\n",
257
  "\n",
258
- "print(f'\\n✅ PhraseVAE training complete! ({global_step} steps)')"
259
  ]
260
  },
261
  {
@@ -264,82 +241,64 @@
264
  "metadata": {},
265
  "outputs": [],
266
  "source": [
267
- "# ============================================================\n",
268
- "# 7. Training — Stage 2: LatentMamba\n",
269
- "# ============================================================\n",
270
  "from model import LatentMamba\n",
271
  "\n",
272
  "vae.eval()\n",
273
  "for p in vae.parameters(): p.requires_grad = False\n",
274
  "\n",
275
- "print('Encoding phrases to latent space...')\n",
 
276
  "all_latents = []\n",
277
- "encode_loader = DataLoader(train_ds, batch_size=128, shuffle=False, num_workers=2)\n",
278
  "with torch.no_grad():\n",
279
- " for batch in encode_loader:\n",
280
- " with torch.autocast(device_type=device.type, dtype=amp_dtype):\n",
281
- " z, _, _ = vae.encode(batch.to(device))\n",
282
  " all_latents.append(z.cpu())\n",
283
  "all_z = torch.cat(all_latents, dim=0)\n",
284
- "print(f'Encoded {all_z.shape[0]} phrases latent dim {all_z.shape[1]}')\n",
285
- "\n",
286
- "SEQ_LEN = min(64, len(all_z) // 4)\n",
287
- "latent_seqs = [all_z[i:i+SEQ_LEN] for i in range(0, len(all_z) - SEQ_LEN, SEQ_LEN // 2)]\n",
288
- "print(f'{len(latent_seqs)} latent sequences of length {SEQ_LEN}')\n",
289
  "\n",
290
- "class LatentDS(Dataset):\n",
291
- " def __init__(self, seqs): self.seqs = seqs\n",
292
- " def __len__(self): return len(self.seqs)\n",
293
- " def __getitem__(self, idx): return self.seqs[idx]\n",
294
  "\n",
295
- "latent_loader = DataLoader(LatentDS(latent_seqs),\n",
296
- " batch_size=min(BATCH_SIZE, len(latent_seqs)), shuffle=True, drop_last=True)\n",
297
  "\n",
298
  "mamba = LatentMamba(config).to(device)\n",
299
  "print(f'LatentMamba: {sum(p.numel() for p in mamba.parameters()):,} params')\n",
300
  "\n",
301
- "mamba_opt = torch.optim.AdamW(mamba.parameters(), lr=LEARNING_RATE*0.5, weight_decay=0.01)\n",
302
- "mamba_sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(mamba_opt, T_0=300, T_mult=2, eta_min=1e-6)\n",
303
- "mamba_zclip = ZClip(z_thresh=2.5)\n",
304
- "mamba_scaler = torch.amp.GradScaler() if use_scaler else None\n",
305
  "\n",
306
- "print('\\n' + '='*60 + '\\nStarting LatentMamba Training\\n' + '='*60)\n",
307
  "mamba_history = {'mse': [], 'cos': []}\n",
308
  "\n",
309
  "for epoch in range(MAMBA_EPOCHS):\n",
310
- " mamba.train(); epoch_loss = 0; n_batches = 0; t0 = time.time()\n",
311
- " for batch in latent_loader:\n",
312
- " z_seq = batch.to(device)\n",
313
- " z_input, z_target = z_seq[:, :-1], z_seq[:, 1:]\n",
314
- " mamba_opt.zero_grad(set_to_none=True)\n",
315
- " with torch.autocast(device_type=device.type, dtype=amp_dtype):\n",
316
- " z_pred = mamba(z_input)\n",
317
- " mse_loss = F.mse_loss(z_pred, z_target)\n",
318
- " cos_loss = 1.0 - F.cosine_similarity(\n",
319
- " z_pred.reshape(-1, z_pred.shape[-1]),\n",
320
- " z_target.reshape(-1, z_target.shape[-1]), dim=-1).mean()\n",
321
- " loss = mse_loss + 0.1 * cos_loss\n",
322
- "\n",
323
- " if torch.isnan(loss) or torch.isinf(loss):\n",
324
- " mamba_opt.zero_grad(set_to_none=True); continue\n",
325
- "\n",
326
  " if use_scaler:\n",
327
- " mamba_scaler.scale(loss).backward()\n",
328
- " mamba_scaler.unscale_(mamba_opt); mamba_zclip(mamba)\n",
329
- " mamba_scaler.step(mamba_opt); mamba_scaler.update()\n",
330
- " else:\n",
331
- " loss.backward(); mamba_zclip(mamba); mamba_opt.step()\n",
332
- " mamba_sched.step()\n",
333
- " epoch_loss += loss.item(); n_batches += 1\n",
334
- " mamba_history['mse'].append(mse_loss.item())\n",
335
- " mamba_history['cos'].append(cos_loss.item())\n",
336
- "\n",
337
- " if (epoch + 1) % 5 == 0 or epoch == 0:\n",
338
- " elapsed = time.time() - t0\n",
339
- " print(f' Epoch {epoch+1}/{MAMBA_EPOCHS} | Loss: {epoch_loss/max(n_batches,1):.6f} | '\n",
340
- " f'MSE: {mse_loss.item():.6f} | Cos: {cos_loss.item():.4f} | {elapsed:.1f}s')\n",
341
- "\n",
342
- "print('\\n✅ LatentMamba training complete!')"
343
  ]
344
  },
345
  {
@@ -348,22 +307,16 @@
348
  "metadata": {},
349
  "outputs": [],
350
  "source": [
351
- "# ============================================================\n",
352
- "# 8. Plot Training Curves\n",
353
- "# ============================================================\n",
354
  "import matplotlib.pyplot as plt\n",
355
  "fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n",
356
- "\n",
357
- "for ax, data, title, color in [\n",
358
- " (axes[0], history['loss'], 'PhraseVAE Loss', 'blue'),\n",
359
- " (axes[1], history['kl'], 'KL Divergence', 'red'),\n",
360
- " (axes[2], mamba_history['mse'], 'LatentMamba MSE', 'green')]:\n",
361
- " ax.plot(data, alpha=0.3, color=color)\n",
362
  " w = min(50, max(1, len(data)//5))\n",
363
- " if w > 1: ax.plot(np.convolve(data, np.ones(w)/w, 'valid'), color=color, lw=2)\n",
364
- " ax.set_title(title); ax.set_xlabel('Step'); ax.grid(True, alpha=0.3)\n",
365
- "\n",
366
- "plt.tight_layout(); plt.savefig('/content/training_curves.png', dpi=150); plt.show()"
367
  ]
368
  },
369
  {
@@ -372,46 +325,34 @@
372
  "metadata": {},
373
  "outputs": [],
374
  "source": [
375
- "# ============================================================\n",
376
- "# 9. Generate Music!\n",
377
- "# ============================================================\n",
378
  "from model import MuseMorphic\n",
379
  "from tokenizer import notes_to_midi_file\n",
380
  "\n",
381
  "full_model = MuseMorphic(config).to(device)\n",
382
- "full_model.phrase_vae = vae\n",
383
- "full_model.latent_mamba = mamba\n",
384
- "full_model.eval()\n",
385
- "\n",
386
- "N_PHRASES = 16; TEMPERATURE = 0.7\n",
387
- "print(f'Generating {N_PHRASES} phrases at temperature {TEMPERATURE}...')\n",
388
  "\n",
 
389
  "with torch.no_grad():\n",
390
- " z_gen = mamba.generate(n_phrases=N_PHRASES, temperature=TEMPERATURE, batch_size=1)\n",
391
- " print(f'Latent shape: {z_gen.shape}')\n",
392
- "\n",
393
  " all_tokens = []\n",
394
  " for t in range(z_gen.shape[1]):\n",
395
  " z = z_gen[:, t]\n",
396
- " gen_ids = [config.bos_token_id]\n",
397
  " for _ in range(128):\n",
398
- " inp = torch.tensor([gen_ids], dtype=torch.long, device=device)\n",
399
- " with torch.autocast(device_type=device.type, dtype=amp_dtype):\n",
400
  " logits = vae.decode(z, inp)\n",
401
- " probs = F.softmax(logits[0, -1] / max(TEMPERATURE, 0.1), dim=-1)\n",
402
- " tok = torch.multinomial(probs, 1).item()\n",
403
- " gen_ids.append(tok)\n",
404
  " if tok == config.eos_token_id: break\n",
405
- " all_tokens.extend(tokenizer.decode(gen_ids))\n",
406
  "\n",
407
  "notes = tokenizer.tokens_to_midi_notes(all_tokens)\n",
408
- "print(f'Generated {len(notes)} notes from {len(all_tokens)} tokens')\n",
409
- "\n",
410
  "if notes:\n",
411
- " notes_to_midi_file(notes, '/content/generated_music.mid')\n",
412
- " print('🎵 MIDI saved to /content/generated_music.mid')\n",
413
- "else:\n",
414
- " print('⚠️ No notes. Try more training epochs or different temperature.')"
415
  ]
416
  },
417
  {
@@ -420,22 +361,13 @@
420
  "metadata": {},
421
  "outputs": [],
422
  "source": [
423
- "# ============================================================\n",
424
- "# 10. Save Model\n",
425
- "# ============================================================\n",
426
- "import os\n",
427
- "from dataclasses import asdict\n",
428
  "os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
429
- "save_path = os.path.join(OUTPUT_DIR, 'musemorphic_model.pt')\n",
430
  "torch.save({'vae_state_dict': vae.state_dict(), 'mamba_state_dict': mamba.state_dict(),\n",
431
- " 'config': asdict(config), 'training_history': {'vae': history, 'mamba': mamba_history}}, save_path)\n",
432
- "tokenizer.save(os.path.join(OUTPUT_DIR, 'tokenizer'))\n",
433
- "print(f'✅ Model saved to {save_path} ({os.path.getsize(save_path)/1e6:.1f} MB)')\n",
434
- "\n",
435
- "if PUSH_TO_HUB and HUB_MODEL_ID:\n",
436
- " from huggingface_hub import HfApi\n",
437
- " HfApi().upload_folder(folder_path=OUTPUT_DIR, repo_id=HUB_MODEL_ID, repo_type='model')\n",
438
- " print(f'✅ Pushed to https://huggingface.co/{HUB_MODEL_ID}')"
439
  ]
440
  },
441
  {
@@ -444,18 +376,12 @@
444
  "metadata": {},
445
  "outputs": [],
446
  "source": [
447
- "# ============================================================\n",
448
- "# 11. Listen to Generated MIDI\n",
449
- "# ============================================================\n",
450
  "try:\n",
451
- " from IPython.display import Audio, display\n",
452
- " import pretty_midi\n",
453
- " pm = pretty_midi.PrettyMIDI('/content/generated_music.mid')\n",
454
- " audio = pm.fluidsynth(fs=22050)\n",
455
- " print('🎧 Listen:'); display(Audio(audio, rate=22050))\n",
456
- "except Exception as e:\n",
457
- " print(f'Audio playback unavailable: {e}')\n",
458
- " print('Download /content/generated_music.mid and play in any MIDI player.')"
459
  ]
460
  }
461
  ]
 
2
  "nbformat": 4,
3
  "nbformat_minor": 0,
4
  "metadata": {
5
+ "colab": {"provenance": [], "gpuType": "T4"},
6
+ "kernelspec": {"name": "python3", "display_name": "Python 3"},
7
+ "language_info": {"name": "python"},
 
 
 
 
 
 
 
 
8
  "accelerator": "GPU"
9
  },
10
  "cells": [
11
  {
12
  "cell_type": "markdown",
13
  "metadata": {},
14
+ "source": ["# 🎵 MuseMorphic v0.3 — Lightweight MIDI Generator\n",
15
+ "📄 [Architecture](https://huggingface.co/asdf98/MuseMorphic)"]
 
 
 
 
 
 
 
16
  },
17
  {
18
  "cell_type": "code",
 
20
  "metadata": {},
21
  "outputs": [],
22
  "source": [
23
+ "# 1. Install & Clone\n",
24
+ "!pip install -q torch torchvision torchaudio einops datasets pretty_midi midiutil huggingface_hub\n",
 
 
 
 
 
25
  "!rm -rf /content/MuseMorphic\n",
26
  "!git clone https://huggingface.co/asdf98/MuseMorphic /content/MuseMorphic\n",
27
+ "import sys; sys.path.insert(0, '/content/MuseMorphic/musemorphic')\n",
28
+ "print('✅ Ready')"
 
 
29
  ]
30
  },
31
  {
 
34
  "metadata": {},
35
  "outputs": [],
36
  "source": [
37
+ "# 2. GPU check\n",
38
+ "import torch, time\n",
39
+ "print(f'PyTorch {torch.__version__}, CUDA {torch.cuda.is_available()}')\n",
 
 
 
40
  "if torch.cuda.is_available():\n",
41
+ " print(f'GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_mem/1e9:.1f}GB')\n",
42
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
43
+ "# T4 does NOT support BF16 — use FP16 with GradScaler\n",
44
+ "amp_dtype = torch.float16 if torch.cuda.is_available() else torch.float32\n",
45
+ "use_scaler = (amp_dtype == torch.float16)\n",
46
+ "print(f'AMP dtype: {amp_dtype}, GradScaler: {use_scaler}')"
 
 
 
 
 
47
  ]
48
  },
49
  {
 
52
  "metadata": {},
53
  "outputs": [],
54
  "source": [
55
+ "# 3. Config\n",
 
 
56
  "MODEL_SIZE = 'small' # @param ['tiny', 'small', 'medium']\n",
57
+ "DATASET = 'synthetic' # @param ['auto', 'maestro', 'synthetic']\n",
58
+ "MAX_PIECES = 200 # @param {type: 'integer'}\n",
59
  "VAE_EPOCHS = 15 # @param {type: 'integer'}\n",
60
  "MAMBA_EPOCHS = 30 # @param {type: 'integer'}\n",
61
  "BATCH_SIZE = 32 # @param {type: 'integer'}\n",
62
  "LEARNING_RATE = 3e-4 # @param {type: 'number'}\n",
63
+ "OUTPUT_DIR = '/content/checkpoints'"
 
 
 
64
  ]
65
  },
66
  {
 
69
  "metadata": {},
70
  "outputs": [],
71
  "source": [
72
+ "# 4. Build model\n",
73
+ "from model import MuseMorphicConfig, model_summary\n",
74
+ "CFGS = {\n",
 
 
 
75
  " 'tiny': MuseMorphicConfig(d_model=128, vae_encoder_layers=2, vae_decoder_layers=2,\n",
76
+ " vae_n_heads=4, vae_d_ff=256, latent_dim=32, mamba_d_model=128, mamba_n_layers=4, mamba_d_state=8),\n",
 
77
  " 'small': MuseMorphicConfig(d_model=256, vae_encoder_layers=3, vae_decoder_layers=3,\n",
78
+ " vae_n_heads=4, vae_d_ff=512, latent_dim=64, mamba_d_model=256, mamba_n_layers=8, mamba_d_state=16),\n",
 
79
  " 'medium': MuseMorphicConfig(d_model=384, vae_encoder_layers=4, vae_decoder_layers=4,\n",
80
+ " vae_n_heads=6, vae_d_ff=768, latent_dim=96, mamba_d_model=384, mamba_n_layers=12, mamba_d_state=16),\n",
 
81
  "}\n",
82
+ "config = CFGS[MODEL_SIZE]\n",
83
+ "_ = model_summary(config)"
84
  ]
85
  },
86
  {
 
89
  "metadata": {},
90
  "outputs": [],
91
  "source": [
92
+ "# 5. Load & preprocess data\n",
93
+ "import logging; logging.basicConfig(level=logging.INFO)\n",
 
 
 
94
  "from data_pipeline import auto_select_dataset, load_dataset_notes, preprocess_dataset, _generate_synthetic_dataset\n",
95
  "from tokenizer import REMIPlusTokenizer\n",
96
  "\n",
97
  "tokenizer = REMIPlusTokenizer()\n",
98
+ "t0 = time.time()\n",
99
+ "\n",
100
+ "if DATASET == 'synthetic':\n",
 
 
 
 
 
 
 
 
 
101
  " pieces = _generate_synthetic_dataset(MAX_PIECES)\n",
102
+ "elif DATASET == 'maestro':\n",
103
+ " try: pieces = load_dataset_notes('maestro_v1_sustain', max_pieces=MAX_PIECES)\n",
104
+ " except Exception as e: print(f'⚠️ {e}'); pieces = _generate_synthetic_dataset(MAX_PIECES)\n",
105
+ "else:\n",
106
+ " try: pieces = load_dataset_notes(auto_select_dataset(), max_pieces=MAX_PIECES)\n",
107
+ " except Exception as e: print(f'⚠️ {e}'); pieces = _generate_synthetic_dataset(MAX_PIECES)\n",
108
+ "\n",
109
+ "print(f'Loaded {len(pieces)} pieces in {time.time()-t0:.1f}s')\n",
110
  "\n",
111
+ "t0 = time.time()\n",
112
  "phrases, controls = preprocess_dataset(pieces, tokenizer, max_phrase_len=config.vae_max_seq_len)\n",
113
+ "print(f'Preprocessed {len(phrases)} phrases in {time.time()-t0:.1f}s')\n",
114
+ "print(f'Avg length: {sum(len(p) for p in phrases)/max(len(phrases),1):.1f} tokens')"
115
  ]
116
  },
117
  {
 
120
  "metadata": {},
121
  "outputs": [],
122
  "source": [
123
+ "# 6. STAGE 1 — PhraseVAE Training\n",
124
+ "# ================================================================\n",
125
+ "# KEY PERF DECISIONS:\n",
126
+ "# - Pre-convert ALL phrases to a single padded tensor (no per-item padding)\n",
127
+ "# - num_workers=0 (avoids Colab multiprocessing deadlocks)\n",
128
+ "# - Warmup forward pass before timing (first CUDA call compiles kernels)\n",
129
+ "# - Explicit timing at every step to identify any remaining bottleneck\n",
130
+ "# ================================================================\n",
131
+ "import random, math, numpy as np\n",
132
  "import torch.nn.functional as F\n",
133
+ "from torch.utils.data import DataLoader, TensorDataset\n",
134
  "from model import PhraseVAE, ZClip, apply_span_mask_vectorized\n",
135
  "\n",
 
136
  "SEED = 42\n",
137
  "random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)\n",
138
  "if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)\n",
139
  "\n",
140
+ "# ---- PRE-CONVERT all phrases into one big padded tensor ----\n",
141
+ "# This eliminates per-item Python list padding in __getitem__\n",
142
+ "print('Pre-converting phrases to tensor...')\n",
143
+ "t0 = time.time()\n",
144
+ "max_len = config.vae_max_seq_len\n",
145
+ "all_ids = torch.zeros(len(phrases), max_len, dtype=torch.long)\n",
146
+ "for i, p in enumerate(phrases):\n",
147
+ " L = min(len(p), max_len)\n",
148
+ " all_ids[i, :L] = torch.tensor(p[:L], dtype=torch.long)\n",
149
+ "print(f'Tensor shape: {all_ids.shape}, took {time.time()-t0:.2f}s')\n",
150
+ "\n",
151
+ "# TensorDataset + num_workers=0 = zero overhead DataLoader\n",
152
+ "train_ds = TensorDataset(all_ids)\n",
153
  "train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,\n",
154
+ " num_workers=0, pin_memory=False, drop_last=True)\n",
155
  "\n",
156
+ "# Create model\n",
157
  "vae = PhraseVAE(config).to(device)\n",
158
  "print(f'PhraseVAE: {sum(p.numel() for p in vae.parameters()):,} params')\n",
159
  "\n",
160
  "optimizer = torch.optim.AdamW(vae.parameters(), lr=LEARNING_RATE, weight_decay=0.01)\n",
161
  "scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=500, T_mult=2, eta_min=1e-6)\n",
162
  "zclip = ZClip(z_thresh=2.5)\n",
163
+ "scaler = torch.amp.GradScaler() if use_scaler else None\n",
164
  "\n",
165
+ "# ---- WARMUP: one forward+backward to compile CUDA kernels ----\n",
166
+ "print('\\nWarmup forward pass (compiling CUDA kernels)...')\n",
167
+ "t0 = time.time()\n",
168
+ "vae.train()\n",
169
+ "warmup_batch = all_ids[:BATCH_SIZE].to(device)\n",
170
+ "with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=torch.cuda.is_available()):\n",
171
+ " warmup_out = vae(warmup_batch, kl_weight=0.0)\n",
172
+ "warmup_out['loss'].backward()\n",
173
+ "optimizer.zero_grad(set_to_none=True)\n",
174
+ "if torch.cuda.is_available(): torch.cuda.synchronize()\n",
175
+ "print(f'Warmup done in {time.time()-t0:.2f}s (loss={warmup_out[\"loss\"].item():.4f})')\n",
176
+ "\n",
177
+ "# ---- Training loop ----\n",
178
  "print('\\n' + '='*60 + '\\nStarting PhraseVAE Training\\n' + '='*60)\n",
179
  "\n",
180
+ "pretrain_ep = max(1, VAE_EPOCHS // 5)\n",
181
+ "ae_ep = max(1, VAE_EPOCHS * 3 // 5)\n",
182
+ "vae_ep = max(1, VAE_EPOCHS - pretrain_ep - ae_ep)\n",
183
+ "stages = [('1a-Pretrain', pretrain_ep, 0.0, True),\n",
184
+ " ('1b-AE', ae_ep, 0.0, False),\n",
185
+ " ('1c-VAE', vae_ep, 0.01, False)]\n",
186
  "\n",
187
  "global_step = 0\n",
188
  "history = {'loss': [], 'recon': [], 'kl': []}\n",
189
  "\n",
190
  "for stage_name, n_epochs, kl_weight, use_masking in stages:\n",
191
+ " print(f'\\n--- {stage_name} ({n_epochs} ep, KL={kl_weight}, mask={use_masking}) ---')\n",
192
  " if stage_name == '1c-VAE':\n",
193
  " for pg in optimizer.param_groups: pg['lr'] = LEARNING_RATE * 0.1\n",
194
  "\n",
195
  " for epoch in range(n_epochs):\n",
196
+ " vae.train()\n",
197
+ " epoch_loss = 0; n_batches = 0\n",
198
+ " t_epoch = time.time()\n",
199
+ "\n",
200
+ " for (batch_data,) in train_loader: # TensorDataset returns tuple\n",
201
+ " token_ids = batch_data.to(device, non_blocking=True)\n",
202
+ "\n",
203
+ " if use_masking:\n",
204
+ " input_ids = apply_span_mask_vectorized(token_ids, mask_id=config.mask_token_id)\n",
205
+ " else:\n",
206
+ " input_ids = token_ids\n",
207
  "\n",
208
  " optimizer.zero_grad(set_to_none=True)\n",
209
+ "\n",
210
+ " with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=torch.cuda.is_available()):\n",
211
  " outputs = vae(input_ids, target_tokens=token_ids, kl_weight=kl_weight)\n",
212
  " loss = outputs['loss']\n",
213
  "\n",
214
  " if torch.isnan(loss) or torch.isinf(loss):\n",
215
+ " print(f'⚠️ NaN step {global_step}'); optimizer.zero_grad(set_to_none=True); continue\n",
 
216
  "\n",
217
  " if use_scaler:\n",
218
  " scaler.scale(loss).backward()\n",
 
227
  " history['recon'].append(outputs['recon_loss'].item())\n",
228
  " history['kl'].append(outputs['kl_loss'].item())\n",
229
  "\n",
230
+ " elapsed = time.time() - t_epoch\n",
231
+ " avg = epoch_loss / max(n_batches, 1)\n",
232
+ " print(f' Ep {epoch+1}/{n_epochs} | loss={avg:.4f} recon={outputs[\"recon_loss\"].item():.4f} '\n",
233
+ " f'kl={outputs[\"kl_loss\"].item():.4f} | {elapsed:.1f}s ({n_batches/max(elapsed,0.01):.1f} batch/s)')\n",
 
 
 
234
  "\n",
235
+ "print(f'\\n✅ PhraseVAE done! {global_step} steps')"
236
  ]
237
  },
238
  {
 
241
  "metadata": {},
242
  "outputs": [],
243
  "source": [
244
+ "# 7. STAGE 2 — LatentMamba Training\n",
 
 
245
  "from model import LatentMamba\n",
246
  "\n",
247
  "vae.eval()\n",
248
  "for p in vae.parameters(): p.requires_grad = False\n",
249
  "\n",
250
+ "print('Encoding phrases...')\n",
251
+ "t0 = time.time()\n",
252
  "all_latents = []\n",
253
+ "enc_loader = DataLoader(TensorDataset(all_ids), batch_size=128, num_workers=0)\n",
254
  "with torch.no_grad():\n",
255
+ " for (batch_data,) in enc_loader:\n",
256
+ " with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=torch.cuda.is_available()):\n",
257
+ " z, _, _ = vae.encode(batch_data.to(device))\n",
258
  " all_latents.append(z.cpu())\n",
259
  "all_z = torch.cat(all_latents, dim=0)\n",
260
+ "print(f'Encoded {all_z.shape[0]} phrases in {time.time()-t0:.1f}s')\n",
 
 
 
 
261
  "\n",
262
+ "SEQ_LEN = min(64, max(4, len(all_z) // 4))\n",
263
+ "latent_seqs = torch.stack([all_z[i:i+SEQ_LEN] for i in range(0, len(all_z)-SEQ_LEN, SEQ_LEN//2)])\n",
264
+ "print(f'{latent_seqs.shape[0]} sequences of len {SEQ_LEN}')\n",
 
265
  "\n",
266
+ "lat_loader = DataLoader(TensorDataset(latent_seqs),\n",
267
+ " batch_size=min(BATCH_SIZE, latent_seqs.shape[0]), shuffle=True, num_workers=0, drop_last=True)\n",
268
  "\n",
269
  "mamba = LatentMamba(config).to(device)\n",
270
  "print(f'LatentMamba: {sum(p.numel() for p in mamba.parameters()):,} params')\n",
271
  "\n",
272
+ "m_opt = torch.optim.AdamW(mamba.parameters(), lr=LEARNING_RATE*0.5, weight_decay=0.01)\n",
273
+ "m_sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(m_opt, T_0=300, T_mult=2, eta_min=1e-6)\n",
274
+ "m_zclip = ZClip(z_thresh=2.5)\n",
275
+ "m_scaler = torch.amp.GradScaler() if use_scaler else None\n",
276
  "\n",
277
+ "print('\\n' + '='*60 + '\\nLatentMamba Training\\n' + '='*60)\n",
278
  "mamba_history = {'mse': [], 'cos': []}\n",
279
  "\n",
280
  "for epoch in range(MAMBA_EPOCHS):\n",
281
+ " mamba.train(); e_loss = 0; nb = 0; t0 = time.time()\n",
282
+ " for (z_batch,) in lat_loader:\n",
283
+ " z_seq = z_batch.to(device); z_in, z_tgt = z_seq[:,:-1], z_seq[:,1:]\n",
284
+ " m_opt.zero_grad(set_to_none=True)\n",
285
+ " with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=torch.cuda.is_available()):\n",
286
+ " z_pred = mamba(z_in)\n",
287
+ " mse = F.mse_loss(z_pred, z_tgt)\n",
288
+ " cos = 1.0 - F.cosine_similarity(z_pred.reshape(-1,z_pred.shape[-1]),\n",
289
+ " z_tgt.reshape(-1,z_tgt.shape[-1]), dim=-1).mean()\n",
290
+ " loss = mse + 0.1 * cos\n",
291
+ " if torch.isnan(loss): m_opt.zero_grad(set_to_none=True); continue\n",
 
 
 
 
 
292
  " if use_scaler:\n",
293
+ " m_scaler.scale(loss).backward(); m_scaler.unscale_(m_opt); m_zclip(mamba)\n",
294
+ " m_scaler.step(m_opt); m_scaler.update()\n",
295
+ " else: loss.backward(); m_zclip(mamba); m_opt.step()\n",
296
+ " m_sched.step(); e_loss += loss.item(); nb += 1\n",
297
+ " mamba_history['mse'].append(mse.item()); mamba_history['cos'].append(cos.item())\n",
298
+ " if (epoch+1) % 5 == 0 or epoch == 0:\n",
299
+ " print(f' Ep {epoch+1}/{MAMBA_EPOCHS} | loss={e_loss/max(nb,1):.6f} mse={mse.item():.6f} '\n",
300
+ " f'cos={cos.item():.4f} | {time.time()-t0:.1f}s')\n",
301
+ "print('\\n✅ LatentMamba done!')"
 
 
 
 
 
 
 
302
  ]
303
  },
304
  {
 
307
  "metadata": {},
308
  "outputs": [],
309
  "source": [
310
+ "# 8. Training curves\n",
 
 
311
  "import matplotlib.pyplot as plt\n",
312
  "fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n",
313
+ "for ax, data, title, c in [(axes[0],history['loss'],'VAE Loss','blue'),\n",
314
+ " (axes[1],history['kl'],'KL','red'),(axes[2],mamba_history['mse'],'Mamba MSE','green')]:\n",
315
+ " ax.plot(data, alpha=0.3, color=c)\n",
 
 
 
316
  " w = min(50, max(1, len(data)//5))\n",
317
+ " if w > 1: ax.plot(np.convolve(data, np.ones(w)/w, 'valid'), color=c, lw=2)\n",
318
+ " ax.set_title(title); ax.grid(True, alpha=0.3)\n",
319
+ "plt.tight_layout(); plt.savefig('/content/curves.png', dpi=150); plt.show()"
 
320
  ]
321
  },
322
  {
 
325
  "metadata": {},
326
  "outputs": [],
327
  "source": [
328
+ "# 9. Generate!\n",
 
 
329
  "from model import MuseMorphic\n",
330
  "from tokenizer import notes_to_midi_file\n",
331
  "\n",
332
  "full_model = MuseMorphic(config).to(device)\n",
333
+ "full_model.phrase_vae = vae; full_model.latent_mamba = mamba; full_model.eval()\n",
 
 
 
 
 
334
  "\n",
335
+ "N_PHRASES = 16; TEMP = 0.7\n",
336
  "with torch.no_grad():\n",
337
+ " z_gen = mamba.generate(N_PHRASES, temperature=TEMP, batch_size=1)\n",
 
 
338
  " all_tokens = []\n",
339
  " for t in range(z_gen.shape[1]):\n",
340
  " z = z_gen[:, t]\n",
341
+ " ids = [config.bos_token_id]\n",
342
  " for _ in range(128):\n",
343
+ " inp = torch.tensor([ids], dtype=torch.long, device=device)\n",
344
+ " with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=torch.cuda.is_available()):\n",
345
  " logits = vae.decode(z, inp)\n",
346
+ " tok = torch.multinomial(F.softmax(logits[0,-1]/max(TEMP,0.1), dim=-1), 1).item()\n",
347
+ " ids.append(tok)\n",
 
348
  " if tok == config.eos_token_id: break\n",
349
+ " all_tokens.extend(tokenizer.decode(ids))\n",
350
  "\n",
351
  "notes = tokenizer.tokens_to_midi_notes(all_tokens)\n",
352
+ "print(f'{len(notes)} notes from {len(all_tokens)} tokens')\n",
 
353
  "if notes:\n",
354
+ " notes_to_midi_file(notes, '/content/generated.mid')\n",
355
+ " print('🎵 Saved /content/generated.mid')"
 
 
356
  ]
357
  },
358
  {
 
361
  "metadata": {},
362
  "outputs": [],
363
  "source": [
364
+ "# 10. Save\n",
365
+ "import os; from dataclasses import asdict\n",
 
 
 
366
  "os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
 
367
  "torch.save({'vae_state_dict': vae.state_dict(), 'mamba_state_dict': mamba.state_dict(),\n",
368
+ " 'config': asdict(config)}, f'{OUTPUT_DIR}/model.pt')\n",
369
+ "tokenizer.save(f'{OUTPUT_DIR}/tokenizer')\n",
370
+ "print(f'✅ Saved to {OUTPUT_DIR}')"
 
 
 
 
 
371
  ]
372
  },
373
  {
 
376
  "metadata": {},
377
  "outputs": [],
378
  "source": [
379
+ "# 11. Listen\n",
 
 
380
  "try:\n",
381
+ " from IPython.display import Audio, display; import pretty_midi\n",
382
+ " audio = pretty_midi.PrettyMIDI('/content/generated.mid').fluidsynth(fs=22050)\n",
383
+ " display(Audio(audio, rate=22050))\n",
384
+ "except: print('Download /content/generated.mid to listen')"
 
 
 
 
385
  ]
386
  }
387
  ]