asdf98 commited on
Commit
2bf4fa5
·
verified ·
1 Parent(s): c163568

FIX: Use vectorized span masking in notebook (no Python loop over batch)"

Browse files
Files changed (1) hide show
  1. MuseMorphic_Training.ipynb +149 -504
MuseMorphic_Training.ipynb CHANGED
@@ -20,26 +20,13 @@
20
  "cell_type": "markdown",
21
  "metadata": {},
22
  "source": [
23
- "# 🎵 MuseMorphic: Lightweight MIDI Generator\n",
24
  "\n",
25
  "**A novel consumer-grade architecture for controllable, infinite-length MIDI generation.**\n",
26
  "\n",
27
- "Key features:\n",
28
- "- **~33M parameters** — trains on free Colab T4, inference <1GB VRAM\n",
29
- "- **O(n) complexity** — Mamba SSM backbone, no quadratic attention\n",
30
- "- **Two-stage hierarchical** — PhraseVAE (compress) + LatentMamba (generate)\n",
31
- "- **Music-native** — FME embeddings with harmonic awareness\n",
32
- "- **Controllable** — tempo, key, density, style conditioning\n",
33
- "- **Infinite generation** — fixed-size recurrent state, no memory growth\n",
34
- "- **Training-stable by design** — σReparam + ZClip + Pre-LN + BF16\n",
35
  "\n",
36
- "📄 [Architecture Paper/README](https://huggingface.co/asdf98/MuseMorphic)\n",
37
- "\n",
38
- "---\n",
39
- "\n",
40
- "## Setup\n",
41
- "\n",
42
- "Run this cell first to install all dependencies."
43
  ]
44
  },
45
  {
@@ -52,15 +39,14 @@
52
  "# 1. Install Dependencies\n",
53
  "# ============================================================\n",
54
  "!pip install -q torch torchvision torchaudio\n",
55
- "!pip install -q einops datasets pretty_midi midiutil\n",
56
- "!pip install -q huggingface_hub\n",
57
  "\n",
58
- "# Clone MuseMorphic repo\n",
59
- "!git clone https://huggingface.co/asdf98/MuseMorphic /content/MuseMorphic 2>/dev/null || (cd /content/MuseMorphic && git pull)\n",
 
60
  "\n",
61
  "import sys\n",
62
  "sys.path.insert(0, '/content/MuseMorphic/musemorphic')\n",
63
- "\n",
64
  "print('✅ Dependencies installed!')"
65
  ]
66
  },
@@ -74,36 +60,20 @@
74
  "# 2. Check GPU & Hardware\n",
75
  "# ============================================================\n",
76
  "import torch\n",
77
- "import os\n",
78
- "\n",
79
- "print(f'PyTorch version: {torch.__version__}')\n",
80
- "print(f'CUDA available: {torch.cuda.is_available()}')\n",
81
  "if torch.cuda.is_available():\n",
82
  " print(f'GPU: {torch.cuda.get_device_name(0)}')\n",
83
  " print(f'VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')\n",
84
- " print(f'BF16 support: {torch.cuda.is_bf16_supported()}')\n",
85
- "else:\n",
86
- " print('⚠️ No GPU detected. Training will be slow but functional on CPU.')\n",
87
  "\n",
88
- "# Auto-detect best dtype\n",
89
  "if torch.cuda.is_available() and torch.cuda.is_bf16_supported():\n",
90
- " DTYPE = 'bf16'\n",
91
- " print('\\n✅ Using BFloat16 (optimal: no loss scaling needed)')\n",
92
  "elif torch.cuda.is_available():\n",
93
- " DTYPE = 'fp16'\n",
94
- " print('\\n⚠️ Using Float16 (T4 GPU — BF16 not supported, using FP16 with gradient scaling)')\n",
95
  "else:\n",
96
- " DTYPE = 'fp32'\n",
97
- " print('\\n📋 Using Float32 (CPU mode)')"
98
- ]
99
- },
100
- {
101
- "cell_type": "markdown",
102
- "metadata": {},
103
- "source": [
104
- "## Configuration\n",
105
- "\n",
106
- "Adjust these settings based on your GPU and desired output quality."
107
  ]
108
  },
109
  {
@@ -115,35 +85,17 @@
115
  "# ============================================================\n",
116
  "# 3. Configuration\n",
117
  "# ============================================================\n",
118
- "\n",
119
- "# ---- Model Size Presets ----\n",
120
- "# 'tiny' : ~8M params — Fast experiments, lower quality\n",
121
- "# 'small' : ~33M params — Default, good quality (recommended for Colab T4)\n",
122
- "# 'medium': ~65M params — Better quality, needs more VRAM\n",
123
- "\n",
124
  "MODEL_SIZE = 'small' # @param ['tiny', 'small', 'medium']\n",
125
- "\n",
126
- "# ---- Dataset ----\n",
127
- "# 'auto' : Auto-select best available\n",
128
- "# 'maestro' : Classical piano (MAESTRO)\n",
129
- "# 'synthetic': Generated data (for testing)\n",
130
- "DATASET = 'auto' # @param ['auto', 'maestro', 'synthetic']\n",
131
- "MAX_PIECES = 500 # @param {type: 'integer'}\n",
132
- "\n",
133
- "# ---- Training ----\n",
134
- "VAE_EPOCHS = 15 # @param {type: 'integer'}\n",
135
- "MAMBA_EPOCHS = 30 # @param {type: 'integer'}\n",
136
- "BATCH_SIZE = 32 # @param {type: 'integer'}\n",
137
- "LEARNING_RATE = 3e-4 # @param {type: 'number'}\n",
138
- "\n",
139
- "# ---- Output ----\n",
140
- "OUTPUT_DIR = '/content/checkpoints' # @param {type: 'string'}\n",
141
- "PUSH_TO_HUB = False # @param {type: 'boolean'}\n",
142
- "HUB_MODEL_ID = '' # @param {type: 'string'}\n",
143
- "\n",
144
- "print(f'Model size: {MODEL_SIZE}')\n",
145
- "print(f'Dataset: {DATASET}')\n",
146
- "print(f'Training: VAE {VAE_EPOCHS}ep + Mamba {MAMBA_EPOCHS}ep, batch={BATCH_SIZE}')"
147
  ]
148
  },
149
  {
@@ -153,63 +105,25 @@
153
  "outputs": [],
154
  "source": [
155
  "# ============================================================\n",
156
- "# 4. Build Model Configuration\n",
157
  "# ============================================================\n",
158
  "from model import MuseMorphicConfig, MuseMorphic, model_summary\n",
159
  "\n",
160
- "# Model size presets\n",
161
  "SIZE_CONFIGS = {\n",
162
- " 'tiny': MuseMorphicConfig(\n",
163
- " d_model=128,\n",
164
- " vae_encoder_layers=2,\n",
165
- " vae_decoder_layers=2,\n",
166
- " vae_n_heads=4,\n",
167
- " vae_d_ff=256,\n",
168
- " latent_dim=32,\n",
169
- " mamba_d_model=128,\n",
170
- " mamba_n_layers=4,\n",
171
- " mamba_d_state=8,\n",
172
- " mamba_expand=2,\n",
173
- " ),\n",
174
- " 'small': MuseMorphicConfig(\n",
175
- " d_model=256,\n",
176
- " vae_encoder_layers=3,\n",
177
- " vae_decoder_layers=3,\n",
178
- " vae_n_heads=4,\n",
179
- " vae_d_ff=512,\n",
180
- " latent_dim=64,\n",
181
- " mamba_d_model=256,\n",
182
- " mamba_n_layers=8,\n",
183
- " mamba_d_state=16,\n",
184
- " mamba_expand=2,\n",
185
- " ),\n",
186
- " 'medium': MuseMorphicConfig(\n",
187
- " d_model=384,\n",
188
- " vae_encoder_layers=4,\n",
189
- " vae_decoder_layers=4,\n",
190
- " vae_n_heads=6,\n",
191
- " vae_d_ff=768,\n",
192
- " latent_dim=96,\n",
193
- " mamba_d_model=384,\n",
194
- " mamba_n_layers=12,\n",
195
- " mamba_d_state=16,\n",
196
- " mamba_expand=2,\n",
197
- " ),\n",
198
  "}\n",
199
- "\n",
200
  "config = SIZE_CONFIGS[MODEL_SIZE]\n",
201
  "model = model_summary(config)"
202
  ]
203
  },
