""" MuseMorphic: Lightweight Consumer-Grade MIDI Generation Architecture ==================================================================== A novel two-stage hierarchical architecture combining: Stage 1 - PhraseVAE: Compress REMI+ tokens → 64-dim latent vectors Stage 2 - LatentMamba: Generate latent sequences with O(n) complexity Key innovations: - O(n) complexity everywhere (Selective SSM backbone) - Music-native FME embeddings (translational invariance, transposability) - ~33M parameters, trains on free Colab T4, inference <1GB VRAM - Controllable via multi-attribute conditioning - Infinite generation via fixed-size recurrent state - Training stability by design (σReparam, ZClip, Pre-LN, BF16, label smoothing) """ import math import torch import torch.nn as nn import torch.nn.functional as F from dataclasses import dataclass, field from typing import Optional, List, Tuple, Dict from einops import rearrange # ============================================================================ # Configuration # ============================================================================ @dataclass class MuseMorphicConfig: """Complete configuration for MuseMorphic architecture.""" # --- Tokenizer --- vocab_size: int = 8192 # BPE vocabulary size pad_token_id: int = 0 bos_token_id: int = 1 eos_token_id: int = 2 mask_token_id: int = 3 # --- FME Embeddings --- d_model: int = 256 # Model dimension throughout fme_base_pitch: float = 10000.0 # Base B for pitch FME fme_base_duration: float = 1000.0 # Base B for duration FME fme_base_onset: float = 5000.0 # Base B for onset FME use_log_frequency: bool = True # Encode pitch as log-frequency # --- PhraseVAE --- vae_encoder_layers: int = 3 vae_decoder_layers: int = 3 vae_n_heads: int = 4 vae_d_ff: int = 512 # Feed-forward dim vae_n_queries: int = 4 # Multi-query bottleneck queries latent_dim: int = 64 # VAE latent dimension vae_dropout: float = 0.1 vae_max_seq_len: int = 256 # Max tokens per phrase kl_beta: float = 0.01 # KL weight (low to prevent posterior collapse) label_smoothing: float = 0.1 # --- LatentMamba --- mamba_d_model: int = 256 mamba_n_layers: int = 8 mamba_d_state: int = 16 # SSM state dimension N mamba_d_conv: int = 4 # Local convolution width mamba_expand: int = 2 # Inner dimension expansion factor mamba_dropout: float = 0.1 max_phrases: int = 512 # Max phrases in a piece # --- Control --- n_tempo_bins: int = 45 # (30-210 BPM, step 4) n_key_classes: int = 24 # 12 keys × major/minor n_time_sig_classes: int = 8 # Common time signatures n_density_bins: int = 10 # Note density percentile bins n_style_classes: int = 32 # Style/genre categories # --- Training Stability --- use_sigma_reparam: bool = True use_pre_ln: bool = True zclip_z_thresh: float = 2.5 zclip_alpha: float = 0.99 # --- Training --- learning_rate: float = 3e-4 weight_decay: float = 0.01 warmup_steps: int = 500 max_steps: int = 100000 batch_size: int = 32 gradient_accumulation_steps: int = 1 # ============================================================================ # Fundamental Music Embedding (FME) — Physics-Aware # ============================================================================ class FundamentalMusicEmbedding(nn.Module): """ Translational-invariant, transposable pitch/duration/onset embedding. From Liang et al. (2022) "Domain-Knowledge-Inspired Music Embedding" Extended with log-frequency pitch encoding for harmonic series awareness. Properties: 1. |f_a - f_b| = |f_c - f_d| => ||FME(f_a) - FME(f_b)|| = ||FME(f_c) - FME(f_d)|| 2. Transposition is a linear operation in embedding space 3. Pitch, duration, onset are orthogonal via different base B values """ def __init__(self, d_model: int, base_B: float = 10000.0, use_log_freq: bool = False): super().__init__() self.d_model = d_model self.use_log_freq = use_log_freq half_d = d_model // 2 # Exponentially decaying frequencies k = torch.arange(half_d, dtype=torch.float32) w_k = base_B ** (-2.0 * k / d_model) self.register_buffer('w_k', w_k) # Learnable biases (enable fine-tuning of embedding geometry) self.b_sin = nn.Parameter(torch.zeros(half_d)) self.b_cos = nn.Parameter(torch.zeros(half_d)) def forward(self, values: torch.Tensor) -> torch.Tensor: """ Args: values: Integer or float values, shape (batch, seq_len) Returns: Embedding, shape (batch, seq_len, d_model) """ f = values.float() if self.use_log_freq: # Convert MIDI pitch to log-frequency (respects harmonic series) # f_hz = 440 * 2^((p-69)/12), log2(f_hz) = log2(440) + (p-69)/12 f = torch.log2(440.0 * (2.0 ** ((f - 69.0) / 12.0)) + 1e-8) f = f.unsqueeze(-1) # (B, L, 1) sin_enc = torch.sin(self.w_k * f) + self.b_sin # (B, L, d/2) cos_enc = torch.cos(self.w_k * f) + self.b_cos # (B, L, d/2) return torch.cat([sin_enc, cos_enc], dim=-1) # (B, L, d) class MusicTokenEmbedding(nn.Module): """ Combined embedding for REMI+ tokens using FME for musically-meaningful tokens and standard learned embeddings for structural tokens. """ def __init__(self, config: MuseMorphicConfig): super().__init__() self.config = config d = config.d_model # Standard token embedding (for BPE tokens) self.token_embed = nn.Embedding(config.vocab_size, d, padding_idx=config.pad_token_id) # FME components (used as additive bias for pitch/duration/onset tokens) self.pitch_fme = FundamentalMusicEmbedding(d, config.fme_base_pitch, config.use_log_frequency) self.duration_fme = FundamentalMusicEmbedding(d, config.fme_base_duration, False) self.onset_fme = FundamentalMusicEmbedding(d, config.fme_base_onset, False) # Positional embedding (within-bar position, learnable) self.pos_embed = nn.Embedding(config.vae_max_seq_len, d) # Layer norm for embedding output stability self.embed_ln = nn.LayerNorm(d) self.embed_dropout = nn.Dropout(config.vae_dropout) # Scale factor self.scale = math.sqrt(d) def forward( self, token_ids: torch.Tensor, pitch_values: Optional[torch.Tensor] = None, duration_values: Optional[torch.Tensor] = None, onset_values: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: token_ids: (batch, seq_len) BPE token indices pitch_values: (batch, seq_len) MIDI pitch values (0 where not applicable) duration_values: (batch, seq_len) duration ticks (0 where not applicable) onset_values: (batch, seq_len) onset positions (0 where not applicable) """ B, L = token_ids.shape # Base token embedding x = self.token_embed(token_ids) * self.scale # Add FME for musically-meaningful attributes (when available) if pitch_values is not None: mask = (pitch_values > 0).float().unsqueeze(-1) x = x + self.pitch_fme(pitch_values) * mask if duration_values is not None: mask = (duration_values > 0).float().unsqueeze(-1) x = x + self.duration_fme(duration_values) * mask if onset_values is not None: mask = (onset_values > 0).float().unsqueeze(-1) x = x + self.onset_fme(onset_values) * mask # Add positional embedding positions = torch.arange(L, device=token_ids.device).unsqueeze(0).expand(B, -1) x = x + self.pos_embed(positions) return self.embed_dropout(self.embed_ln(x)) # ============================================================================ # σReparam (Spectral Reparameterization) — Training Stability # ============================================================================ class SigmaReparamLinear(nn.Module): """ Linear layer with spectral reparameterization (σReparam). From Zhai et al. (2023) "Stabilizing Transformer Training by Preventing Attention Entropy Collapse" (arXiv:2303.06296). W_hat = (γ / σ(W)) * W where σ(W) is the spectral norm (largest singular value). Prevents attention entropy collapse — the #1 source of training instability. """ def __init__(self, in_features: int, out_features: int, bias: bool = True): super().__init__() self.linear = nn.Linear(in_features, out_features, bias=bias) # Apply spectral normalization self.linear = nn.utils.parametrizations.spectral_norm(self.linear) # Learnable scaling factor (initialized to 1) self.gamma = nn.Parameter(torch.ones(1)) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.gamma * self.linear(x) def make_linear(in_f: int, out_f: int, bias: bool = True, sigma_reparam: bool = True) -> nn.Module: """Factory for linear layers with optional σReparam.""" if sigma_reparam: return SigmaReparamLinear(in_f, out_f, bias) return nn.Linear(in_f, out_f, bias) # ============================================================================ # Pre-LN Transformer Block (for PhraseVAE encoder/decoder) # ============================================================================ class PreLNMultiHeadAttention(nn.Module): """Multi-head attention with Pre-LayerNorm and σReparam.""" def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1, sigma_reparam: bool = True, is_cross_attention: bool = False): super().__init__() assert d_model % n_heads == 0 self.n_heads = n_heads self.d_head = d_model // n_heads self.scale = math.sqrt(self.d_head) self.q_proj = make_linear(d_model, d_model, sigma_reparam=sigma_reparam) self.k_proj = make_linear(d_model, d_model, sigma_reparam=sigma_reparam) self.v_proj = make_linear(d_model, d_model, sigma_reparam=sigma_reparam) self.out_proj = make_linear(d_model, d_model, sigma_reparam=sigma_reparam) self.attn_dropout = nn.Dropout(dropout) self.is_cross_attention = is_cross_attention def forward( self, x: torch.Tensor, context: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, is_causal: bool = False, ) -> torch.Tensor: B, L, D = x.shape q = self.q_proj(x) kv_input = context if self.is_cross_attention and context is not None else x k = self.k_proj(kv_input) v = self.v_proj(kv_input) # Reshape for multi-head q = rearrange(q, 'b l (h d) -> b h l d', h=self.n_heads) k = rearrange(k, 'b s (h d) -> b h s d', h=self.n_heads) v = rearrange(v, 'b s (h d) -> b h s d', h=self.n_heads) # Scaled dot-product attention (using PyTorch's efficient implementation) attn_out = F.scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=is_causal, ) attn_out = rearrange(attn_out, 'b h l d -> b l (h d)') return self.out_proj(attn_out) class PreLNFeedForward(nn.Module): """Feed-forward network with Pre-LN, SiLU activation, and σReparam.""" def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, sigma_reparam: bool = True): super().__init__() self.w1 = make_linear(d_model, d_ff, sigma_reparam=sigma_reparam) self.w2 = make_linear(d_ff, d_model, sigma_reparam=sigma_reparam) self.gate = make_linear(d_model, d_ff, sigma_reparam=sigma_reparam) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: # SwiGLU-style gating (used in LLaMA, Mamba) return self.dropout(self.w2(F.silu(self.gate(x)) * self.w1(x))) class PreLNTransformerBlock(nn.Module): """ Transformer block with Pre-LayerNorm for guaranteed training stability. Pre-LN: x → LayerNorm → Sublayer → + residual (vs Post-LN: x → Sublayer → + residual → LayerNorm, which is UNSTABLE) Pre-LN has analytically bounded gradient norms, eliminates need for LR warmup. """ def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1, sigma_reparam: bool = True, has_cross_attention: bool = False): super().__init__() self.norm1 = nn.LayerNorm(d_model) self.self_attn = PreLNMultiHeadAttention(d_model, n_heads, dropout, sigma_reparam) self.has_cross_attention = has_cross_attention if has_cross_attention: self.norm_cross = nn.LayerNorm(d_model) self.cross_attn = PreLNMultiHeadAttention( d_model, n_heads, dropout, sigma_reparam, is_cross_attention=True ) self.norm2 = nn.LayerNorm(d_model) self.ffn = PreLNFeedForward(d_model, d_ff, dropout, sigma_reparam) def forward( self, x: torch.Tensor, context: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, is_causal: bool = False, ) -> torch.Tensor: # Pre-LN self-attention x = x + self.self_attn(self.norm1(x), mask=mask, is_causal=is_causal) # Pre-LN cross-attention (if applicable) if self.has_cross_attention and context is not None: x = x + self.cross_attn(self.norm_cross(x), context=context) # Pre-LN feed-forward x = x + self.ffn(self.norm2(x)) return x # ============================================================================ # PhraseVAE — Stage 1: Compress REMI+ phrases to latent vectors # ============================================================================ class PhraseVAEEncoder(nn.Module): """ Encode a sequence of REMI+ tokens into a latent vector using multi-query cross-attention bottleneck. Architecture: TransformerEncoder → MultiQueryBottleneck → μ, log_var """ def __init__(self, config: MuseMorphicConfig): super().__init__() self.config = config d = config.d_model # Transformer encoder layers self.layers = nn.ModuleList([ PreLNTransformerBlock( d, config.vae_n_heads, config.vae_d_ff, config.vae_dropout, config.use_sigma_reparam ) for _ in range(config.vae_encoder_layers) ]) self.final_norm = nn.LayerNorm(d) # Multi-query bottleneck (m learned queries) self.query_tokens = nn.Parameter(torch.randn(config.vae_n_queries, d) * 0.02) self.bottleneck_attn = PreLNMultiHeadAttention( d, config.vae_n_heads, config.vae_dropout, config.use_sigma_reparam, is_cross_attention=True ) self.bottleneck_norm = nn.LayerNorm(d) # Project to latent space bottleneck_dim = config.vae_n_queries * d self.to_mu = nn.Linear(bottleneck_dim, config.latent_dim) self.to_log_var = nn.Linear(bottleneck_dim, config.latent_dim) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: x: Embedded tokens (batch, seq_len, d_model) Returns: mu: (batch, latent_dim) log_var: (batch, latent_dim) """ B = x.shape[0] # Encode through transformer layers for layer in self.layers: x = layer(x, mask=mask) x = self.final_norm(x) # Multi-query bottleneck queries = self.query_tokens.unsqueeze(0).expand(B, -1, -1) # (B, m, d) z_queries = self.bottleneck_attn( self.bottleneck_norm(queries), context=x ) # (B, m, d) # Flatten and project z_flat = z_queries.reshape(B, -1) # (B, m*d) mu = self.to_mu(z_flat) log_var = self.to_log_var(z_flat) return mu, log_var class PhraseVAEDecoder(nn.Module): """ Decode a latent vector back to REMI+ token sequence (autoregressive). Architecture: LatentProjection → CrossAttention with latent → AR generation """ def __init__(self, config: MuseMorphicConfig): super().__init__() self.config = config d = config.d_model # Project latent to key/value for cross-attention self.latent_proj = nn.Linear(config.latent_dim, config.vae_n_queries * d) # Token embedding for autoregressive decoding self.token_embed = nn.Embedding(config.vocab_size, d, padding_idx=config.pad_token_id) self.pos_embed = nn.Embedding(config.vae_max_seq_len, d) self.embed_scale = math.sqrt(d) # Decoder layers (with cross-attention to latent) self.layers = nn.ModuleList([ PreLNTransformerBlock( d, config.vae_n_heads, config.vae_d_ff, config.vae_dropout, config.use_sigma_reparam, has_cross_attention=True ) for _ in range(config.vae_decoder_layers) ]) self.final_norm = nn.LayerNorm(d) self.output_proj = nn.Linear(d, config.vocab_size, bias=False) def forward( self, z: torch.Tensor, target_tokens: torch.Tensor, ) -> torch.Tensor: """ Args: z: Latent vector (batch, latent_dim) target_tokens: Target token ids for teacher forcing (batch, seq_len) Returns: logits: (batch, seq_len, vocab_size) """ B, L = target_tokens.shape d = self.config.d_model # Project latent to cross-attention context latent_ctx = self.latent_proj(z).reshape(B, self.config.vae_n_queries, d) # Embed target tokens positions = torch.arange(L, device=target_tokens.device).unsqueeze(0) x = self.token_embed(target_tokens) * self.embed_scale + self.pos_embed(positions) # Decode with causal masking for layer in self.layers: x = layer(x, context=latent_ctx, is_causal=True) x = self.final_norm(x) logits = self.output_proj(x) return logits class PhraseVAE(nn.Module): """ Complete PhraseVAE: Encode REMI+ token phrases → latent vectors → decode back. Three-stage training curriculum: Stage 1: Span-infilling pretraining (learn REMI grammar) Stage 2: Autoencoder (KL weight = 0, pure reconstruction) Stage 3: VAE fine-tuning (KL weight = β = 0.01) """ def __init__(self, config: MuseMorphicConfig): super().__init__() self.config = config # Shared embedding (encoder input) self.embedding = MusicTokenEmbedding(config) # Encoder and decoder self.encoder = PhraseVAEEncoder(config) self.decoder = PhraseVAEDecoder(config) def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor: """Reparameterization trick: z = μ + σ * ε""" if self.training: std = torch.exp(0.5 * log_var) eps = torch.randn_like(std) return mu + std * eps return mu # At inference, just use the mean def encode(self, token_ids: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Encode tokens to latent space.""" x = self.embedding(token_ids, **kwargs) mu, log_var = self.encoder(x) z = self.reparameterize(mu, log_var) return z, mu, log_var def decode(self, z: torch.Tensor, target_tokens: torch.Tensor) -> torch.Tensor: """Decode latent vector to token logits.""" return self.decoder(z, target_tokens) def forward( self, token_ids: torch.Tensor, target_tokens: Optional[torch.Tensor] = None, kl_weight: float = 0.01, **kwargs ) -> Dict[str, torch.Tensor]: """ Full forward pass with loss computation. Args: token_ids: Input tokens (batch, seq_len) target_tokens: Target tokens for reconstruction (batch, seq_len), defaults to token_ids shifted right kl_weight: β for KL loss weighting (0 for AE stage, 0.01 for VAE stage) """ B, L = token_ids.shape if target_tokens is None: target_tokens = token_ids # Encode z, mu, log_var = self.encode(token_ids, **kwargs) # Decode (teacher forcing with shifted input) decoder_input = target_tokens[:, :-1] # Remove last token decoder_target = target_tokens[:, 1:] # Remove first token (shift right) logits = self.decode(z, decoder_input) # Reconstruction loss with label smoothing recon_loss = F.cross_entropy( logits.reshape(-1, self.config.vocab_size), decoder_target.reshape(-1), ignore_index=self.config.pad_token_id, label_smoothing=self.config.label_smoothing, ) # KL divergence (per-sample, averaged) kl_loss = -0.5 * torch.mean( torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1) ) total_loss = recon_loss + kl_weight * kl_loss return { 'loss': total_loss, 'recon_loss': recon_loss, 'kl_loss': kl_loss, 'z': z, 'mu': mu, 'log_var': log_var, 'logits': logits, } # ============================================================================ # Selective SSM (Mamba) Block — O(n) Sequence Modeling # ============================================================================ class SelectiveSSM(nn.Module): """ Selective State Space Model (Mamba core). From Gu & Dao (2023) "Mamba: Linear-Time Sequence Modeling with Selective State Spaces" (arXiv:2312.00752). Key equations: B(x) = Linear_N(x) -- input-dependent C(x) = Linear_N(x) -- input-dependent Δ(x) = softplus(Linear_1(x) + param) -- input-dependent discretization Ā = exp(Δ · A) -- discretized state matrix B̄ = Δ · B(x) -- simplified discretized input matrix h_t = Ā · h_{t-1} + B̄ · x_t -- state update y_t = C(x_t) · h_t -- output Training: parallel scan O(BLD·N) Inference: O(BD·N) per step, state is O(D·N) fixed """ def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2, sigma_reparam: bool = True): super().__init__() self.d_model = d_model self.d_state = d_state self.d_inner = d_model * expand self.d_conv = d_conv # Input projection (expand dimension) self.in_proj = make_linear(d_model, self.d_inner * 2, bias=False, sigma_reparam=sigma_reparam) # Depthwise convolution (local context) self.conv1d = nn.Conv1d( self.d_inner, self.d_inner, kernel_size=d_conv, padding=d_conv - 1, groups=self.d_inner, ) # SSM parameters # A is initialized as negative log-spaced values (HiPPO-inspired) A = torch.arange(1, d_state + 1, dtype=torch.float32).unsqueeze(0).expand(self.d_inner, -1) self.A_log = nn.Parameter(torch.log(A)) # Learn in log space for stability self.D = nn.Parameter(torch.ones(self.d_inner)) # Skip connection # Input-dependent projections self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False) # B, C, dt self.dt_proj = nn.Linear(1, self.d_inner, bias=True) # Initialize dt bias for proper timescales dt_init_std = 0.02 nn.init.uniform_(self.dt_proj.bias, math.log(0.001), math.log(0.1)) # Output projection self.out_proj = make_linear(self.d_inner, d_model, bias=False, sigma_reparam=sigma_reparam) def _ssm_scan(self, x: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, D: torch.Tensor, dt: torch.Tensor) -> torch.Tensor: """ Parallel associative scan for training efficiency. This is a pure PyTorch implementation using sequential scan. For production, use the CUDA kernel from mamba-ssm package. Args: x: (B, L, D_inner) A: (D_inner, N) — state transition (negative, in log space) B: (B, L, N) — input-dependent input matrix C: (B, L, N) — input-dependent output matrix D: (D_inner,) — skip connection dt: (B, L, D_inner) — input-dependent discretization step """ batch, seq_len, d_inner = x.shape N = self.d_state # Discretize: Ā = exp(dt * A), B̄ = dt * B A_discrete = torch.exp(dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)) # (B, L, D, N) B_discrete = dt.unsqueeze(-1) * B.unsqueeze(2) # (B, L, D, N) # Sequential scan (can be parallelized with associative scan) h = torch.zeros(batch, d_inner, N, device=x.device, dtype=x.dtype) outputs = [] for t in range(seq_len): h = A_discrete[:, t] * h + B_discrete[:, t] * x[:, t].unsqueeze(-1) y_t = torch.sum(h * C[:, t].unsqueeze(1), dim=-1) # (B, D) outputs.append(y_t) y = torch.stack(outputs, dim=1) # (B, L, D) # Skip connection y = y + x * D.unsqueeze(0).unsqueeze(0) return y def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (batch, seq_len, d_model) Returns: (batch, seq_len, d_model) """ B, L, D = x.shape # Input projection with gating xz = self.in_proj(x) # (B, L, 2*D_inner) x_inner, z = xz.chunk(2, dim=-1) # Each: (B, L, D_inner) # Depthwise convolution for local context x_conv = self.conv1d(x_inner.transpose(1, 2))[:, :, :L].transpose(1, 2) x_conv = F.silu(x_conv) # Compute input-dependent SSM parameters x_proj = self.x_proj(x_conv) # (B, L, 2N+1) B_param = x_proj[:, :, :self.d_state] # (B, L, N) C_param = x_proj[:, :, self.d_state:2*self.d_state] # (B, L, N) dt_param = x_proj[:, :, -1:] # (B, L, 1) # Discretization step dt = F.softplus(self.dt_proj(dt_param)) # (B, L, D_inner) # Get A from log space A = -torch.exp(self.A_log) # (D_inner, N), negative for stability # Run SSM y = self._ssm_scan(x_conv, A, B_param, C_param, self.D, dt) # Gate and output y = y * F.silu(z) y = self.out_proj(y) return y class MambaBlock(nn.Module): """ Complete Mamba block with Pre-LN and residual connection. x → Pre-LN → SelectiveSSM → + residual """ def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2, dropout: float = 0.1, sigma_reparam: bool = True): super().__init__() self.norm = nn.LayerNorm(d_model) self.ssm = SelectiveSSM(d_model, d_state, d_conv, expand, sigma_reparam) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: return x + self.dropout(self.ssm(self.norm(x))) # ============================================================================ # LatentMamba — Stage 2: Generate phrase latent sequences # ============================================================================ class ControlEmbedding(nn.Module): """ Embed musical control parameters into d_model vectors. Controls: tempo, key, time_signature, note_density, style Each control is embedded and summed, then projected. """ def __init__(self, config: MuseMorphicConfig): super().__init__() d = config.mamba_d_model self.tempo_embed = nn.Embedding(config.n_tempo_bins, d) self.key_embed = nn.Embedding(config.n_key_classes, d) self.time_sig_embed = nn.Embedding(config.n_time_sig_classes, d) self.density_embed = nn.Embedding(config.n_density_bins, d) self.style_embed = nn.Embedding(config.n_style_classes, d) # Project combined controls self.control_proj = nn.Sequential( nn.Linear(d, d), nn.SiLU(), nn.Linear(d, d), ) self.norm = nn.LayerNorm(d) def forward( self, tempo: Optional[torch.Tensor] = None, key: Optional[torch.Tensor] = None, time_sig: Optional[torch.Tensor] = None, density: Optional[torch.Tensor] = None, style: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Returns control embedding of shape (batch, 1, d_model).""" B = tempo.shape[0] if tempo is not None else key.shape[0] d = self.tempo_embed.embedding_dim device = next(self.parameters()).device ctrl = torch.zeros(B, d, device=device) if tempo is not None: ctrl = ctrl + self.tempo_embed(tempo) if key is not None: ctrl = ctrl + self.key_embed(key) if time_sig is not None: ctrl = ctrl + self.time_sig_embed(time_sig) if density is not None: ctrl = ctrl + self.density_embed(density) if style is not None: ctrl = ctrl + self.style_embed(style) ctrl = self.norm(self.control_proj(ctrl)) return ctrl.unsqueeze(1) # (B, 1, d) class LatentMamba(nn.Module): """ Generate sequences of phrase latent vectors using Selective SSM (Mamba). Architecture: Input: [control_embed, z_1, z_2, ..., z_T] → Linear projection (latent_dim → d_model) → MambaBlock × N → Linear projection (d_model → latent_dim) → Output: predicted z_2, z_3, ..., z_{T+1} Complexity: O(T·D·N) — linear in sequence length Inference: O(D·N) per phrase — constant, enables infinite generation """ def __init__(self, config: MuseMorphicConfig): super().__init__() self.config = config d = config.mamba_d_model # Control embedding self.control_embed = ControlEmbedding(config) # Project latent to model dimension self.latent_in = nn.Sequential( nn.Linear(config.latent_dim, d), nn.LayerNorm(d), ) # Positional embedding for phrase positions self.pos_embed = nn.Embedding(config.max_phrases + 1, d) # +1 for control token # Mamba layers self.layers = nn.ModuleList([ MambaBlock( d, config.mamba_d_state, config.mamba_d_conv, config.mamba_expand, config.mamba_dropout, config.use_sigma_reparam ) for _ in range(config.mamba_n_layers) ]) self.final_norm = nn.LayerNorm(d) # Project back to latent space self.latent_out = nn.Linear(d, config.latent_dim) def forward( self, z_seq: torch.Tensor, controls: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.Tensor: """ Args: z_seq: Sequence of phrase latents (batch, n_phrases, latent_dim) controls: Dict of control tensors (each (batch,) integer indices) Returns: z_pred: Predicted next phrase latents (batch, n_phrases, latent_dim) """ B, T, _ = z_seq.shape device = z_seq.device # Project latents to model dimension x = self.latent_in(z_seq) # (B, T, d) # Add control embedding at position 0 if controls is not None: ctrl = self.control_embed(**controls) # (B, 1, d) x = torch.cat([ctrl, x], dim=1) # (B, T+1, d) T_total = T + 1 else: T_total = T # Add positional embeddings positions = torch.arange(T_total, device=device).unsqueeze(0) x = x + self.pos_embed(positions) # Process through Mamba layers for layer in self.layers: x = layer(x) x = self.final_norm(x) # Remove control token position, project to latent space if controls is not None: x = x[:, 1:] # Remove control position z_pred = self.latent_out(x) # (B, T, latent_dim) return z_pred def generate( self, n_phrases: int, controls: Optional[Dict[str, torch.Tensor]] = None, temperature: float = 0.8, batch_size: int = 1, ) -> torch.Tensor: """ Generate a sequence of phrase latents autoregressively. Uses Mamba's recurrent mode for O(1) memory per step. Can generate infinitely without memory growth. """ device = next(self.parameters()).device d = self.config.mamba_d_model # Initialize with control embedding or zeros if controls is not None: z_init = self.control_embed(**controls) # (B, 1, d) else: z_init = torch.zeros(batch_size, 1, d, device=device) # Generate phrase latents one by one generated = [] x = z_init + self.pos_embed(torch.tensor([0], device=device)) # Initialize Mamba states states = [torch.zeros(batch_size, self.config.mamba_d_model * self.config.mamba_expand, self.config.mamba_d_state, device=device) for _ in range(self.config.mamba_n_layers)] for t in range(n_phrases): h = x for i, layer in enumerate(self.layers): h = layer.norm(h) # Note: In production, use Mamba's step() for true O(1) inference h = layer.ssm(h) # Simplified; real impl would update states h = x + layer.dropout(h - x + h) # residual x = h h = self.final_norm(h) z_t = self.latent_out(h[:, -1:]) # (B, 1, latent_dim) # Add noise for diversity (controlled by temperature) if temperature > 0: z_t = z_t + temperature * torch.randn_like(z_t) generated.append(z_t) # Prepare next input x = self.latent_in(z_t) + self.pos_embed( torch.tensor([t + 1], device=device).clamp(max=self.config.max_phrases - 1) ) return torch.cat(generated, dim=1) # (B, n_phrases, latent_dim) # ============================================================================ # Complete MuseMorphic Model # ============================================================================ class MuseMorphic(nn.Module): """ Complete MuseMorphic model combining PhraseVAE and LatentMamba. Two-stage training: Stage 1: Train PhraseVAE (encode/decode individual phrases) Stage 2: Freeze PhraseVAE encoder, train LatentMamba on latent sequences Inference pipeline: Controls → LatentMamba.generate() → PhraseVAE.decode() → REMI+ tokens → MIDI """ def __init__(self, config: MuseMorphicConfig): super().__init__() self.config = config self.phrase_vae = PhraseVAE(config) self.latent_mamba = LatentMamba(config) def encode_phrases(self, phrases: List[torch.Tensor], **kwargs) -> torch.Tensor: """ Encode a list of phrase token sequences to latent vectors. Args: phrases: List of (batch, phrase_len) token tensors Returns: z_seq: (batch, n_phrases, latent_dim) """ z_list = [] self.phrase_vae.eval() with torch.no_grad(): for phrase_tokens in phrases: z, _, _ = self.phrase_vae.encode(phrase_tokens, **kwargs) z_list.append(z.unsqueeze(1)) return torch.cat(z_list, dim=1) def decode_phrases(self, z_seq: torch.Tensor, max_len: int = 256) -> List[torch.Tensor]: """ Decode latent vectors back to token sequences. Args: z_seq: (batch, n_phrases, latent_dim) Returns: List of (batch, phrase_len) token tensors """ B, T, _ = z_seq.shape decoded = [] self.phrase_vae.eval() with torch.no_grad(): for t in range(T): z = z_seq[:, t] # Autoregressive decoding tokens = self._ar_decode(z, max_len) decoded.append(tokens) return decoded def _ar_decode(self, z: torch.Tensor, max_len: int) -> torch.Tensor: """Autoregressive decoding from latent vector.""" B = z.shape[0] device = z.device # Start with BOS token tokens = torch.full((B, 1), self.config.bos_token_id, dtype=torch.long, device=device) for _ in range(max_len - 1): logits = self.phrase_vae.decode(z, tokens) next_token_logits = logits[:, -1, :] # (B, vocab_size) # Greedy or sample next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) tokens = torch.cat([tokens, next_token], dim=1) # Stop if all sequences have generated EOS if (next_token == self.config.eos_token_id).all(): break return tokens @torch.no_grad() def generate( self, n_phrases: int = 32, controls: Optional[Dict[str, torch.Tensor]] = None, temperature: float = 0.8, max_phrase_len: int = 256, batch_size: int = 1, ) -> List[torch.Tensor]: """ Full generation pipeline. Controls → LatentMamba → PhraseVAE.decode → REMI+ tokens Memory: O(D·N) fixed during generation — truly infinite. """ self.eval() # Stage 2: Generate phrase latent sequence z_seq = self.latent_mamba.generate( n_phrases, controls, temperature, batch_size ) # Stage 1 (decode): Latent → REMI+ tokens decoded_phrases = self.decode_phrases(z_seq, max_phrase_len) return decoded_phrases def count_parameters(self) -> Dict[str, int]: """Count parameters by component.""" vae_enc = sum(p.numel() for p in self.phrase_vae.encoder.parameters()) vae_dec = sum(p.numel() for p in self.phrase_vae.decoder.parameters()) vae_emb = sum(p.numel() for p in self.phrase_vae.embedding.parameters()) mamba = sum(p.numel() for p in self.latent_mamba.parameters()) total = sum(p.numel() for p in self.parameters()) return { 'vae_encoder': vae_enc, 'vae_decoder': vae_dec, 'vae_embedding': vae_emb, 'latent_mamba': mamba, 'total': total, } def get_vram_estimate(self, batch_size: int = 1, seq_len: int = 256, dtype_bytes: int = 2) -> Dict[str, str]: """Estimate VRAM usage.""" params = self.count_parameters() # Parameters param_mem = params['total'] * dtype_bytes # Activations (rough estimate: 2x parameters for forward pass) act_mem = param_mem * 2 # Optimizer states (AdamW: 2 states per param) opt_mem = params['total'] * 4 * 2 # FP32 optimizer states training_mem = param_mem + act_mem + opt_mem inference_mem = param_mem + act_mem // 4 # Much less activations return { 'parameters_mb': f"{param_mem / 1e6:.1f} MB", 'training_vram_mb': f"{training_mem / 1e6:.1f} MB", 'inference_vram_mb': f"{inference_mem / 1e6:.1f} MB", } # ============================================================================ # ZClip — Adaptive Gradient Clipping # ============================================================================ class ZClip: """ Adaptive gradient clipping via z-score thresholding. From ZClip (2025) "Adaptive Spike Mitigation for LLM Pre-Training" (arXiv:2504.02507). Only clips genuine gradient spikes, not normal gradients. Optimal z_thresh: 2.0-3.0 (Table 6 in paper). """ def __init__(self, z_thresh: float = 2.5, alpha: float = 0.99): self.z_thresh = z_thresh self.alpha = alpha self.mu = 0.0 self.var = 1.0 self.initialized = False def __call__(self, model: nn.Module) -> float: """Clip gradients and return the original norm.""" total_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), float('inf') ).item() if not self.initialized: self.mu = total_norm self.var = 0.0 self.initialized = True return total_norm # Compute adaptive threshold sigma = max(math.sqrt(self.var), 1e-8) threshold = self.mu + self.z_thresh * sigma # Clip only if genuine spike if total_norm > threshold: torch.nn.utils.clip_grad_norm_(model.parameters(), threshold) # Update EMA statistics self.mu = self.alpha * self.mu + (1 - self.alpha) * total_norm self.var = self.alpha * self.var + (1 - self.alpha) * (total_norm - self.mu) ** 2 return total_norm # ============================================================================ # Utility: Model summary # ============================================================================ def model_summary(config: Optional[MuseMorphicConfig] = None): """Print model summary with parameter counts and VRAM estimates.""" if config is None: config = MuseMorphicConfig() model = MuseMorphic(config) params = model.count_parameters() vram = model.get_vram_estimate() print("=" * 60) print("MuseMorphic Model Summary") print("=" * 60) print(f"\nParameter Counts:") for name, count in params.items(): print(f" {name:20s}: {count:>10,d} ({count/1e6:.2f}M)") print(f"\nVRAM Estimates (BF16):") for name, est in vram.items(): print(f" {name:20s}: {est}") print(f"\nArchitecture:") print(f" d_model: {config.d_model}") print(f" Vocab size: {config.vocab_size}") print(f" Latent dim: {config.latent_dim}") print(f" VAE layers: {config.vae_encoder_layers}+{config.vae_decoder_layers}") print(f" Mamba layers: {config.mamba_n_layers}") print(f" Mamba state dim: {config.mamba_d_state}") print(f" Max phrase tokens: {config.vae_max_seq_len}") print(f" Max phrases: {config.max_phrases}") print("=" * 60) return model if __name__ == "__main__": model = model_summary()