| """ |
| MuseMorphic: Lightweight Consumer-Grade MIDI Generation Architecture |
| ==================================================================== |
| |
| A novel two-stage hierarchical architecture combining: |
| Stage 1 - PhraseVAE: Compress REMI+ tokens → 64-dim latent vectors |
| Stage 2 - LatentMamba: Generate latent sequences with O(n) complexity |
| |
| Key innovations: |
| - O(n) complexity everywhere (Selective SSM backbone) |
| - Music-native FME embeddings (translational invariance, transposability) |
| - ~33M parameters, trains on free Colab T4, inference <1GB VRAM |
| - Controllable via multi-attribute conditioning |
| - Infinite generation via fixed-size recurrent state |
| - Training stability by design (σReparam, ZClip, Pre-LN, BF16, label smoothing) |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from dataclasses import dataclass, field |
| from typing import Optional, List, Tuple, Dict |
| from einops import rearrange |
|
|
| |
| |
| |
|
|
| @dataclass |
| class MuseMorphicConfig: |
| """Complete configuration for MuseMorphic architecture.""" |
| |
| |
| vocab_size: int = 8192 |
| pad_token_id: int = 0 |
| bos_token_id: int = 1 |
| eos_token_id: int = 2 |
| mask_token_id: int = 3 |
| |
| |
| d_model: int = 256 |
| fme_base_pitch: float = 10000.0 |
| fme_base_duration: float = 1000.0 |
| fme_base_onset: float = 5000.0 |
| use_log_frequency: bool = True |
| |
| |
| vae_encoder_layers: int = 3 |
| vae_decoder_layers: int = 3 |
| vae_n_heads: int = 4 |
| vae_d_ff: int = 512 |
| vae_n_queries: int = 4 |
| latent_dim: int = 64 |
| vae_dropout: float = 0.1 |
| vae_max_seq_len: int = 256 |
| kl_beta: float = 0.01 |
| label_smoothing: float = 0.1 |
| |
| |
| mamba_d_model: int = 256 |
| mamba_n_layers: int = 8 |
| mamba_d_state: int = 16 |
| mamba_d_conv: int = 4 |
| mamba_expand: int = 2 |
| mamba_dropout: float = 0.1 |
| max_phrases: int = 512 |
| |
| |
| n_tempo_bins: int = 45 |
| n_key_classes: int = 24 |
| n_time_sig_classes: int = 8 |
| n_density_bins: int = 10 |
| n_style_classes: int = 32 |
| |
| |
| use_sigma_reparam: bool = True |
| use_pre_ln: bool = True |
| zclip_z_thresh: float = 2.5 |
| zclip_alpha: float = 0.99 |
| |
| |
| learning_rate: float = 3e-4 |
| weight_decay: float = 0.01 |
| warmup_steps: int = 500 |
| max_steps: int = 100000 |
| batch_size: int = 32 |
| gradient_accumulation_steps: int = 1 |
|
|
|
|
| |
| |
| |
|
|
| class FundamentalMusicEmbedding(nn.Module): |
| """ |
| Translational-invariant, transposable pitch/duration/onset embedding. |
| |
| From Liang et al. (2022) "Domain-Knowledge-Inspired Music Embedding" |
| Extended with log-frequency pitch encoding for harmonic series awareness. |
| |
| Properties: |
| 1. |f_a - f_b| = |f_c - f_d| => ||FME(f_a) - FME(f_b)|| = ||FME(f_c) - FME(f_d)|| |
| 2. Transposition is a linear operation in embedding space |
| 3. Pitch, duration, onset are orthogonal via different base B values |
| """ |
| |
| def __init__(self, d_model: int, base_B: float = 10000.0, use_log_freq: bool = False): |
| super().__init__() |
| self.d_model = d_model |
| self.use_log_freq = use_log_freq |
| half_d = d_model // 2 |
| |
| |
| k = torch.arange(half_d, dtype=torch.float32) |
| w_k = base_B ** (-2.0 * k / d_model) |
| self.register_buffer('w_k', w_k) |
| |
| |
| self.b_sin = nn.Parameter(torch.zeros(half_d)) |
| self.b_cos = nn.Parameter(torch.zeros(half_d)) |
| |
| def forward(self, values: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| values: Integer or float values, shape (batch, seq_len) |
| Returns: |
| Embedding, shape (batch, seq_len, d_model) |
| """ |
| f = values.float() |
| |
| if self.use_log_freq: |
| |
| |
| f = torch.log2(440.0 * (2.0 ** ((f - 69.0) / 12.0)) + 1e-8) |
| |
| f = f.unsqueeze(-1) |
| |
| sin_enc = torch.sin(self.w_k * f) + self.b_sin |
| cos_enc = torch.cos(self.w_k * f) + self.b_cos |
| |
| return torch.cat([sin_enc, cos_enc], dim=-1) |
|
|
|
|
| class MusicTokenEmbedding(nn.Module): |
| """ |
| Combined embedding for REMI+ tokens using FME for musically-meaningful tokens |
| and standard learned embeddings for structural tokens. |
| """ |
| |
| def __init__(self, config: MuseMorphicConfig): |
| super().__init__() |
| self.config = config |
| d = config.d_model |
| |
| |
| self.token_embed = nn.Embedding(config.vocab_size, d, padding_idx=config.pad_token_id) |
| |
| |
| self.pitch_fme = FundamentalMusicEmbedding(d, config.fme_base_pitch, config.use_log_frequency) |
| self.duration_fme = FundamentalMusicEmbedding(d, config.fme_base_duration, False) |
| self.onset_fme = FundamentalMusicEmbedding(d, config.fme_base_onset, False) |
| |
| |
| self.pos_embed = nn.Embedding(config.vae_max_seq_len, d) |
| |
| |
| self.embed_ln = nn.LayerNorm(d) |
| self.embed_dropout = nn.Dropout(config.vae_dropout) |
| |
| |
| self.scale = math.sqrt(d) |
| |
| def forward( |
| self, |
| token_ids: torch.Tensor, |
| pitch_values: Optional[torch.Tensor] = None, |
| duration_values: Optional[torch.Tensor] = None, |
| onset_values: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| token_ids: (batch, seq_len) BPE token indices |
| pitch_values: (batch, seq_len) MIDI pitch values (0 where not applicable) |
| duration_values: (batch, seq_len) duration ticks (0 where not applicable) |
| onset_values: (batch, seq_len) onset positions (0 where not applicable) |
| """ |
| B, L = token_ids.shape |
| |
| |
| x = self.token_embed(token_ids) * self.scale |
| |
| |
| if pitch_values is not None: |
| mask = (pitch_values > 0).float().unsqueeze(-1) |
| x = x + self.pitch_fme(pitch_values) * mask |
| |
| if duration_values is not None: |
| mask = (duration_values > 0).float().unsqueeze(-1) |
| x = x + self.duration_fme(duration_values) * mask |
| |
| if onset_values is not None: |
| mask = (onset_values > 0).float().unsqueeze(-1) |
| x = x + self.onset_fme(onset_values) * mask |
| |
| |
| positions = torch.arange(L, device=token_ids.device).unsqueeze(0).expand(B, -1) |
| x = x + self.pos_embed(positions) |
| |
| return self.embed_dropout(self.embed_ln(x)) |
|
|
|
|
| |
| |
| |
|
|
| class SigmaReparamLinear(nn.Module): |
| """ |
| Linear layer with spectral reparameterization (σReparam). |
| |
| From Zhai et al. (2023) "Stabilizing Transformer Training by Preventing |
| Attention Entropy Collapse" (arXiv:2303.06296). |
| |
| W_hat = (γ / σ(W)) * W |
| |
| where σ(W) is the spectral norm (largest singular value). |
| Prevents attention entropy collapse — the #1 source of training instability. |
| """ |
| |
| def __init__(self, in_features: int, out_features: int, bias: bool = True): |
| super().__init__() |
| self.linear = nn.Linear(in_features, out_features, bias=bias) |
| |
| self.linear = nn.utils.parametrizations.spectral_norm(self.linear) |
| |
| self.gamma = nn.Parameter(torch.ones(1)) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.gamma * self.linear(x) |
|
|
|
|
| def make_linear(in_f: int, out_f: int, bias: bool = True, sigma_reparam: bool = True) -> nn.Module: |
| """Factory for linear layers with optional σReparam.""" |
| if sigma_reparam: |
| return SigmaReparamLinear(in_f, out_f, bias) |
| return nn.Linear(in_f, out_f, bias) |
|
|
|
|
| |
| |
| |
|
|
| class PreLNMultiHeadAttention(nn.Module): |
| """Multi-head attention with Pre-LayerNorm and σReparam.""" |
| |
| def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1, |
| sigma_reparam: bool = True, is_cross_attention: bool = False): |
| super().__init__() |
| assert d_model % n_heads == 0 |
| self.n_heads = n_heads |
| self.d_head = d_model // n_heads |
| self.scale = math.sqrt(self.d_head) |
| |
| self.q_proj = make_linear(d_model, d_model, sigma_reparam=sigma_reparam) |
| self.k_proj = make_linear(d_model, d_model, sigma_reparam=sigma_reparam) |
| self.v_proj = make_linear(d_model, d_model, sigma_reparam=sigma_reparam) |
| self.out_proj = make_linear(d_model, d_model, sigma_reparam=sigma_reparam) |
| |
| self.attn_dropout = nn.Dropout(dropout) |
| self.is_cross_attention = is_cross_attention |
| |
| def forward( |
| self, |
| x: torch.Tensor, |
| context: Optional[torch.Tensor] = None, |
| mask: Optional[torch.Tensor] = None, |
| is_causal: bool = False, |
| ) -> torch.Tensor: |
| B, L, D = x.shape |
| |
| q = self.q_proj(x) |
| kv_input = context if self.is_cross_attention and context is not None else x |
| k = self.k_proj(kv_input) |
| v = self.v_proj(kv_input) |
| |
| |
| q = rearrange(q, 'b l (h d) -> b h l d', h=self.n_heads) |
| k = rearrange(k, 'b s (h d) -> b h s d', h=self.n_heads) |
| v = rearrange(v, 'b s (h d) -> b h s d', h=self.n_heads) |
| |
| |
| attn_out = F.scaled_dot_product_attention( |
| q, k, v, |
| attn_mask=mask, |
| dropout_p=self.attn_dropout.p if self.training else 0.0, |
| is_causal=is_causal, |
| ) |
| |
| attn_out = rearrange(attn_out, 'b h l d -> b l (h d)') |
| return self.out_proj(attn_out) |
|
|
|
|
| class PreLNFeedForward(nn.Module): |
| """Feed-forward network with Pre-LN, SiLU activation, and σReparam.""" |
| |
| def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, |
| sigma_reparam: bool = True): |
| super().__init__() |
| self.w1 = make_linear(d_model, d_ff, sigma_reparam=sigma_reparam) |
| self.w2 = make_linear(d_ff, d_model, sigma_reparam=sigma_reparam) |
| self.gate = make_linear(d_model, d_ff, sigma_reparam=sigma_reparam) |
| self.dropout = nn.Dropout(dropout) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| return self.dropout(self.w2(F.silu(self.gate(x)) * self.w1(x))) |
|
|
|
|
| class PreLNTransformerBlock(nn.Module): |
| """ |
| Transformer block with Pre-LayerNorm for guaranteed training stability. |
| |
| Pre-LN: x → LayerNorm → Sublayer → + residual |
| (vs Post-LN: x → Sublayer → + residual → LayerNorm, which is UNSTABLE) |
| |
| Pre-LN has analytically bounded gradient norms, eliminates need for LR warmup. |
| """ |
| |
| def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1, |
| sigma_reparam: bool = True, has_cross_attention: bool = False): |
| super().__init__() |
| |
| self.norm1 = nn.LayerNorm(d_model) |
| self.self_attn = PreLNMultiHeadAttention(d_model, n_heads, dropout, sigma_reparam) |
| |
| self.has_cross_attention = has_cross_attention |
| if has_cross_attention: |
| self.norm_cross = nn.LayerNorm(d_model) |
| self.cross_attn = PreLNMultiHeadAttention( |
| d_model, n_heads, dropout, sigma_reparam, is_cross_attention=True |
| ) |
| |
| self.norm2 = nn.LayerNorm(d_model) |
| self.ffn = PreLNFeedForward(d_model, d_ff, dropout, sigma_reparam) |
| |
| def forward( |
| self, |
| x: torch.Tensor, |
| context: Optional[torch.Tensor] = None, |
| mask: Optional[torch.Tensor] = None, |
| is_causal: bool = False, |
| ) -> torch.Tensor: |
| |
| x = x + self.self_attn(self.norm1(x), mask=mask, is_causal=is_causal) |
| |
| |
| if self.has_cross_attention and context is not None: |
| x = x + self.cross_attn(self.norm_cross(x), context=context) |
| |
| |
| x = x + self.ffn(self.norm2(x)) |
| |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class PhraseVAEEncoder(nn.Module): |
| """ |
| Encode a sequence of REMI+ tokens into a latent vector using |
| multi-query cross-attention bottleneck. |
| |
| Architecture: TransformerEncoder → MultiQueryBottleneck → μ, log_var |
| """ |
| |
| def __init__(self, config: MuseMorphicConfig): |
| super().__init__() |
| self.config = config |
| d = config.d_model |
| |
| |
| self.layers = nn.ModuleList([ |
| PreLNTransformerBlock( |
| d, config.vae_n_heads, config.vae_d_ff, |
| config.vae_dropout, config.use_sigma_reparam |
| ) |
| for _ in range(config.vae_encoder_layers) |
| ]) |
| |
| self.final_norm = nn.LayerNorm(d) |
| |
| |
| self.query_tokens = nn.Parameter(torch.randn(config.vae_n_queries, d) * 0.02) |
| self.bottleneck_attn = PreLNMultiHeadAttention( |
| d, config.vae_n_heads, config.vae_dropout, |
| config.use_sigma_reparam, is_cross_attention=True |
| ) |
| self.bottleneck_norm = nn.LayerNorm(d) |
| |
| |
| bottleneck_dim = config.vae_n_queries * d |
| self.to_mu = nn.Linear(bottleneck_dim, config.latent_dim) |
| self.to_log_var = nn.Linear(bottleneck_dim, config.latent_dim) |
| |
| def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Args: |
| x: Embedded tokens (batch, seq_len, d_model) |
| Returns: |
| mu: (batch, latent_dim) |
| log_var: (batch, latent_dim) |
| """ |
| B = x.shape[0] |
| |
| |
| for layer in self.layers: |
| x = layer(x, mask=mask) |
| x = self.final_norm(x) |
| |
| |
| queries = self.query_tokens.unsqueeze(0).expand(B, -1, -1) |
| z_queries = self.bottleneck_attn( |
| self.bottleneck_norm(queries), context=x |
| ) |
| |
| |
| z_flat = z_queries.reshape(B, -1) |
| mu = self.to_mu(z_flat) |
| log_var = self.to_log_var(z_flat) |
| |
| return mu, log_var |
|
|
|
|
| class PhraseVAEDecoder(nn.Module): |
| """ |
| Decode a latent vector back to REMI+ token sequence (autoregressive). |
| |
| Architecture: LatentProjection → CrossAttention with latent → AR generation |
| """ |
| |
| def __init__(self, config: MuseMorphicConfig): |
| super().__init__() |
| self.config = config |
| d = config.d_model |
| |
| |
| self.latent_proj = nn.Linear(config.latent_dim, config.vae_n_queries * d) |
| |
| |
| self.token_embed = nn.Embedding(config.vocab_size, d, padding_idx=config.pad_token_id) |
| self.pos_embed = nn.Embedding(config.vae_max_seq_len, d) |
| self.embed_scale = math.sqrt(d) |
| |
| |
| self.layers = nn.ModuleList([ |
| PreLNTransformerBlock( |
| d, config.vae_n_heads, config.vae_d_ff, |
| config.vae_dropout, config.use_sigma_reparam, |
| has_cross_attention=True |
| ) |
| for _ in range(config.vae_decoder_layers) |
| ]) |
| |
| self.final_norm = nn.LayerNorm(d) |
| self.output_proj = nn.Linear(d, config.vocab_size, bias=False) |
| |
| def forward( |
| self, |
| z: torch.Tensor, |
| target_tokens: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| z: Latent vector (batch, latent_dim) |
| target_tokens: Target token ids for teacher forcing (batch, seq_len) |
| Returns: |
| logits: (batch, seq_len, vocab_size) |
| """ |
| B, L = target_tokens.shape |
| d = self.config.d_model |
| |
| |
| latent_ctx = self.latent_proj(z).reshape(B, self.config.vae_n_queries, d) |
| |
| |
| positions = torch.arange(L, device=target_tokens.device).unsqueeze(0) |
| x = self.token_embed(target_tokens) * self.embed_scale + self.pos_embed(positions) |
| |
| |
| for layer in self.layers: |
| x = layer(x, context=latent_ctx, is_causal=True) |
| |
| x = self.final_norm(x) |
| logits = self.output_proj(x) |
| |
| return logits |
|
|
|
|
| class PhraseVAE(nn.Module): |
| """ |
| Complete PhraseVAE: Encode REMI+ token phrases → latent vectors → decode back. |
| |
| Three-stage training curriculum: |
| Stage 1: Span-infilling pretraining (learn REMI grammar) |
| Stage 2: Autoencoder (KL weight = 0, pure reconstruction) |
| Stage 3: VAE fine-tuning (KL weight = β = 0.01) |
| """ |
| |
| def __init__(self, config: MuseMorphicConfig): |
| super().__init__() |
| self.config = config |
| |
| |
| self.embedding = MusicTokenEmbedding(config) |
| |
| |
| self.encoder = PhraseVAEEncoder(config) |
| self.decoder = PhraseVAEDecoder(config) |
| |
| def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor: |
| """Reparameterization trick: z = μ + σ * ε""" |
| if self.training: |
| std = torch.exp(0.5 * log_var) |
| eps = torch.randn_like(std) |
| return mu + std * eps |
| return mu |
| |
| def encode(self, token_ids: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """Encode tokens to latent space.""" |
| x = self.embedding(token_ids, **kwargs) |
| mu, log_var = self.encoder(x) |
| z = self.reparameterize(mu, log_var) |
| return z, mu, log_var |
| |
| def decode(self, z: torch.Tensor, target_tokens: torch.Tensor) -> torch.Tensor: |
| """Decode latent vector to token logits.""" |
| return self.decoder(z, target_tokens) |
| |
| def forward( |
| self, |
| token_ids: torch.Tensor, |
| target_tokens: Optional[torch.Tensor] = None, |
| kl_weight: float = 0.01, |
| **kwargs |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Full forward pass with loss computation. |
| |
| Args: |
| token_ids: Input tokens (batch, seq_len) |
| target_tokens: Target tokens for reconstruction (batch, seq_len), |
| defaults to token_ids shifted right |
| kl_weight: β for KL loss weighting (0 for AE stage, 0.01 for VAE stage) |
| """ |
| B, L = token_ids.shape |
| |
| if target_tokens is None: |
| target_tokens = token_ids |
| |
| |
| z, mu, log_var = self.encode(token_ids, **kwargs) |
| |
| |
| decoder_input = target_tokens[:, :-1] |
| decoder_target = target_tokens[:, 1:] |
| logits = self.decode(z, decoder_input) |
| |
| |
| recon_loss = F.cross_entropy( |
| logits.reshape(-1, self.config.vocab_size), |
| decoder_target.reshape(-1), |
| ignore_index=self.config.pad_token_id, |
| label_smoothing=self.config.label_smoothing, |
| ) |
| |
| |
| kl_loss = -0.5 * torch.mean( |
| torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1) |
| ) |
| |
| total_loss = recon_loss + kl_weight * kl_loss |
| |
| return { |
| 'loss': total_loss, |
| 'recon_loss': recon_loss, |
| 'kl_loss': kl_loss, |
| 'z': z, |
| 'mu': mu, |
| 'log_var': log_var, |
| 'logits': logits, |
| } |
|
|
|
|
| |
| |
| |
|
|
| class SelectiveSSM(nn.Module): |
| """ |
| Selective State Space Model (Mamba core). |
| |
| From Gu & Dao (2023) "Mamba: Linear-Time Sequence Modeling with Selective |
| State Spaces" (arXiv:2312.00752). |
| |
| Key equations: |
| B(x) = Linear_N(x) -- input-dependent |
| C(x) = Linear_N(x) -- input-dependent |
| Δ(x) = softplus(Linear_1(x) + param) -- input-dependent discretization |
| Ā = exp(Δ · A) -- discretized state matrix |
| B̄ = Δ · B(x) -- simplified discretized input matrix |
| h_t = Ā · h_{t-1} + B̄ · x_t -- state update |
| y_t = C(x_t) · h_t -- output |
| |
| Training: parallel scan O(BLD·N) |
| Inference: O(BD·N) per step, state is O(D·N) fixed |
| """ |
| |
| def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, |
| expand: int = 2, sigma_reparam: bool = True): |
| super().__init__() |
| self.d_model = d_model |
| self.d_state = d_state |
| self.d_inner = d_model * expand |
| self.d_conv = d_conv |
| |
| |
| self.in_proj = make_linear(d_model, self.d_inner * 2, bias=False, sigma_reparam=sigma_reparam) |
| |
| |
| self.conv1d = nn.Conv1d( |
| self.d_inner, self.d_inner, |
| kernel_size=d_conv, |
| padding=d_conv - 1, |
| groups=self.d_inner, |
| ) |
| |
| |
| |
| A = torch.arange(1, d_state + 1, dtype=torch.float32).unsqueeze(0).expand(self.d_inner, -1) |
| self.A_log = nn.Parameter(torch.log(A)) |
| self.D = nn.Parameter(torch.ones(self.d_inner)) |
| |
| |
| self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False) |
| self.dt_proj = nn.Linear(1, self.d_inner, bias=True) |
| |
| |
| dt_init_std = 0.02 |
| nn.init.uniform_(self.dt_proj.bias, math.log(0.001), math.log(0.1)) |
| |
| |
| self.out_proj = make_linear(self.d_inner, d_model, bias=False, sigma_reparam=sigma_reparam) |
| |
| def _ssm_scan(self, x: torch.Tensor, A: torch.Tensor, B: torch.Tensor, |
| C: torch.Tensor, D: torch.Tensor, dt: torch.Tensor) -> torch.Tensor: |
| """ |
| Parallel associative scan for training efficiency. |
| |
| This is a pure PyTorch implementation using sequential scan. |
| For production, use the CUDA kernel from mamba-ssm package. |
| |
| Args: |
| x: (B, L, D_inner) |
| A: (D_inner, N) — state transition (negative, in log space) |
| B: (B, L, N) — input-dependent input matrix |
| C: (B, L, N) — input-dependent output matrix |
| D: (D_inner,) — skip connection |
| dt: (B, L, D_inner) — input-dependent discretization step |
| """ |
| batch, seq_len, d_inner = x.shape |
| N = self.d_state |
| |
| |
| A_discrete = torch.exp(dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)) |
| B_discrete = dt.unsqueeze(-1) * B.unsqueeze(2) |
| |
| |
| h = torch.zeros(batch, d_inner, N, device=x.device, dtype=x.dtype) |
| outputs = [] |
| |
| for t in range(seq_len): |
| h = A_discrete[:, t] * h + B_discrete[:, t] * x[:, t].unsqueeze(-1) |
| y_t = torch.sum(h * C[:, t].unsqueeze(1), dim=-1) |
| outputs.append(y_t) |
| |
| y = torch.stack(outputs, dim=1) |
| |
| |
| y = y + x * D.unsqueeze(0).unsqueeze(0) |
| |
| return y |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x: (batch, seq_len, d_model) |
| Returns: |
| (batch, seq_len, d_model) |
| """ |
| B, L, D = x.shape |
| |
| |
| xz = self.in_proj(x) |
| x_inner, z = xz.chunk(2, dim=-1) |
| |
| |
| x_conv = self.conv1d(x_inner.transpose(1, 2))[:, :, :L].transpose(1, 2) |
| x_conv = F.silu(x_conv) |
| |
| |
| x_proj = self.x_proj(x_conv) |
| B_param = x_proj[:, :, :self.d_state] |
| C_param = x_proj[:, :, self.d_state:2*self.d_state] |
| dt_param = x_proj[:, :, -1:] |
| |
| |
| dt = F.softplus(self.dt_proj(dt_param)) |
| |
| |
| A = -torch.exp(self.A_log) |
| |
| |
| y = self._ssm_scan(x_conv, A, B_param, C_param, self.D, dt) |
| |
| |
| y = y * F.silu(z) |
| y = self.out_proj(y) |
| |
| return y |
|
|
|
|
| class MambaBlock(nn.Module): |
| """ |
| Complete Mamba block with Pre-LN and residual connection. |
| |
| x → Pre-LN → SelectiveSSM → + residual |
| """ |
| |
| def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, |
| expand: int = 2, dropout: float = 0.1, sigma_reparam: bool = True): |
| super().__init__() |
| self.norm = nn.LayerNorm(d_model) |
| self.ssm = SelectiveSSM(d_model, d_state, d_conv, expand, sigma_reparam) |
| self.dropout = nn.Dropout(dropout) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x + self.dropout(self.ssm(self.norm(x))) |
|
|
|
|
| |
| |
| |
|
|
| class ControlEmbedding(nn.Module): |
| """ |
| Embed musical control parameters into d_model vectors. |
| |
| Controls: tempo, key, time_signature, note_density, style |
| Each control is embedded and summed, then projected. |
| """ |
| |
| def __init__(self, config: MuseMorphicConfig): |
| super().__init__() |
| d = config.mamba_d_model |
| |
| self.tempo_embed = nn.Embedding(config.n_tempo_bins, d) |
| self.key_embed = nn.Embedding(config.n_key_classes, d) |
| self.time_sig_embed = nn.Embedding(config.n_time_sig_classes, d) |
| self.density_embed = nn.Embedding(config.n_density_bins, d) |
| self.style_embed = nn.Embedding(config.n_style_classes, d) |
| |
| |
| self.control_proj = nn.Sequential( |
| nn.Linear(d, d), |
| nn.SiLU(), |
| nn.Linear(d, d), |
| ) |
| self.norm = nn.LayerNorm(d) |
| |
| def forward( |
| self, |
| tempo: Optional[torch.Tensor] = None, |
| key: Optional[torch.Tensor] = None, |
| time_sig: Optional[torch.Tensor] = None, |
| density: Optional[torch.Tensor] = None, |
| style: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Returns control embedding of shape (batch, 1, d_model).""" |
| B = tempo.shape[0] if tempo is not None else key.shape[0] |
| d = self.tempo_embed.embedding_dim |
| device = next(self.parameters()).device |
| |
| ctrl = torch.zeros(B, d, device=device) |
| |
| if tempo is not None: |
| ctrl = ctrl + self.tempo_embed(tempo) |
| if key is not None: |
| ctrl = ctrl + self.key_embed(key) |
| if time_sig is not None: |
| ctrl = ctrl + self.time_sig_embed(time_sig) |
| if density is not None: |
| ctrl = ctrl + self.density_embed(density) |
| if style is not None: |
| ctrl = ctrl + self.style_embed(style) |
| |
| ctrl = self.norm(self.control_proj(ctrl)) |
| return ctrl.unsqueeze(1) |
|
|
|
|
| class LatentMamba(nn.Module): |
| """ |
| Generate sequences of phrase latent vectors using Selective SSM (Mamba). |
| |
| Architecture: |
| Input: [control_embed, z_1, z_2, ..., z_T] |
| → Linear projection (latent_dim → d_model) |
| → MambaBlock × N |
| → Linear projection (d_model → latent_dim) |
| → Output: predicted z_2, z_3, ..., z_{T+1} |
| |
| Complexity: O(T·D·N) — linear in sequence length |
| Inference: O(D·N) per phrase — constant, enables infinite generation |
| """ |
| |
| def __init__(self, config: MuseMorphicConfig): |
| super().__init__() |
| self.config = config |
| d = config.mamba_d_model |
| |
| |
| self.control_embed = ControlEmbedding(config) |
| |
| |
| self.latent_in = nn.Sequential( |
| nn.Linear(config.latent_dim, d), |
| nn.LayerNorm(d), |
| ) |
| |
| |
| self.pos_embed = nn.Embedding(config.max_phrases + 1, d) |
| |
| |
| self.layers = nn.ModuleList([ |
| MambaBlock( |
| d, config.mamba_d_state, config.mamba_d_conv, |
| config.mamba_expand, config.mamba_dropout, |
| config.use_sigma_reparam |
| ) |
| for _ in range(config.mamba_n_layers) |
| ]) |
| |
| self.final_norm = nn.LayerNorm(d) |
| |
| |
| self.latent_out = nn.Linear(d, config.latent_dim) |
| |
| def forward( |
| self, |
| z_seq: torch.Tensor, |
| controls: Optional[Dict[str, torch.Tensor]] = None, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| z_seq: Sequence of phrase latents (batch, n_phrases, latent_dim) |
| controls: Dict of control tensors (each (batch,) integer indices) |
| Returns: |
| z_pred: Predicted next phrase latents (batch, n_phrases, latent_dim) |
| """ |
| B, T, _ = z_seq.shape |
| device = z_seq.device |
| |
| |
| x = self.latent_in(z_seq) |
| |
| |
| if controls is not None: |
| ctrl = self.control_embed(**controls) |
| x = torch.cat([ctrl, x], dim=1) |
| T_total = T + 1 |
| else: |
| T_total = T |
| |
| |
| positions = torch.arange(T_total, device=device).unsqueeze(0) |
| x = x + self.pos_embed(positions) |
| |
| |
| for layer in self.layers: |
| x = layer(x) |
| |
| x = self.final_norm(x) |
| |
| |
| if controls is not None: |
| x = x[:, 1:] |
| |
| z_pred = self.latent_out(x) |
| |
| return z_pred |
| |
| def generate( |
| self, |
| n_phrases: int, |
| controls: Optional[Dict[str, torch.Tensor]] = None, |
| temperature: float = 0.8, |
| batch_size: int = 1, |
| ) -> torch.Tensor: |
| """ |
| Generate a sequence of phrase latents autoregressively. |
| |
| Uses Mamba's recurrent mode for O(1) memory per step. |
| Can generate infinitely without memory growth. |
| """ |
| device = next(self.parameters()).device |
| d = self.config.mamba_d_model |
| |
| |
| if controls is not None: |
| z_init = self.control_embed(**controls) |
| else: |
| z_init = torch.zeros(batch_size, 1, d, device=device) |
| |
| |
| generated = [] |
| x = z_init + self.pos_embed(torch.tensor([0], device=device)) |
| |
| |
| states = [torch.zeros(batch_size, self.config.mamba_d_model * self.config.mamba_expand, |
| self.config.mamba_d_state, device=device) |
| for _ in range(self.config.mamba_n_layers)] |
| |
| for t in range(n_phrases): |
| h = x |
| for i, layer in enumerate(self.layers): |
| h = layer.norm(h) |
| |
| h = layer.ssm(h) |
| h = x + layer.dropout(h - x + h) |
| x = h |
| |
| h = self.final_norm(h) |
| z_t = self.latent_out(h[:, -1:]) |
| |
| |
| if temperature > 0: |
| z_t = z_t + temperature * torch.randn_like(z_t) |
| |
| generated.append(z_t) |
| |
| |
| x = self.latent_in(z_t) + self.pos_embed( |
| torch.tensor([t + 1], device=device).clamp(max=self.config.max_phrases - 1) |
| ) |
| |
| return torch.cat(generated, dim=1) |
|
|
|
|
| |
| |
| |
|
|
| class MuseMorphic(nn.Module): |
| """ |
| Complete MuseMorphic model combining PhraseVAE and LatentMamba. |
| |
| Two-stage training: |
| Stage 1: Train PhraseVAE (encode/decode individual phrases) |
| Stage 2: Freeze PhraseVAE encoder, train LatentMamba on latent sequences |
| |
| Inference pipeline: |
| Controls → LatentMamba.generate() → PhraseVAE.decode() → REMI+ tokens → MIDI |
| """ |
| |
| def __init__(self, config: MuseMorphicConfig): |
| super().__init__() |
| self.config = config |
| self.phrase_vae = PhraseVAE(config) |
| self.latent_mamba = LatentMamba(config) |
| |
| def encode_phrases(self, phrases: List[torch.Tensor], **kwargs) -> torch.Tensor: |
| """ |
| Encode a list of phrase token sequences to latent vectors. |
| |
| Args: |
| phrases: List of (batch, phrase_len) token tensors |
| Returns: |
| z_seq: (batch, n_phrases, latent_dim) |
| """ |
| z_list = [] |
| self.phrase_vae.eval() |
| with torch.no_grad(): |
| for phrase_tokens in phrases: |
| z, _, _ = self.phrase_vae.encode(phrase_tokens, **kwargs) |
| z_list.append(z.unsqueeze(1)) |
| return torch.cat(z_list, dim=1) |
| |
| def decode_phrases(self, z_seq: torch.Tensor, max_len: int = 256) -> List[torch.Tensor]: |
| """ |
| Decode latent vectors back to token sequences. |
| |
| Args: |
| z_seq: (batch, n_phrases, latent_dim) |
| Returns: |
| List of (batch, phrase_len) token tensors |
| """ |
| B, T, _ = z_seq.shape |
| decoded = [] |
| |
| self.phrase_vae.eval() |
| with torch.no_grad(): |
| for t in range(T): |
| z = z_seq[:, t] |
| |
| tokens = self._ar_decode(z, max_len) |
| decoded.append(tokens) |
| |
| return decoded |
| |
| def _ar_decode(self, z: torch.Tensor, max_len: int) -> torch.Tensor: |
| """Autoregressive decoding from latent vector.""" |
| B = z.shape[0] |
| device = z.device |
| |
| |
| tokens = torch.full((B, 1), self.config.bos_token_id, dtype=torch.long, device=device) |
| |
| for _ in range(max_len - 1): |
| logits = self.phrase_vae.decode(z, tokens) |
| next_token_logits = logits[:, -1, :] |
| |
| |
| next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) |
| tokens = torch.cat([tokens, next_token], dim=1) |
| |
| |
| if (next_token == self.config.eos_token_id).all(): |
| break |
| |
| return tokens |
| |
| @torch.no_grad() |
| def generate( |
| self, |
| n_phrases: int = 32, |
| controls: Optional[Dict[str, torch.Tensor]] = None, |
| temperature: float = 0.8, |
| max_phrase_len: int = 256, |
| batch_size: int = 1, |
| ) -> List[torch.Tensor]: |
| """ |
| Full generation pipeline. |
| |
| Controls → LatentMamba → PhraseVAE.decode → REMI+ tokens |
| |
| Memory: O(D·N) fixed during generation — truly infinite. |
| """ |
| self.eval() |
| |
| |
| z_seq = self.latent_mamba.generate( |
| n_phrases, controls, temperature, batch_size |
| ) |
| |
| |
| decoded_phrases = self.decode_phrases(z_seq, max_phrase_len) |
| |
| return decoded_phrases |
| |
| def count_parameters(self) -> Dict[str, int]: |
| """Count parameters by component.""" |
| vae_enc = sum(p.numel() for p in self.phrase_vae.encoder.parameters()) |
| vae_dec = sum(p.numel() for p in self.phrase_vae.decoder.parameters()) |
| vae_emb = sum(p.numel() for p in self.phrase_vae.embedding.parameters()) |
| mamba = sum(p.numel() for p in self.latent_mamba.parameters()) |
| total = sum(p.numel() for p in self.parameters()) |
| |
| return { |
| 'vae_encoder': vae_enc, |
| 'vae_decoder': vae_dec, |
| 'vae_embedding': vae_emb, |
| 'latent_mamba': mamba, |
| 'total': total, |
| } |
| |
| def get_vram_estimate(self, batch_size: int = 1, seq_len: int = 256, |
| dtype_bytes: int = 2) -> Dict[str, str]: |
| """Estimate VRAM usage.""" |
| params = self.count_parameters() |
| |
| |
| param_mem = params['total'] * dtype_bytes |
| |
| |
| act_mem = param_mem * 2 |
| |
| |
| opt_mem = params['total'] * 4 * 2 |
| |
| training_mem = param_mem + act_mem + opt_mem |
| inference_mem = param_mem + act_mem // 4 |
| |
| return { |
| 'parameters_mb': f"{param_mem / 1e6:.1f} MB", |
| 'training_vram_mb': f"{training_mem / 1e6:.1f} MB", |
| 'inference_vram_mb': f"{inference_mem / 1e6:.1f} MB", |
| } |
|
|
|
|
| |
| |
| |
|
|
| class ZClip: |
| """ |
| Adaptive gradient clipping via z-score thresholding. |
| |
| From ZClip (2025) "Adaptive Spike Mitigation for LLM Pre-Training" |
| (arXiv:2504.02507). |
| |
| Only clips genuine gradient spikes, not normal gradients. |
| Optimal z_thresh: 2.0-3.0 (Table 6 in paper). |
| """ |
| |
| def __init__(self, z_thresh: float = 2.5, alpha: float = 0.99): |
| self.z_thresh = z_thresh |
| self.alpha = alpha |
| self.mu = 0.0 |
| self.var = 1.0 |
| self.initialized = False |
| |
| def __call__(self, model: nn.Module) -> float: |
| """Clip gradients and return the original norm.""" |
| total_norm = torch.nn.utils.clip_grad_norm_( |
| model.parameters(), float('inf') |
| ).item() |
| |
| if not self.initialized: |
| self.mu = total_norm |
| self.var = 0.0 |
| self.initialized = True |
| return total_norm |
| |
| |
| sigma = max(math.sqrt(self.var), 1e-8) |
| threshold = self.mu + self.z_thresh * sigma |
| |
| |
| if total_norm > threshold: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), threshold) |
| |
| |
| self.mu = self.alpha * self.mu + (1 - self.alpha) * total_norm |
| self.var = self.alpha * self.var + (1 - self.alpha) * (total_norm - self.mu) ** 2 |
| |
| return total_norm |
|
|
|
|
| |
| |
| |
|
|
| def model_summary(config: Optional[MuseMorphicConfig] = None): |
| """Print model summary with parameter counts and VRAM estimates.""" |
| if config is None: |
| config = MuseMorphicConfig() |
| |
| model = MuseMorphic(config) |
| params = model.count_parameters() |
| vram = model.get_vram_estimate() |
| |
| print("=" * 60) |
| print("MuseMorphic Model Summary") |
| print("=" * 60) |
| print(f"\nParameter Counts:") |
| for name, count in params.items(): |
| print(f" {name:20s}: {count:>10,d} ({count/1e6:.2f}M)") |
| |
| print(f"\nVRAM Estimates (BF16):") |
| for name, est in vram.items(): |
| print(f" {name:20s}: {est}") |
| |
| print(f"\nArchitecture:") |
| print(f" d_model: {config.d_model}") |
| print(f" Vocab size: {config.vocab_size}") |
| print(f" Latent dim: {config.latent_dim}") |
| print(f" VAE layers: {config.vae_encoder_layers}+{config.vae_decoder_layers}") |
| print(f" Mamba layers: {config.mamba_n_layers}") |
| print(f" Mamba state dim: {config.mamba_d_state}") |
| print(f" Max phrase tokens: {config.vae_max_seq_len}") |
| print(f" Max phrases: {config.max_phrases}") |
| print("=" * 60) |
| |
| return model |
|
|
|
|
| if __name__ == "__main__": |
| model = model_summary() |
|
|