204
- {
205
- "cell_type": "markdown",
206
- "metadata": {},
207
- "source": [
208
- "## Data Preparation\n",
209
- "\n",
210
- "Automatically downloads and preprocesses MIDI data."
211
- ]
212
- },
213
  {
214
  "cell_type": "code",
215
  "execution_count": null,
@@ -221,60 +135,27 @@
221
  "# ============================================================\n",
222
  "import logging\n",
223
  "logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')\n",
224
- "\n",
225
- "from data_pipeline import prepare_training_data, auto_select_dataset, load_dataset_notes\n",
226
- "from data_pipeline import preprocess_dataset, _generate_synthetic_dataset\n",
227
  "from tokenizer import REMIPlusTokenizer\n",
228
  "\n",
229
- "# Select dataset\n",
230
- "if DATASET == 'auto':\n",
231
- " dataset_name = auto_select_dataset()\n",
232
- "elif DATASET == 'maestro':\n",
233
- " dataset_name = 'maestro_v1_sustain'\n",
234
- "elif DATASET == 'synthetic':\n",
235
- " dataset_name = None\n",
236
- "else:\n",
237
- " dataset_name = DATASET\n",
238
- "\n",
239
- "# Load and preprocess\n",
240
  "tokenizer = REMIPlusTokenizer()\n",
 
 
 
241
  "\n",
242
- "if dataset_name is not None:\n",
243
  " try:\n",
244
  " pieces = load_dataset_notes(dataset_name, max_pieces=MAX_PIECES)\n",
245
  " print(f'✅ Loaded {len(pieces)} pieces from {dataset_name}')\n",
246
  " except Exception as e:\n",
247
- " print(f'⚠️ Failed to load {dataset_name}: {e}')\n",
248
- " print('Falling back to synthetic data...')\n",
249
  " pieces = _generate_synthetic_dataset(MAX_PIECES)\n",
250
  "else:\n",
251
  " pieces = _generate_synthetic_dataset(MAX_PIECES)\n",
252
  " print(f'✅ Generated {len(pieces)} synthetic pieces')\n",
253
  "\n",
254
- "# Preprocess\n",
255
  "phrases, controls = preprocess_dataset(pieces, tokenizer, max_phrase_len=config.vae_max_seq_len)\n",
256
- "print(f'\\n📊 Dataset Summary:')\n",
257
- "print(f' Total phrases: {len(phrases)}')\n",
258
- "print(f' Avg phrase length: {sum(len(p) for p in phrases)/len(phrases):.1f} tokens')\n",
259
- "print(f' Vocab size: {tokenizer.vocab_size}')\n",
260
- "print(f' Sample phrase (first 10 tokens): {phrases[0][:10]}')"
261
- ]
262
- },
263
- {
264
- "cell_type": "markdown",
265
- "metadata": {},
266
- "source": [
267
- "## Training\n",
268
- "\n",
269
- "Two-stage training with curriculum:\n",
270
- "\n",
271
- "**Stage 1 — PhraseVAE** (compress phrases to latent vectors):\n",
272
- "- 1a. Span-infilling pretraining (learn REMI grammar)\n",
273
- "- 1b. Autoencoder (pure reconstruction, KL=0)\n",
274
- "- 1c. VAE fine-tuning (KL weight = 0.01)\n",
275
- "\n",
276
- "**Stage 2 — LatentMamba** (generate latent sequences):\n",
277
- "- Predict next phrase latent from history, O(n) complexity"
278
  ]
279
  },
