{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": {"provenance": [], "gpuType": "T4"}, "kernelspec": {"name": "python3", "display_name": "Python 3"}, "language_info": {"name": "python"}, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": {}, "source": ["# šŸŽµ MuseMorphic v0.3 — Lightweight MIDI Generator\n", "šŸ“„ [Architecture](https://huggingface.co/asdf98/MuseMorphic)"] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 1. Install & Clone\n", "!pip install -q torch torchvision torchaudio einops datasets pretty_midi midiutil huggingface_hub\n", "!rm -rf /content/MuseMorphic\n", "!git clone https://huggingface.co/asdf98/MuseMorphic /content/MuseMorphic\n", "import sys; sys.path.insert(0, '/content/MuseMorphic/musemorphic')\n", "print('āœ… Ready')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 2. GPU check\n", "import torch, time\n", "print(f'PyTorch {torch.__version__}, CUDA {torch.cuda.is_available()}')\n", "if torch.cuda.is_available():\n", " print(f'GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_mem/1e9:.1f}GB')\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "# T4 does NOT support BF16 — use FP16 with GradScaler\n", "amp_dtype = torch.float16 if torch.cuda.is_available() else torch.float32\n", "use_scaler = (amp_dtype == torch.float16)\n", "print(f'AMP dtype: {amp_dtype}, GradScaler: {use_scaler}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 3. Config\n", "MODEL_SIZE = 'small' # @param ['tiny', 'small', 'medium']\n", "DATASET = 'synthetic' # @param ['auto', 'maestro', 'synthetic']\n", "MAX_PIECES = 200 # @param {type: 'integer'}\n", "VAE_EPOCHS = 15 # @param {type: 'integer'}\n", "MAMBA_EPOCHS = 30 # @param {type: 'integer'}\n", "BATCH_SIZE = 32 # @param {type: 'integer'}\n", "LEARNING_RATE = 3e-4 # @param {type: 'number'}\n", "OUTPUT_DIR = '/content/checkpoints'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 4. Build model\n", "from model import MuseMorphicConfig, model_summary\n", "CFGS = {\n", " 'tiny': MuseMorphicConfig(d_model=128, vae_encoder_layers=2, vae_decoder_layers=2,\n", " vae_n_heads=4, vae_d_ff=256, latent_dim=32, mamba_d_model=128, mamba_n_layers=4, mamba_d_state=8),\n", " 'small': MuseMorphicConfig(d_model=256, vae_encoder_layers=3, vae_decoder_layers=3,\n", " vae_n_heads=4, vae_d_ff=512, latent_dim=64, mamba_d_model=256, mamba_n_layers=8, mamba_d_state=16),\n", " 'medium': MuseMorphicConfig(d_model=384, vae_encoder_layers=4, vae_decoder_layers=4,\n", " vae_n_heads=6, vae_d_ff=768, latent_dim=96, mamba_d_model=384, mamba_n_layers=12, mamba_d_state=16),\n", "}\n", "config = CFGS[MODEL_SIZE]\n", "_ = model_summary(config)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 5. Load & preprocess data\n", "import logging; logging.basicConfig(level=logging.INFO)\n", "from data_pipeline import auto_select_dataset, load_dataset_notes, preprocess_dataset, _generate_synthetic_dataset\n", "from tokenizer import REMIPlusTokenizer\n", "\n", "tokenizer = REMIPlusTokenizer()\n", "t0 = time.time()\n", "\n", "if DATASET == 'synthetic':\n", " pieces = _generate_synthetic_dataset(MAX_PIECES)\n", "elif DATASET == 'maestro':\n", " try: pieces = load_dataset_notes('maestro_v1_sustain', max_pieces=MAX_PIECES)\n", " except Exception as e: print(f'āš ļø {e}'); pieces = _generate_synthetic_dataset(MAX_PIECES)\n", "else:\n", " try: pieces = load_dataset_notes(auto_select_dataset(), max_pieces=MAX_PIECES)\n", " except Exception as e: print(f'āš ļø {e}'); pieces = _generate_synthetic_dataset(MAX_PIECES)\n", "\n", "print(f'Loaded {len(pieces)} pieces in {time.time()-t0:.1f}s')\n", "\n", "t0 = time.time()\n", "phrases, controls = preprocess_dataset(pieces, tokenizer, max_phrase_len=config.vae_max_seq_len)\n", "print(f'Preprocessed → {len(phrases)} phrases in {time.time()-t0:.1f}s')\n", "print(f'Avg length: {sum(len(p) for p in phrases)/max(len(phrases),1):.1f} tokens')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 6. STAGE 1 — PhraseVAE Training\n", "# ================================================================\n", "# KEY PERF DECISIONS:\n", "# - Pre-convert ALL phrases to a single padded tensor (no per-item padding)\n", "# - num_workers=0 (avoids Colab multiprocessing deadlocks)\n", "# - Warmup forward pass before timing (first CUDA call compiles kernels)\n", "# - Explicit timing at every step to identify any remaining bottleneck\n", "# ================================================================\n", "import random, math, numpy as np\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader, TensorDataset\n", "from model import PhraseVAE, ZClip, apply_span_mask_vectorized\n", "\n", "SEED = 42\n", "random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)\n", "if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)\n", "\n", "# ---- PRE-CONVERT all phrases into one big padded tensor ----\n", "# This eliminates per-item Python list padding in __getitem__\n", "print('Pre-converting phrases to tensor...')\n", "t0 = time.time()\n", "max_len = config.vae_max_seq_len\n", "all_ids = torch.zeros(len(phrases), max_len, dtype=torch.long)\n", "for i, p in enumerate(phrases):\n", " L = min(len(p), max_len)\n", " all_ids[i, :L] = torch.tensor(p[:L], dtype=torch.long)\n", "print(f'Tensor shape: {all_ids.shape}, took {time.time()-t0:.2f}s')\n", "\n", "# TensorDataset + num_workers=0 = zero overhead DataLoader\n", "train_ds = TensorDataset(all_ids)\n", "train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,\n", " num_workers=0, pin_memory=False, drop_last=True)\n", "\n", "# Create model\n", "vae = PhraseVAE(config).to(device)\n", "print(f'PhraseVAE: {sum(p.numel() for p in vae.parameters()):,} params')\n", "\n", "optimizer = torch.optim.AdamW(vae.parameters(), lr=LEARNING_RATE, weight_decay=0.01)\n", "scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=500, T_mult=2, eta_min=1e-6)\n", "zclip = ZClip(z_thresh=2.5)\n", "scaler = torch.amp.GradScaler() if use_scaler else None\n", "\n", "# ---- WARMUP: one forward+backward to compile CUDA kernels ----\n", "print('\\nWarmup forward pass (compiling CUDA kernels)...')\n", "t0 = time.time()\n", "vae.train()\n", "warmup_batch = all_ids[:BATCH_SIZE].to(device)\n", "with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=torch.cuda.is_available()):\n", " warmup_out = vae(warmup_batch, kl_weight=0.0)\n", "warmup_out['loss'].backward()\n", "optimizer.zero_grad(set_to_none=True)\n", "if torch.cuda.is_available(): torch.cuda.synchronize()\n", "print(f'Warmup done in {time.time()-t0:.2f}s (loss={warmup_out[\"loss\"].item():.4f})')\n", "\n", "# ---- Training loop ----\n", "print('\\n' + '='*60 + '\\nStarting PhraseVAE Training\\n' + '='*60)\n", "\n", "pretrain_ep = max(1, VAE_EPOCHS // 5)\n", "ae_ep = max(1, VAE_EPOCHS * 3 // 5)\n", "vae_ep = max(1, VAE_EPOCHS - pretrain_ep - ae_ep)\n", "stages = [('1a-Pretrain', pretrain_ep, 0.0, True),\n", " ('1b-AE', ae_ep, 0.0, False),\n", " ('1c-VAE', vae_ep, 0.01, False)]\n", "\n", "global_step = 0\n", "history = {'loss': [], 'recon': [], 'kl': []}\n", "\n", "for stage_name, n_epochs, kl_weight, use_masking in stages:\n", " print(f'\\n--- {stage_name} ({n_epochs} ep, KL={kl_weight}, mask={use_masking}) ---')\n", " if stage_name == '1c-VAE':\n", " for pg in optimizer.param_groups: pg['lr'] = LEARNING_RATE * 0.1\n", "\n", " for epoch in range(n_epochs):\n", " vae.train()\n", " epoch_loss = 0; n_batches = 0\n", " t_epoch = time.time()\n", "\n", " for (batch_data,) in train_loader: # TensorDataset returns tuple\n", " token_ids = batch_data.to(device, non_blocking=True)\n", "\n", " if use_masking:\n", " input_ids = apply_span_mask_vectorized(token_ids, mask_id=config.mask_token_id)\n", " else:\n", " input_ids = token_ids\n", "\n", " optimizer.zero_grad(set_to_none=True)\n", "\n", " with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=torch.cuda.is_available()):\n", " outputs = vae(input_ids, target_tokens=token_ids, kl_weight=kl_weight)\n", " loss = outputs['loss']\n", "\n", " if torch.isnan(loss) or torch.isinf(loss):\n", " print(f'āš ļø NaN step {global_step}'); optimizer.zero_grad(set_to_none=True); continue\n", "\n", " if use_scaler:\n", " scaler.scale(loss).backward()\n", " scaler.unscale_(optimizer); zclip(vae)\n", " scaler.step(optimizer); scaler.update()\n", " else:\n", " loss.backward(); zclip(vae); optimizer.step()\n", " scheduler.step()\n", "\n", " epoch_loss += loss.item(); n_batches += 1; global_step += 1\n", " history['loss'].append(loss.item())\n", " history['recon'].append(outputs['recon_loss'].item())\n", " history['kl'].append(outputs['kl_loss'].item())\n", "\n", " elapsed = time.time() - t_epoch\n", " avg = epoch_loss / max(n_batches, 1)\n", " print(f' Ep {epoch+1}/{n_epochs} | loss={avg:.4f} recon={outputs[\"recon_loss\"].item():.4f} '\n", " f'kl={outputs[\"kl_loss\"].item():.4f} | {elapsed:.1f}s ({n_batches/max(elapsed,0.01):.1f} batch/s)')\n", "\n", "print(f'\\nāœ… PhraseVAE done! {global_step} steps')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 7. STAGE 2 — LatentMamba Training\n", "from model import LatentMamba\n", "\n", "vae.eval()\n", "for p in vae.parameters(): p.requires_grad = False\n", "\n", "print('Encoding phrases...')\n", "t0 = time.time()\n", "all_latents = []\n", "enc_loader = DataLoader(TensorDataset(all_ids), batch_size=128, num_workers=0)\n", "with torch.no_grad():\n", " for (batch_data,) in enc_loader:\n", " with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=torch.cuda.is_available()):\n", " z, _, _ = vae.encode(batch_data.to(device))\n", " all_latents.append(z.cpu())\n", "all_z = torch.cat(all_latents, dim=0)\n", "print(f'Encoded {all_z.shape[0]} phrases in {time.time()-t0:.1f}s')\n", "\n", "SEQ_LEN = min(64, max(4, len(all_z) // 4))\n", "latent_seqs = torch.stack([all_z[i:i+SEQ_LEN] for i in range(0, len(all_z)-SEQ_LEN, SEQ_LEN//2)])\n", "print(f'{latent_seqs.shape[0]} sequences of len {SEQ_LEN}')\n", "\n", "lat_loader = DataLoader(TensorDataset(latent_seqs),\n", " batch_size=min(BATCH_SIZE, latent_seqs.shape[0]), shuffle=True, num_workers=0, drop_last=True)\n", "\n", "mamba = LatentMamba(config).to(device)\n", "print(f'LatentMamba: {sum(p.numel() for p in mamba.parameters()):,} params')\n", "\n", "m_opt = torch.optim.AdamW(mamba.parameters(), lr=LEARNING_RATE*0.5, weight_decay=0.01)\n", "m_sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(m_opt, T_0=300, T_mult=2, eta_min=1e-6)\n", "m_zclip = ZClip(z_thresh=2.5)\n", "m_scaler = torch.amp.GradScaler() if use_scaler else None\n", "\n", "print('\\n' + '='*60 + '\\nLatentMamba Training\\n' + '='*60)\n", "mamba_history = {'mse': [], 'cos': []}\n", "\n", "for epoch in range(MAMBA_EPOCHS):\n", " mamba.train(); e_loss = 0; nb = 0; t0 = time.time()\n", " for (z_batch,) in lat_loader:\n", " z_seq = z_batch.to(device); z_in, z_tgt = z_seq[:,:-1], z_seq[:,1:]\n", " m_opt.zero_grad(set_to_none=True)\n", " with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=torch.cuda.is_available()):\n", " z_pred = mamba(z_in)\n", " mse = F.mse_loss(z_pred, z_tgt)\n", " cos = 1.0 - F.cosine_similarity(z_pred.reshape(-1,z_pred.shape[-1]),\n", " z_tgt.reshape(-1,z_tgt.shape[-1]), dim=-1).mean()\n", " loss = mse + 0.1 * cos\n", " if torch.isnan(loss): m_opt.zero_grad(set_to_none=True); continue\n", " if use_scaler:\n", " m_scaler.scale(loss).backward(); m_scaler.unscale_(m_opt); m_zclip(mamba)\n", " m_scaler.step(m_opt); m_scaler.update()\n", " else: loss.backward(); m_zclip(mamba); m_opt.step()\n", " m_sched.step(); e_loss += loss.item(); nb += 1\n", " mamba_history['mse'].append(mse.item()); mamba_history['cos'].append(cos.item())\n", " if (epoch+1) % 5 == 0 or epoch == 0:\n", " print(f' Ep {epoch+1}/{MAMBA_EPOCHS} | loss={e_loss/max(nb,1):.6f} mse={mse.item():.6f} '\n", " f'cos={cos.item():.4f} | {time.time()-t0:.1f}s')\n", "print('\\nāœ… LatentMamba done!')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 8. Training curves\n", "import matplotlib.pyplot as plt\n", "fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n", "for ax, data, title, c in [(axes[0],history['loss'],'VAE Loss','blue'),\n", " (axes[1],history['kl'],'KL','red'),(axes[2],mamba_history['mse'],'Mamba MSE','green')]:\n", " ax.plot(data, alpha=0.3, color=c)\n", " w = min(50, max(1, len(data)//5))\n", " if w > 1: ax.plot(np.convolve(data, np.ones(w)/w, 'valid'), color=c, lw=2)\n", " ax.set_title(title); ax.grid(True, alpha=0.3)\n", "plt.tight_layout(); plt.savefig('/content/curves.png', dpi=150); plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 9. Generate!\n", "from model import MuseMorphic\n", "from tokenizer import notes_to_midi_file\n", "\n", "full_model = MuseMorphic(config).to(device)\n", "full_model.phrase_vae = vae; full_model.latent_mamba = mamba; full_model.eval()\n", "\n", "N_PHRASES = 16; TEMP = 0.7\n", "with torch.no_grad():\n", " z_gen = mamba.generate(N_PHRASES, temperature=TEMP, batch_size=1)\n", " all_tokens = []\n", " for t in range(z_gen.shape[1]):\n", " z = z_gen[:, t]\n", " ids = [config.bos_token_id]\n", " for _ in range(128):\n", " inp = torch.tensor([ids], dtype=torch.long, device=device)\n", " with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=torch.cuda.is_available()):\n", " logits = vae.decode(z, inp)\n", " tok = torch.multinomial(F.softmax(logits[0,-1]/max(TEMP,0.1), dim=-1), 1).item()\n", " ids.append(tok)\n", " if tok == config.eos_token_id: break\n", " all_tokens.extend(tokenizer.decode(ids))\n", "\n", "notes = tokenizer.tokens_to_midi_notes(all_tokens)\n", "print(f'{len(notes)} notes from {len(all_tokens)} tokens')\n", "if notes:\n", " notes_to_midi_file(notes, '/content/generated.mid')\n", " print('šŸŽµ Saved /content/generated.mid')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 10. Save\n", "import os; from dataclasses import asdict\n", "os.makedirs(OUTPUT_DIR, exist_ok=True)\n", "torch.save({'vae_state_dict': vae.state_dict(), 'mamba_state_dict': mamba.state_dict(),\n", " 'config': asdict(config)}, f'{OUTPUT_DIR}/model.pt')\n", "tokenizer.save(f'{OUTPUT_DIR}/tokenizer')\n", "print(f'āœ… Saved to {OUTPUT_DIR}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 11. Listen\n", "try:\n", " from IPython.display import Audio, display; import pretty_midi\n", " audio = pretty_midi.PrettyMIDI('/content/generated.mid').fluidsynth(fs=22050)\n", " display(Audio(audio, rate=22050))\n", "except: print('Download /content/generated.mid to listen')" ] } ] }