280
  {
@@ -286,160 +167,95 @@
286
  "# ============================================================\n",
287
  "# 6. Training — Stage 1: PhraseVAE\n",
288
  "# ============================================================\n",
289
- "import time\n",
290
- "import random\n",
291
- "import math\n",
292
- "import numpy as np\n",
293
- "import torch\n",
294
  "import torch.nn.functional as F\n",
295
  "from torch.utils.data import DataLoader, Dataset\n",
296
- "from model import PhraseVAE, ZClip\n",
297
  "\n",
298
- "# Set seed\n",
299
  "SEED = 42\n",
300
- "random.seed(SEED)\n",
301
- "np.random.seed(SEED)\n",
302
- "torch.manual_seed(SEED)\n",
303
- "if torch.cuda.is_available():\n",
304
- " torch.cuda.manual_seed_all(SEED)\n",
305
  "\n",
306
- "# Device & dtype\n",
307
  "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
308
- "if DTYPE == 'bf16' and torch.cuda.is_available() and torch.cuda.is_bf16_supported():\n",
309
- " amp_dtype = torch.bfloat16\n",
310
- "elif DTYPE == 'fp16' and torch.cuda.is_available():\n",
311
- " amp_dtype = torch.float16\n",
312
- "else:\n",
313
- " amp_dtype = torch.float32\n",
314
  "\n",
315
- "# Dataset\n",
316
  "class PhraseDS(Dataset):\n",
317
  " def __init__(self, phrases, max_len, pad_id=0):\n",
318
- " self.phrases = phrases\n",
319
- " self.max_len = max_len\n",
320
- " self.pad_id = pad_id\n",
321
- " def __len__(self):\n",
322
- " return len(self.phrases)\n",
323
  " def __getitem__(self, idx):\n",
324
  " ids = self.phrases[idx][:self.max_len]\n",
325
- " padded = ids + [self.pad_id] * (self.max_len - len(ids))\n",
326
- " return torch.tensor(padded, dtype=torch.long)\n",
327
  "\n",
328
  "train_ds = PhraseDS(phrases, config.vae_max_seq_len)\n",
329
- "train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, \n",
330
  " num_workers=2, pin_memory=True, drop_last=True)\n",
331
  "\n",
332
- "# Create VAE\n",
333
  "vae = PhraseVAE(config).to(device)\n",
334
- "vae_params = sum(p.numel() for p in vae.parameters())\n",
335
- "print(f'PhraseVAE parameters: {vae_params:,} ({vae_params/1e6:.2f}M)')\n",
336
  "\n",
337
- "# Optimizer\n",
338
  "optimizer = torch.optim.AdamW(vae.parameters(), lr=LEARNING_RATE, weight_decay=0.01)\n",
339
  "scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=500, T_mult=2, eta_min=1e-6)\n",
340
  "zclip = ZClip(z_thresh=2.5)\n",
341
  "\n",
342
- "# FP16 needs GradScaler, BF16 does not\n",
343
- "use_scaler = (amp_dtype == torch.float16)\n",
344
- "scaler = torch.amp.GradScaler() if use_scaler else None\n",
345
  "\n",
346
- "# Span masking helper\n",
347
- "def apply_span_mask(token_ids, mask_prob=0.15, mask_id=3, span_len=3):\n",
348
- " masked = token_ids.clone()\n",
349
- " B, L = masked.shape\n",
350
- " for b in range(B):\n",
351
- " n_masks = max(1, int(L * mask_prob / span_len))\n",
352
- " for _ in range(n_masks):\n",
353
- " start = random.randint(1, max(1, L - span_len - 1))\n",
354
- " end = min(start + span_len, L)\n",
355
- " masked[b, start:end] = mask_id\n",
356
- " return masked\n",
357
- "\n",
358
- "# ---- Training Loop ----\n",
359
- "print('\\n' + '='*60)\n",
360
- "print('Starting PhraseVAE Training')\n",
361
- "print('='*60)\n",
362
- "\n",
363
- "# Compute total epochs\n",
364
  "pretrain_epochs = max(1, VAE_EPOCHS // 5)\n",
365
  "ae_epochs = max(1, VAE_EPOCHS * 3 // 5)\n",
366
- "vae_epochs = max(1, VAE_EPOCHS - pretrain_epochs - ae_epochs)\n",
367
- "\n",
368
- "stages = [\n",
369
- " ('1a-Pretrain', pretrain_epochs, 0.0, True),\n",
370
- " ('1b-AE', ae_epochs, 0.0, False),\n",
371
- " ('1c-VAE', vae_epochs, 0.01, False),\n",
372
- "]\n",
373
  "\n",
374
  "global_step = 0\n",
375
  "history = {'loss': [], 'recon': [], 'kl': []}\n",
376
  "\n",
377
  "for stage_name, n_epochs, kl_weight, use_masking in stages:\n",
378
  " print(f'\\n--- Stage {stage_name} ({n_epochs} epochs, KL={kl_weight}) ---')\n",
379
- " \n",
380
- " # Lower LR for VAE fine-tuning stage\n",
381
  " if stage_name == '1c-VAE':\n",
382
- " for pg in optimizer.param_groups:\n",
383
- " pg['lr'] = LEARNING_RATE * 0.1\n",
384
- " \n",
385
  " for epoch in range(n_epochs):\n",
386
- " vae.train()\n",
387
- " epoch_loss = 0\n",
388
- " n_batches = 0\n",
389
- " t0 = time.time()\n",
390
- " \n",
391
  " for batch in train_loader:\n",
392
  " token_ids = batch.to(device)\n",
393
- " \n",
394
- " # Apply masking for pretraining\n",
395
- " if use_masking:\n",
396
- " input_ids = apply_span_mask(token_ids)\n",
397
- " else:\n",
398
- " input_ids = token_ids\n",
399
- " \n",
400
- " optimizer.zero_grad()\n",
401
- " \n",
402
  " with torch.autocast(device_type=device.type, dtype=amp_dtype):\n",
403
  " outputs = vae(input_ids, target_tokens=token_ids, kl_weight=kl_weight)\n",
404
  " loss = outputs['loss']\n",
405
- " \n",
406
- " # NaN check\n",
407
  " if torch.isnan(loss) or torch.isinf(loss):\n",
408
- " print(f'⚠️ NaN/Inf at step {global_step}! Skipping...')\n",
409
- " optimizer.zero_grad()\n",
410
- " continue\n",
411
- " \n",
412
  " if use_scaler:\n",
413
  " scaler.scale(loss).backward()\n",
414
- " scaler.unscale_(optimizer)\n",
415
- " zclip(vae)\n",
416
- " scaler.step(optimizer)\n",
417
- " scaler.update()\n",
418
  " else:\n",
419
- " loss.backward()\n",
420
- " zclip(vae)\n",
421
- " optimizer.step()\n",
422
- " \n",
423
  " scheduler.step()\n",
424
- " \n",
425
- " epoch_loss += loss.item()\n",
426
- " n_batches += 1\n",
427
- " global_step += 1\n",
428
- " \n",
429
- " # Log\n",
430
  " history['loss'].append(loss.item())\n",
431
  " history['recon'].append(outputs['recon_loss'].item())\n",
432
  " history['kl'].append(outputs['kl_loss'].item())\n",
433
- " \n",
434
  " elapsed = time.time() - t0\n",
435
  " avg_loss = epoch_loss / max(n_batches, 1)\n",
436
- " lr = optimizer.param_groups[0]['lr']\n",
437
  " print(f' Epoch {epoch+1}/{n_epochs} | Loss: {avg_loss:.4f} | '\n",
438
  " f'Recon: {outputs[\"recon_loss\"].item():.4f} | '\n",
439
  " f'KL: {outputs[\"kl_loss\"].item():.4f} | '\n",
440
- " f'LR: {lr:.2e} | Time: {elapsed:.1f}s')\n",
 
441
  "\n",
442
- "print(f'\\n✅ PhraseVAE training complete! ({global_step} total steps)')"
443
  ]
444
  },
445
  {
@@ -453,122 +269,77 @@
453
  "# ============================================================\n",
454
  "from model import LatentMamba\n",
455
  "\n",
456
- "# Freeze VAE encoder\n",
457
  "vae.eval()\n",
458
- "for p in vae.parameters():\n",
459
- " p.requires_grad = False\n",
460
  "\n",
461
- "# Encode all phrases to latent space\n",
462
  "print('Encoding phrases to latent space...')\n",
463
  "all_latents = []\n",
464
- "encode_loader = DataLoader(train_ds, batch_size=64, shuffle=False, num_workers=2)\n",
465
- "\n",
466
  "with torch.no_grad():\n",
467
  " for batch in encode_loader:\n",
468
- " token_ids = batch.to(device)\n",
469
  " with torch.autocast(device_type=device.type, dtype=amp_dtype):\n",
470
- " z, _, _ = vae.encode(token_ids)\n",
471
  " all_latents.append(z.cpu())\n",
472
- "\n",
473
  "all_z = torch.cat(all_latents, dim=0)\n",
474
- "print(f'Encoded {all_z.shape[0]} phrases to latent dim {all_z.shape[1]}')\n",
475
  "\n",
476
- "# Create latent sequences (group phrases into chunks)\n",
477
  "SEQ_LEN = min(64, len(all_z) // 4)\n",
478
- "latent_seqs = []\n",
479
- "for i in range(0, len(all_z) - SEQ_LEN, SEQ_LEN // 2): # 50% overlap\n",
480
- " latent_seqs.append(all_z[i:i+SEQ_LEN])\n",
481
- "\n",
482
- "print(f'Created {len(latent_seqs)} latent sequences of length {SEQ_LEN}')\n",
483
  "\n",
484
  "class LatentDS(Dataset):\n",
485
- " def __init__(self, seqs):\n",
486
- " self.seqs = seqs\n",
487
- " def __len__(self):\n",
488
- " return len(self.seqs)\n",
489
- " def __getitem__(self, idx):\n",
490
- " return self.seqs[idx]\n",
491
  "\n",
492
- "latent_ds = LatentDS(latent_seqs)\n",
493
- "latent_loader = DataLoader(latent_ds, batch_size=min(BATCH_SIZE, len(latent_seqs)),\n",
494
- " shuffle=True, drop_last=True)\n",
495
  "\n",
496
- "# Create LatentMamba\n",
497
  "mamba = LatentMamba(config).to(device)\n",
498
- "mamba_params = sum(p.numel() for p in mamba.parameters())\n",
499
- "print(f'LatentMamba parameters: {mamba_params:,} ({mamba_params/1e6:.2f}M)')\n",
500
  "\n",
501
- "# Optimizer\n",
502
- "mamba_optimizer = torch.optim.AdamW(mamba.parameters(), lr=LEARNING_RATE * 0.5, weight_decay=0.01)\n",
503
- "mamba_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(\n",
504
- " mamba_optimizer, T_0=300, T_mult=2, eta_min=1e-6)\n",
505
  "mamba_zclip = ZClip(z_thresh=2.5)\n",
506
- "\n",
507
  "mamba_scaler = torch.amp.GradScaler() if use_scaler else None\n",
508
  "\n",
509
- "# Training loop\n",
510
- "print('\\n' + '='*60)\n",
511
- "print('Starting LatentMamba Training')\n",
512
- "print('='*60)\n",
513
- "\n",
514
  "mamba_history = {'mse': [], 'cos': []}\n",
515
  "\n",
516
  "for epoch in range(MAMBA_EPOCHS):\n",
517
- " mamba.train()\n",
518
- " epoch_loss = 0\n",
519
- " n_batches = 0\n",
520
- " t0 = time.time()\n",
521
- " \n",
522
  " for batch in latent_loader:\n",
523
  " z_seq = batch.to(device)\n",
524
- " z_input = z_seq[:, :-1]\n",
525
- " z_target = z_seq[:, 1:]\n",
526
- " \n",
527
- " mamba_optimizer.zero_grad()\n",
528
- " \n",
529
  " with torch.autocast(device_type=device.type, dtype=amp_dtype):\n",
530
  " z_pred = mamba(z_input)\n",
531
  " mse_loss = F.mse_loss(z_pred, z_target)\n",
532
  " cos_loss = 1.0 - F.cosine_similarity(\n",
533
  " z_pred.reshape(-1, z_pred.shape[-1]),\n",
534
- " z_target.reshape(-1, z_target.shape[-1]), dim=-1\n",
535
- " ).mean()\n",
536
  " loss = mse_loss + 0.1 * cos_loss\n",
537
- " \n",
538
  " if torch.isnan(loss) or torch.isinf(loss):\n",
539
- " print(f'⚠️ NaN/Inf at epoch {epoch}! Skipping...')\n",
540
- " mamba_optimizer.zero_grad()\n",
541
- " continue\n",
542
- " \n",
543
  " if use_scaler:\n",
544
  " mamba_scaler.scale(loss).backward()\n",
545
- " mamba_scaler.unscale_(mamba_optimizer)\n",
546
- " mamba_zclip(mamba)\n",
547
- " mamba_scaler.step(mamba_optimizer)\n",
548
- " mamba_scaler.update()\n",
549
  " else:\n",
550
- " loss.backward()\n",
551
- " mamba_zclip(mamba)\n",
552
- " mamba_optimizer.step()\n",
553
- " \n",
554
- " mamba_scheduler.step()\n",
555
- " \n",
556
- " epoch_loss += loss.item()\n",
557
- " n_batches += 1\n",
558
- " \n",
559
  " mamba_history['mse'].append(mse_loss.item())\n",
560
  " mamba_history['cos'].append(cos_loss.item())\n",
561
- " \n",
562
- " elapsed = time.time() - t0\n",
563
- " avg_loss = epoch_loss / max(n_batches, 1)\n",
564
- " lr = mamba_optimizer.param_groups[0]['lr']\n",
565
- " \n",
566
  " if (epoch + 1) % 5 == 0 or epoch == 0:\n",
567
- " print(f' Epoch {epoch+1}/{MAMBA_EPOCHS} | Loss: {avg_loss:.6f} | '\n",
568
- " f'MSE: {mse_loss.item():.6f} | Cos: {cos_loss.item():.4f} | '\n",
569
- " f'LR: {lr:.2e} | Time: {elapsed:.1f}s')\n",
570
  "\n",
571
- "print(f'\\n✅ LatentMamba training complete!')"
572
  ]
573
  },
574
  {
@@ -581,51 +352,18 @@
581
  "# 8. Plot Training Curves\n",
582
  "# ============================================================\n",
583
  "import matplotlib.pyplot as plt\n",
584
- "\n",
585
  "fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n",
586
  "\n",
587
- "# VAE Loss\n",
588
- "axes[0].plot(history['loss'], alpha=0.3, color='blue')\n",
589
- "window = min(50, len(history['loss']) // 5) if len(history['loss']) > 10 else 1\n",
590
- "if window > 1:\n",
591
- " smoothed = np.convolve(history['loss'], np.ones(window)/window, mode='valid')\n",
592
- " axes[0].plot(smoothed, color='blue', linewidth=2)\n",
593
- "axes[0].set_title('PhraseVAE Loss')\n",
594
- "axes[0].set_xlabel('Step')\n",
595
- "axes[0].set_ylabel('Loss')\n",
596
- "axes[0].grid(True, alpha=0.3)\n",
597
- "\n",
598
- "# VAE KL\n",
599
- "axes[1].plot(history['kl'], alpha=0.5, color='red')\n",
600
- "axes[1].set_title('KL Divergence')\n",
601
- "axes[1].set_xlabel('Step')\n",
602
- "axes[1].set_ylabel('KL')\n",
603
- "axes[1].grid(True, alpha=0.3)\n",
604
- "\n",
605
- "# Mamba Loss\n",
606
- "axes[2].plot(mamba_history['mse'], alpha=0.3, color='green')\n",
607
- "window = min(50, len(mamba_history['mse']) // 5) if len(mamba_history['mse']) > 10 else 1\n",
608
- "if window > 1:\n",
609
- " smoothed = np.convolve(mamba_history['mse'], np.ones(window)/window, mode='valid')\n",
610
- " axes[2].plot(smoothed, color='green', linewidth=2)\n",
611
- "axes[2].set_title('LatentMamba MSE Loss')\n",
612
- "axes[2].set_xlabel('Step')\n",
613
- "axes[2].set_ylabel('MSE')\n",
614
- "axes[2].grid(True, alpha=0.3)\n",
615
- "\n",
616
- "plt.tight_layout()\n",
617
- "plt.savefig('/content/training_curves.png', dpi=150)\n",
618
- "plt.show()\n",
619
- "print('📊 Training curves saved to /content/training_curves.png')"
620
- ]
621
- },
622
- {
623
- "cell_type": "markdown",
624
- "metadata": {},
625
- "source": [
626
- "## Generation\n",
627
  "\n",
628
- "Generate MIDI music using the trained model!"
629
  ]
630
  },
631
  {
@@ -638,80 +376,42 @@
638
  "# 9. Generate Music!\n",
639
  "# ============================================================\n",
640
  "from model import MuseMorphic\n",
641
- "from tokenizer import REMIPlusTokenizer, notes_to_midi_file\n",
642
  "\n",
643
- "# Assemble full model\n",
644
  "full_model = MuseMorphic(config).to(device)\n",
645
  "full_model.phrase_vae = vae\n",
646
  "full_model.latent_mamba = mamba\n",
647
  "full_model.eval()\n",
648
  "\n",
649
- "# Unfreeze VAE for generation (was frozen for Stage 2)\n",
650
- "# (No gradient computation needed for generation anyway)\n",
651
- "\n",
652
- "# Generation settings\n",
653
- "N_PHRASES = 16 # @param {type: 'integer'}\n",
654
- "TEMPERATURE = 0.7 # @param {type: 'number'}\n",
655
- "\n",
656
  "print(f'Generating {N_PHRASES} phrases at temperature {TEMPERATURE}...')\n",
657
  "\n",
658
  "with torch.no_grad():\n",
659
- " # Generate latent sequence\n",
660
- " z_generated = mamba.generate(\n",
661
- " n_phrases=N_PHRASES,\n",
662
- " temperature=TEMPERATURE,\n",
663
- " batch_size=1,\n",
664
- " )\n",
665
- " print(f'Generated latent shape: {z_generated.shape}')\n",
666
- " \n",
667
- " # Decode each phrase latent to tokens\n",
668
  " all_tokens = []\n",
669
- " for t in range(z_generated.shape[1]):\n",
670
- " z = z_generated[:, t] # (1, latent_dim)\n",
671
- " \n",
672
- " # Autoregressive decode\n",
673
- " generated_ids = [config.bos_token_id]\n",
674
- " max_decode_len = 128\n",
675
- " \n",
676
- " for _ in range(max_decode_len):\n",
677
- " input_tensor = torch.tensor([generated_ids], dtype=torch.long, device=device)\n",
678
  " with torch.autocast(device_type=device.type, dtype=amp_dtype):\n",
679
- " logits = vae.decode(z, input_tensor)\n",
680
- " \n",
681
- " next_logits = logits[0, -1] / max(TEMPERATURE, 0.1)\n",
682
- " probs = F.softmax(next_logits, dim=-1)\n",
683
- " next_token = torch.multinomial(probs, 1).item()\n",
684
- " generated_ids.append(next_token)\n",
685
- " \n",
686
- " if next_token == config.eos_token_id:\n",
687
- " break\n",
688
- " \n",
689
- " phrase_tokens = tokenizer.decode(generated_ids)\n",
690
- " all_tokens.extend(phrase_tokens)\n",
691
- "\n",
692
- "print(f'\\nGenerated {len(all_tokens)} REMI+ tokens')\n",
693
- "print(f'Sample tokens: {all_tokens[:20]}')\n",
694
- "\n",
695
- "# Convert to MIDI notes\n",
696
  "notes = tokenizer.tokens_to_midi_notes(all_tokens)\n",
697
- "print(f'Extracted {len(notes)} notes')\n",
698
  "\n",
699
  "if notes:\n",
700
- " # Write MIDI file\n",
701
- " output_midi = '/content/generated_music.mid'\n",
702
- " success = notes_to_midi_file(notes, output_midi)\n",
703
- " if success:\n",
704
- " print(f'\\n🎵 MIDI file saved to: {output_midi}')\n",
705
- " print(f' Notes: {len(notes)}')\n",
706
- " if notes:\n",
707
- " total_duration = max(n[\"start\"] + n[\"duration\"] for n in notes)\n",
708
- " print(f' Duration: ~{total_duration/480:.1f} beats')\n",
709
- " pitches = [n[\"pitch\"] for n in notes]\n",
710
- " print(f' Pitch range: {min(pitches)}-{max(pitches)}')\n",
711
- " else:\n",
712
- " print('⚠️ MIDI writing failed. Install midiutil: pip install midiutil')\n",
713
  "else:\n",
714
- " print('⚠️ No notes generated. Try training for more epochs or adjusting temperature.')"
715
  ]
716
  },
717
  {
@@ -725,36 +425,16 @@
725
  "# ============================================================\n",
726
  "import os\n",
727
  "from dataclasses import asdict\n",
728
- "\n",
729
  "os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
730
- "\n",
731
- "# Save model\n",
732
  "save_path = os.path.join(OUTPUT_DIR, 'musemorphic_model.pt')\n",
733
- "torch.save({\n",
734
- " 'vae_state_dict': vae.state_dict(),\n",
735
- " 'mamba_state_dict': mamba.state_dict(),\n",
736
- " 'config': asdict(config),\n",
737
- " 'training_history': {\n",
738
- " 'vae': history,\n",
739
- " 'mamba': mamba_history,\n",
740
- " }\n",
741
- "}, save_path)\n",
742
- "print(f'✅ Model saved to {save_path}')\n",
743
- "print(f' File size: {os.path.getsize(save_path)/1e6:.1f} MB')\n",
744
- "\n",
745
- "# Save tokenizer\n",
746
  "tokenizer.save(os.path.join(OUTPUT_DIR, 'tokenizer'))\n",
747
- "print(f'✅ Tokenizer saved')\n",
748
  "\n",
749
- "# Optional: Push to HF Hub\n",
750
  "if PUSH_TO_HUB and HUB_MODEL_ID:\n",
751
  " from huggingface_hub import HfApi\n",
752
- " api = HfApi()\n",
753
- " api.upload_folder(\n",
754
- " folder_path=OUTPUT_DIR,\n",
755
- " repo_id=HUB_MODEL_ID,\n",
756
- " repo_type='model',\n",
757
- " )\n",
758
  " print(f'✅ Pushed to https://huggingface.co/{HUB_MODEL_ID}')"
759
  ]
760
  },
@@ -765,52 +445,17 @@
765
  "outputs": [],
766
  "source": [
767
  "# ============================================================\n",
768
- "# 11. Listen to Generated MIDI (in Colab)\n",
769
  "# ============================================================\n",
770
  "try:\n",
771
  " from IPython.display import Audio, display\n",
772
  " import pretty_midi\n",
773
- " \n",
774
- " # Synthesize MIDI to audio using FluidSynth\n",
775
  " pm = pretty_midi.PrettyMIDI('/content/generated_music.mid')\n",
776
  " audio = pm.fluidsynth(fs=22050)\n",
777
- " \n",
778
- " print('🎧 Listen to generated music:')\n",
779
- " display(Audio(audio, rate=22050))\n",
780
  "except Exception as e:\n",
781
- " print(f'Audio playback not available: {e}')\n",
782
- " print('Download the MIDI file and play it in any MIDI player.')\n",
783
- " print('File: /content/generated_music.mid')"
784
- ]
785
- },
786
- {
787
- "cell_type": "markdown",
788
- "metadata": {},
789
- "source": [
790
- "---\n",
791
- "\n",
792
- "## Architecture Summary\n",
793
- "\n",
794
- "### Novel Contributions\n",
795
- "\n",
796
- "1. **First SSM-based latent music generator**: Mamba operating on compressed phrase latents\n",
797
- "2. **FME with log-frequency encoding**: Physics-aware embeddings respecting harmonic series\n",
798
- "3. **Multi-attribute control via latent conditioning**: Tempo, key, density, style\n",
799
- "4. **Guaranteed training stability stack**: σReparam + ZClip + Pre-LN + BF16 + label smoothing\n",
800
- "5. **Three-stage PhraseVAE curriculum**: Prevents posterior collapse\n",
801
- "6. **Sub-1GB inference**: Phrase-level Mamba recurrence with fixed-size state\n",
802
- "\n",
803
- "### Key References\n",
804
- "\n",
805
- "- Gu & Dao (2023). **Mamba**: Linear-Time Sequence Modeling with Selective State Spaces\n",
806
- "- **MIDI-RWKV** (2025). Personalizable Long-Context Symbolic Music Infilling\n",
807
- "- **PhraseVAE** (2024). Phrase-level latent diffusion for music\n",
808
- "- **FME** (2022). Domain-Knowledge-Inspired Music Embedding\n",
809
- "- **σReparam** (2023). Stabilizing Transformer Training\n",
810
- "- **ZClip** (2025). Adaptive Spike Mitigation for LLM Pre-Training\n",
811
- "- **REMI** (2020). Pop Music Transformer\n",
812
- "\n",
813
- "📄 Full architecture document: [https://huggingface.co/asdf98/MuseMorphic](https://huggingface.co/asdf98/MuseMorphic)"
814
  ]
815
  }
816
  ]
 
20
  "cell_type": "markdown",
21
  "metadata": {},
22
  "source": [
23
+ "# 🎵 MuseMorphic: Lightweight MIDI Generator (v0.2)\n",
24
  "\n",
25
  "**A novel consumer-grade architecture for controllable, infinite-length MIDI generation.**\n",
26
  "\n",
27
+ "v0.2 Performance fixes: weight_norm (not spectral_norm), chunked SSM scan, vectorized masking.\n",
 
 
 
 
 
 
 
28
  "\n",
29
+ "📄 [Architecture Details](https://huggingface.co/asdf98/MuseMorphic)"
 
 
 
 
 
 
30
  ]
31
  },
32
  {
 
39
  "# 1. Install Dependencies\n",
40
  "# ============================================================\n",
41
  "!pip install -q torch torchvision torchaudio\n",
42
+ "!pip install -q einops datasets pretty_midi midiutil huggingface_hub\n",
 
43
  "\n",
44
+ "# Clone MuseMorphic repo (fresh pull to get v0.2 fixes)\n",
45
+ "!rm -rf /content/MuseMorphic\n",
46
+ "!git clone https://huggingface.co/asdf98/MuseMorphic /content/MuseMorphic\n",
47
  "\n",
48
  "import sys\n",
49
  "sys.path.insert(0, '/content/MuseMorphic/musemorphic')\n",
 
50
  "print('✅ Dependencies installed!')"
51
  ]
52
  },
 
60
  "# 2. Check GPU & Hardware\n",
61
  "# ============================================================\n",
62
  "import torch\n",
63
+ "print(f'PyTorch: {torch.__version__}')\n",
64
+ "print(f'CUDA: {torch.cuda.is_available()}')\n",
 
 
65
  "if torch.cuda.is_available():\n",
66
  " print(f'GPU: {torch.cuda.get_device_name(0)}')\n",
67
  " print(f'VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')\n",
68
+ " print(f'BF16: {torch.cuda.is_bf16_supported()}')\n",
 
 
69
  "\n",
 
70
  "if torch.cuda.is_available() and torch.cuda.is_bf16_supported():\n",
71
+ " DTYPE = 'bf16'; amp_dtype = torch.bfloat16\n",
 
72
  "elif torch.cuda.is_available():\n",
73
+ " DTYPE = 'fp16'; amp_dtype = torch.float16\n",
 
74
  "else:\n",
75
+ " DTYPE = 'fp32'; amp_dtype = torch.float32\n",
76
+ "print(f'\\nUsing: {DTYPE}')"
 
 
 
 
 
 
 
 
 
77
  ]
78
  },
79
  {
 
85
  "# ============================================================\n",
86
  "# 3. Configuration\n",
87
  "# ============================================================\n",
 
 
 
 
 
 
88
  "MODEL_SIZE = 'small' # @param ['tiny', 'small', 'medium']\n",
89
+ "DATASET = 'auto' # @param ['auto', 'maestro', 'synthetic']\n",
90
+ "MAX_PIECES = 500 # @param {type: 'integer'}\n",
91
+ "VAE_EPOCHS = 15 # @param {type: 'integer'}\n",
92
+ "MAMBA_EPOCHS = 30 # @param {type: 'integer'}\n",
93
+ "BATCH_SIZE = 32 # @param {type: 'integer'}\n",
94
+ "LEARNING_RATE = 3e-4 # @param {type: 'number'}\n",
95
+ "OUTPUT_DIR = '/content/checkpoints'\n",
96
+ "PUSH_TO_HUB = False\n",
97
+ "HUB_MODEL_ID = ''\n",
98
+ "print(f'Config: {MODEL_SIZE} model, {DATASET} data, VAE {VAE_EPOCHS}ep + Mamba {MAMBA_EPOCHS}ep')"
 
 
 
 
 
 
 
 
 
 
 
 
99
  ]
100
  },
101
  {
 
105
  "outputs": [],
106
  "source": [
107
  "# ============================================================\n",
108
+ "# 4. Build Model\n",
109
  "# ============================================================\n",
110
  "from model import MuseMorphicConfig, MuseMorphic, model_summary\n",
111
  "\n",
 
112
  "SIZE_CONFIGS = {\n",
113
+ " 'tiny': MuseMorphicConfig(d_model=128, vae_encoder_layers=2, vae_decoder_layers=2,\n",
114
+ " vae_n_heads=4, vae_d_ff=256, latent_dim=32, mamba_d_model=128,\n",
115
+ " mamba_n_layers=4, mamba_d_state=8, mamba_expand=2),\n",
116
+ " 'small': MuseMorphicConfig(d_model=256, vae_encoder_layers=3, vae_decoder_layers=3,\n",
117
+ " vae_n_heads=4, vae_d_ff=512, latent_dim=64, mamba_d_model=256,\n",
118
+ " mamba_n_layers=8, mamba_d_state=16, mamba_expand=2),\n",
119
+ " 'medium': MuseMorphicConfig(d_model=384, vae_encoder_layers=4, vae_decoder_layers=4,\n",
120
+ " vae_n_heads=6, vae_d_ff=768, latent_dim=96, mamba_d_model=384,\n",
121
+ " mamba_n_layers=12, mamba_d_state=16, mamba_expand=2),\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  "}\n",
 
123
  "config = SIZE_CONFIGS[MODEL_SIZE]\n",
124
  "model = model_summary(config)"
125
  ]
126
  },
 
 
 
 
 
 
 
 
 
127
  {
128
  "cell_type": "code",
129
  "execution_count": null,
 
135
  "# ============================================================\n",
136
  "import logging\n",
137
  "logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')\n",
138
+ "from data_pipeline import auto_select_dataset, load_dataset_notes, preprocess_dataset, _generate_synthetic_dataset\n",
 
 
139
  "from tokenizer import REMIPlusTokenizer\n",
140
  "\n",
 
 
 
 
 
 
 
 
 
 
 
141
  "tokenizer = REMIPlusTokenizer()\n",
142
+ "if DATASET == 'auto': dataset_name = auto_select_dataset()\n",
143
+ "elif DATASET == 'maestro': dataset_name = 'maestro_v1_sustain'\n",
144
+ "else: dataset_name = None\n",
145
  "\n",
146
+ "if dataset_name:\n",
147
  " try:\n",
148
  " pieces = load_dataset_notes(dataset_name, max_pieces=MAX_PIECES)\n",
149
  " print(f'✅ Loaded {len(pieces)} pieces from {dataset_name}')\n",
150
  " except Exception as e:\n",
151
+ " print(f'⚠️ {e}\\nFalling back to synthetic data...')\n",
 
152
  " pieces = _generate_synthetic_dataset(MAX_PIECES)\n",
153
  "else:\n",
154
  " pieces = _generate_synthetic_dataset(MAX_PIECES)\n",
155
  " print(f'✅ Generated {len(pieces)} synthetic pieces')\n",
156
  "\n",
 
157
  "phrases, controls = preprocess_dataset(pieces, tokenizer, max_phrase_len=config.vae_max_seq_len)\n",
158
+ "print(f'📊 {len(phrases)} phrases, avg {sum(len(p) for p in phrases)/len(phrases):.1f} tokens')"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  ]
160
  },
161
  {
 
167
  "# ============================================================\n",
168
  "# 6. Training — Stage 1: PhraseVAE\n",
169
  "# ============================================================\n",
170
+ "import time, random, math, numpy as np\n",
 
 
 
 
171
  "import torch.nn.functional as F\n",
172
  "from torch.utils.data import DataLoader, Dataset\n",
173
+ "from model import PhraseVAE, ZClip, apply_span_mask_vectorized\n",
174
  "\n",
175
+ "# Seed\n",
176
  "SEED = 42\n",
177
+ "random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)\n",
178
+ "if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)\n",
 
 
 
179
  "\n",
 
180
  "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
181
+ "use_scaler = (amp_dtype == torch.float16)\n",
182
+ "scaler = torch.amp.GradScaler() if use_scaler else None\n",
 
 
 
 
183
  "\n",
 
184
  "class PhraseDS(Dataset):\n",
185
  " def __init__(self, phrases, max_len, pad_id=0):\n",
186
+ " self.phrases = phrases; self.max_len = max_len; self.pad_id = pad_id\n",
187
+ " def __len__(self): return len(self.phrases)\n",
 
 
 
188
  " def __getitem__(self, idx):\n",
189
  " ids = self.phrases[idx][:self.max_len]\n",
190
+ " return torch.tensor(ids + [self.pad_id] * (self.max_len - len(ids)), dtype=torch.long)\n",
 
191
  "\n",
192
  "train_ds = PhraseDS(phrases, config.vae_max_seq_len)\n",
193
+ "train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,\n",
194
  " num_workers=2, pin_memory=True, drop_last=True)\n",
195
  "\n",
 
196
  "vae = PhraseVAE(config).to(device)\n",
197
+ "print(f'PhraseVAE: {sum(p.numel() for p in vae.parameters()):,} params')\n",
 
198
  "\n",
 
199
  "optimizer = torch.optim.AdamW(vae.parameters(), lr=LEARNING_RATE, weight_decay=0.01)\n",
200
  "scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=500, T_mult=2, eta_min=1e-6)\n",
201
  "zclip = ZClip(z_thresh=2.5)\n",
202
  "\n",
203
+ "# ---- Training Loop (uses vectorized span masking — NO Python loop over batch) ----\n",
204
+ "print('\\n' + '='*60 + '\\nStarting PhraseVAE Training\\n' + '='*60)\n",
 
205
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  "pretrain_epochs = max(1, VAE_EPOCHS // 5)\n",
207
  "ae_epochs = max(1, VAE_EPOCHS * 3 // 5)\n",
208
+ "vae_epochs_count = max(1, VAE_EPOCHS - pretrain_epochs - ae_epochs)\n",
209
+ "stages = [('1a-Pretrain', pretrain_epochs, 0.0, True),\n",
210
+ " ('1b-AE', ae_epochs, 0.0, False),\n",
211
+ " ('1c-VAE', vae_epochs_count, 0.01, False)]\n",
 
 
 
212
  "\n",
213
  "global_step = 0\n",
214
  "history = {'loss': [], 'recon': [], 'kl': []}\n",
215
  "\n",
216
  "for stage_name, n_epochs, kl_weight, use_masking in stages:\n",
217
  " print(f'\\n--- Stage {stage_name} ({n_epochs} epochs, KL={kl_weight}) ---')\n",
 
 
218
  " if stage_name == '1c-VAE':\n",
219
+ " for pg in optimizer.param_groups: pg['lr'] = LEARNING_RATE * 0.1\n",
220
+ "\n",
 
221
  " for epoch in range(n_epochs):\n",
222
+ " vae.train(); epoch_loss = 0; n_batches = 0; t0 = time.time()\n",
 
 
 
 
223
  " for batch in train_loader:\n",
224
  " token_ids = batch.to(device)\n",
225
+ " # VECTORIZED masking — runs entirely on GPU, no Python loop\n",
226
+ " input_ids = apply_span_mask_vectorized(token_ids, mask_id=config.mask_token_id) if use_masking else token_ids\n",
227
+ "\n",
228
+ " optimizer.zero_grad(set_to_none=True)\n",
 
 
 
 
 
229
  " with torch.autocast(device_type=device.type, dtype=amp_dtype):\n",
230
  " outputs = vae(input_ids, target_tokens=token_ids, kl_weight=kl_weight)\n",
231
  " loss = outputs['loss']\n",
232
+ "\n",
 
233
  " if torch.isnan(loss) or torch.isinf(loss):\n",
234
+ " print(f'⚠️ NaN at step {global_step}! Skipping...')\n",
235
+ " optimizer.zero_grad(set_to_none=True); continue\n",
236
+ "\n",
 
237
  " if use_scaler:\n",
238
  " scaler.scale(loss).backward()\n",
239
+ " scaler.unscale_(optimizer); zclip(vae)\n",
240
+ " scaler.step(optimizer); scaler.update()\n",
 
 
241
  " else:\n",
242
+ " loss.backward(); zclip(vae); optimizer.step()\n",
 
 
 
243
  " scheduler.step()\n",
244
+ "\n",
245
+ " epoch_loss += loss.item(); n_batches += 1; global_step += 1\n",
 
 
 
 
246
  " history['loss'].append(loss.item())\n",
247
  " history['recon'].append(outputs['recon_loss'].item())\n",
248
  " history['kl'].append(outputs['kl_loss'].item())\n",
249
+ "\n",
250
  " elapsed = time.time() - t0\n",
251
  " avg_loss = epoch_loss / max(n_batches, 1)\n",
 
252
  " print(f' Epoch {epoch+1}/{n_epochs} | Loss: {avg_loss:.4f} | '\n",
253
  " f'Recon: {outputs[\"recon_loss\"].item():.4f} | '\n",
254
  " f'KL: {outputs[\"kl_loss\"].item():.4f} | '\n",
255
+ " f'LR: {optimizer.param_groups[0][\"lr\"]:.2e} | {elapsed:.1f}s | '\n",
256
+ " f'{n_batches/elapsed:.1f} batch/s')\n",
257
  "\n",
258
+ "print(f'\\n✅ PhraseVAE training complete! ({global_step} steps)')"
259
  ]
260
  },
261
  {
 
269
  "# ============================================================\n",
270
  "from model import LatentMamba\n",
271
  "\n",
 
272
  "vae.eval()\n",
273
+ "for p in vae.parameters(): p.requires_grad = False\n",
 
274
  "\n",
 
275
  "print('Encoding phrases to latent space...')\n",
276
  "all_latents = []\n",
277
+ "encode_loader = DataLoader(train_ds, batch_size=128, shuffle=False, num_workers=2)\n",
 
278
  "with torch.no_grad():\n",
279
  " for batch in encode_loader:\n",
 
280
  " with torch.autocast(device_type=device.type, dtype=amp_dtype):\n",
281
+ " z, _, _ = vae.encode(batch.to(device))\n",
282
  " all_latents.append(z.cpu())\n",
 
283
  "all_z = torch.cat(all_latents, dim=0)\n",
284
+ "print(f'Encoded {all_z.shape[0]} phrases latent dim {all_z.shape[1]}')\n",
285
  "\n",
 
286
  "SEQ_LEN = min(64, len(all_z) // 4)\n",
287
+ "latent_seqs = [all_z[i:i+SEQ_LEN] for i in range(0, len(all_z) - SEQ_LEN, SEQ_LEN // 2)]\n",
288
+ "print(f'{len(latent_seqs)} latent sequences of length {SEQ_LEN}')\n",
 
 
 
289
  "\n",
290
  "class LatentDS(Dataset):\n",
291
+ " def __init__(self, seqs): self.seqs = seqs\n",
292
+ " def __len__(self): return len(self.seqs)\n",
293
+ " def __getitem__(self, idx): return self.seqs[idx]\n",
 
 
 
294
  "\n",
295
+ "latent_loader = DataLoader(LatentDS(latent_seqs),\n",
296
+ " batch_size=min(BATCH_SIZE, len(latent_seqs)), shuffle=True, drop_last=True)\n",
 
297
  "\n",
 
298
  "mamba = LatentMamba(config).to(device)\n",
299
+ "print(f'LatentMamba: {sum(p.numel() for p in mamba.parameters()):,} params')\n",
 
300
  "\n",
301
+ "mamba_opt = torch.optim.AdamW(mamba.parameters(), lr=LEARNING_RATE*0.5, weight_decay=0.01)\n",
302
+ "mamba_sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(mamba_opt, T_0=300, T_mult=2, eta_min=1e-6)\n",
 
 
303
  "mamba_zclip = ZClip(z_thresh=2.5)\n",
 
304
  "mamba_scaler = torch.amp.GradScaler() if use_scaler else None\n",
305
  "\n",
306
+ "print('\\n' + '='*60 + '\\nStarting LatentMamba Training\\n' + '='*60)\n",
 
 
 
 
307
  "mamba_history = {'mse': [], 'cos': []}\n",
308
  "\n",
309
  "for epoch in range(MAMBA_EPOCHS):\n",
310
+ " mamba.train(); epoch_loss = 0; n_batches = 0; t0 = time.time()\n",
 
 
 
 
311
  " for batch in latent_loader:\n",
312
  " z_seq = batch.to(device)\n",
313
+ " z_input, z_target = z_seq[:, :-1], z_seq[:, 1:]\n",
314
+ " mamba_opt.zero_grad(set_to_none=True)\n",
 
 
 
315
  " with torch.autocast(device_type=device.type, dtype=amp_dtype):\n",
316
  " z_pred = mamba(z_input)\n",
317
  " mse_loss = F.mse_loss(z_pred, z_target)\n",
318
  " cos_loss = 1.0 - F.cosine_similarity(\n",
319
  " z_pred.reshape(-1, z_pred.shape[-1]),\n",
320
+ " z_target.reshape(-1, z_target.shape[-1]), dim=-1).mean()\n",
 
321
  " loss = mse_loss + 0.1 * cos_loss\n",
322
+ "\n",
323
  " if torch.isnan(loss) or torch.isinf(loss):\n",
324
+ " mamba_opt.zero_grad(set_to_none=True); continue\n",
325
+ "\n",
 
 
326
  " if use_scaler:\n",
327
  " mamba_scaler.scale(loss).backward()\n",
328
+ " mamba_scaler.unscale_(mamba_opt); mamba_zclip(mamba)\n",
329
+ " mamba_scaler.step(mamba_opt); mamba_scaler.update()\n",
 
 
330
  " else:\n",
331
+ " loss.backward(); mamba_zclip(mamba); mamba_opt.step()\n",
332
+ " mamba_sched.step()\n",
333
+ " epoch_loss += loss.item(); n_batches += 1\n",
 
 
 
 
 
 
334
  " mamba_history['mse'].append(mse_loss.item())\n",
335
  " mamba_history['cos'].append(cos_loss.item())\n",
336
+ "\n",
 
 
 
 
337
  " if (epoch + 1) % 5 == 0 or epoch == 0:\n",
338
+ " elapsed = time.time() - t0\n",
339
+ " print(f' Epoch {epoch+1}/{MAMBA_EPOCHS} | Loss: {epoch_loss/max(n_batches,1):.6f} | '\n",
340
+ " f'MSE: {mse_loss.item():.6f} | Cos: {cos_loss.item():.4f} | {elapsed:.1f}s')\n",
341
  "\n",
342
+ "print('\\n✅ LatentMamba training complete!')"
343
  ]
344
  },
345
  {
 
352
  "# 8. Plot Training Curves\n",
353
  "# ============================================================\n",
354
  "import matplotlib.pyplot as plt\n",
 
355
  "fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n",
356
  "\n",
357
+ "for ax, data, title, color in [\n",
358
+ " (axes[0], history['loss'], 'PhraseVAE Loss', 'blue'),\n",
359
+ " (axes[1], history['kl'], 'KL Divergence', 'red'),\n",
360
+ " (axes[2], mamba_history['mse'], 'LatentMamba MSE', 'green')]:\n",
361
+ " ax.plot(data, alpha=0.3, color=color)\n",
362
+ " w = min(50, max(1, len(data)//5))\n",
363
+ " if w > 1: ax.plot(np.convolve(data, np.ones(w)/w, 'valid'), color=color, lw=2)\n",
364
+ " ax.set_title(title); ax.set_xlabel('Step'); ax.grid(True, alpha=0.3)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
  "\n",
366
+ "plt.tight_layout(); plt.savefig('/content/training_curves.png', dpi=150); plt.show()"
367
  ]
368
  },
369
  {
 
376
  "# 9. Generate Music!\n",
377
  "# ============================================================\n",
378
  "from model import MuseMorphic\n",
379
+ "from tokenizer import notes_to_midi_file\n",
380
  "\n",
 
381
  "full_model = MuseMorphic(config).to(device)\n",
382
  "full_model.phrase_vae = vae\n",
383
  "full_model.latent_mamba = mamba\n",
384
  "full_model.eval()\n",
385
  "\n",
386
+ "N_PHRASES = 16; TEMPERATURE = 0.7\n",
 
 
 
 
 
 
387
  "print(f'Generating {N_PHRASES} phrases at temperature {TEMPERATURE}...')\n",
388
  "\n",
389
  "with torch.no_grad():\n",
390
+ " z_gen = mamba.generate(n_phrases=N_PHRASES, temperature=TEMPERATURE, batch_size=1)\n",
391
+ " print(f'Latent shape: {z_gen.shape}')\n",
392
+ "\n",
 
 
 
 
 
 
393
  " all_tokens = []\n",
394
+ " for t in range(z_gen.shape[1]):\n",
395
+ " z = z_gen[:, t]\n",
396
+ " gen_ids = [config.bos_token_id]\n",
397
+ " for _ in range(128):\n",
398
+ " inp = torch.tensor([gen_ids], dtype=torch.long, device=device)\n",
 
 
 
 
399
  " with torch.autocast(device_type=device.type, dtype=amp_dtype):\n",
400
+ " logits = vae.decode(z, inp)\n",
401
+ " probs = F.softmax(logits[0, -1] / max(TEMPERATURE, 0.1), dim=-1)\n",
402
+ " tok = torch.multinomial(probs, 1).item()\n",
403
+ " gen_ids.append(tok)\n",
404
+ " if tok == config.eos_token_id: break\n",
405
+ " all_tokens.extend(tokenizer.decode(gen_ids))\n",
406
+ "\n",
 
 
 
 
 
 
 
 
 
 
407
  "notes = tokenizer.tokens_to_midi_notes(all_tokens)\n",
408
+ "print(f'Generated {len(notes)} notes from {len(all_tokens)} tokens')\n",
409
  "\n",
410
  "if notes:\n",
411
+ " notes_to_midi_file(notes, '/content/generated_music.mid')\n",
412
+ " print('🎵 MIDI saved to /content/generated_music.mid')\n",
 
 
 
 
 
 
 
 
 
 
 
413
  "else:\n",
414
+ " print('⚠️ No notes. Try more training epochs or different temperature.')"
415
  ]
416
  },
417
  {
 
425
  "# ============================================================\n",
426
  "import os\n",
427
  "from dataclasses import asdict\n",
 
428
  "os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
 
 
429
  "save_path = os.path.join(OUTPUT_DIR, 'musemorphic_model.pt')\n",
430
+ "torch.save({'vae_state_dict': vae.state_dict(), 'mamba_state_dict': mamba.state_dict(),\n",
431
+ " 'config': asdict(config), 'training_history': {'vae': history, 'mamba': mamba_history}}, save_path)\n",
 
 
 
 
 
 
 
 
 
 
 
432
  "tokenizer.save(os.path.join(OUTPUT_DIR, 'tokenizer'))\n",
433
+ "print(f'✅ Model saved to {save_path} ({os.path.getsize(save_path)/1e6:.1f} MB)')\n",
434
  "\n",
 
435
  "if PUSH_TO_HUB and HUB_MODEL_ID:\n",
436
  " from huggingface_hub import HfApi\n",
437
+ " HfApi().upload_folder(folder_path=OUTPUT_DIR, repo_id=HUB_MODEL_ID, repo_type='model')\n",
 
 
 
 
 
438
  " print(f'✅ Pushed to https://huggingface.co/{HUB_MODEL_ID}')"
439
  ]
440
  },
 
445
  "outputs": [],
446
  "source": [
447
  "# ============================================================\n",
448
+ "# 11. Listen to Generated MIDI\n",
449
  "# ============================================================\n",
450
  "try:\n",
451
  " from IPython.display import Audio, display\n",
452
  " import pretty_midi\n",
 
 
453
  " pm = pretty_midi.PrettyMIDI('/content/generated_music.mid')\n",
454
  " audio = pm.fluidsynth(fs=22050)\n",
455
+ " print('🎧 Listen:'); display(Audio(audio, rate=22050))\n",
 
 
456
  "except Exception as e:\n",
457
+ " print(f'Audio playback unavailable: {e}')\n",
458
+ " print('Download /content/generated_music.mid and play in any MIDI player.')"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
  ]
460
  }
461
  ]