asdf98's picture
Upload musemorphic/model.py
f88bd6f verified
raw
history blame
44.6 kB
"""
MuseMorphic: Lightweight Consumer-Grade MIDI Generation Architecture
====================================================================
A novel two-stage hierarchical architecture combining:
Stage 1 - PhraseVAE: Compress REMI+ tokens → 64-dim latent vectors
Stage 2 - LatentMamba: Generate latent sequences with O(n) complexity
Key innovations:
- O(n) complexity everywhere (Selective SSM backbone)
- Music-native FME embeddings (translational invariance, transposability)
- ~33M parameters, trains on free Colab T4, inference <1GB VRAM
- Controllable via multi-attribute conditioning
- Infinite generation via fixed-size recurrent state
- Training stability by design (σReparam, ZClip, Pre-LN, BF16, label smoothing)
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass, field
from typing import Optional, List, Tuple, Dict
from einops import rearrange
# ============================================================================
# Configuration
# ============================================================================
@dataclass
class MuseMorphicConfig:
"""Complete configuration for MuseMorphic architecture."""
# --- Tokenizer ---
vocab_size: int = 8192 # BPE vocabulary size
pad_token_id: int = 0
bos_token_id: int = 1
eos_token_id: int = 2
mask_token_id: int = 3
# --- FME Embeddings ---
d_model: int = 256 # Model dimension throughout
fme_base_pitch: float = 10000.0 # Base B for pitch FME
fme_base_duration: float = 1000.0 # Base B for duration FME
fme_base_onset: float = 5000.0 # Base B for onset FME
use_log_frequency: bool = True # Encode pitch as log-frequency
# --- PhraseVAE ---
vae_encoder_layers: int = 3
vae_decoder_layers: int = 3
vae_n_heads: int = 4
vae_d_ff: int = 512 # Feed-forward dim
vae_n_queries: int = 4 # Multi-query bottleneck queries
latent_dim: int = 64 # VAE latent dimension
vae_dropout: float = 0.1
vae_max_seq_len: int = 256 # Max tokens per phrase
kl_beta: float = 0.01 # KL weight (low to prevent posterior collapse)
label_smoothing: float = 0.1
# --- LatentMamba ---
mamba_d_model: int = 256
mamba_n_layers: int = 8
mamba_d_state: int = 16 # SSM state dimension N
mamba_d_conv: int = 4 # Local convolution width
mamba_expand: int = 2 # Inner dimension expansion factor
mamba_dropout: float = 0.1
max_phrases: int = 512 # Max phrases in a piece
# --- Control ---
n_tempo_bins: int = 45 # (30-210 BPM, step 4)
n_key_classes: int = 24 # 12 keys × major/minor
n_time_sig_classes: int = 8 # Common time signatures
n_density_bins: int = 10 # Note density percentile bins
n_style_classes: int = 32 # Style/genre categories
# --- Training Stability ---
use_sigma_reparam: bool = True
use_pre_ln: bool = True
zclip_z_thresh: float = 2.5
zclip_alpha: float = 0.99
# --- Training ---
learning_rate: float = 3e-4
weight_decay: float = 0.01
warmup_steps: int = 500
max_steps: int = 100000
batch_size: int = 32
gradient_accumulation_steps: int = 1
# ============================================================================
# Fundamental Music Embedding (FME) — Physics-Aware
# ============================================================================
class FundamentalMusicEmbedding(nn.Module):
"""
Translational-invariant, transposable pitch/duration/onset embedding.
From Liang et al. (2022) "Domain-Knowledge-Inspired Music Embedding"
Extended with log-frequency pitch encoding for harmonic series awareness.
Properties:
1. |f_a - f_b| = |f_c - f_d| => ||FME(f_a) - FME(f_b)|| = ||FME(f_c) - FME(f_d)||
2. Transposition is a linear operation in embedding space
3. Pitch, duration, onset are orthogonal via different base B values
"""
def __init__(self, d_model: int, base_B: float = 10000.0, use_log_freq: bool = False):
super().__init__()
self.d_model = d_model
self.use_log_freq = use_log_freq
half_d = d_model // 2
# Exponentially decaying frequencies
k = torch.arange(half_d, dtype=torch.float32)
w_k = base_B ** (-2.0 * k / d_model)
self.register_buffer('w_k', w_k)
# Learnable biases (enable fine-tuning of embedding geometry)
self.b_sin = nn.Parameter(torch.zeros(half_d))
self.b_cos = nn.Parameter(torch.zeros(half_d))
def forward(self, values: torch.Tensor) -> torch.Tensor:
"""
Args:
values: Integer or float values, shape (batch, seq_len)
Returns:
Embedding, shape (batch, seq_len, d_model)
"""
f = values.float()
if self.use_log_freq:
# Convert MIDI pitch to log-frequency (respects harmonic series)
# f_hz = 440 * 2^((p-69)/12), log2(f_hz) = log2(440) + (p-69)/12
f = torch.log2(440.0 * (2.0 ** ((f - 69.0) / 12.0)) + 1e-8)
f = f.unsqueeze(-1) # (B, L, 1)
sin_enc = torch.sin(self.w_k * f) + self.b_sin # (B, L, d/2)
cos_enc = torch.cos(self.w_k * f) + self.b_cos # (B, L, d/2)
return torch.cat([sin_enc, cos_enc], dim=-1) # (B, L, d)
class MusicTokenEmbedding(nn.Module):
"""
Combined embedding for REMI+ tokens using FME for musically-meaningful tokens
and standard learned embeddings for structural tokens.
"""
def __init__(self, config: MuseMorphicConfig):
super().__init__()
self.config = config
d = config.d_model
# Standard token embedding (for BPE tokens)
self.token_embed = nn.Embedding(config.vocab_size, d, padding_idx=config.pad_token_id)
# FME components (used as additive bias for pitch/duration/onset tokens)
self.pitch_fme = FundamentalMusicEmbedding(d, config.fme_base_pitch, config.use_log_frequency)
self.duration_fme = FundamentalMusicEmbedding(d, config.fme_base_duration, False)
self.onset_fme = FundamentalMusicEmbedding(d, config.fme_base_onset, False)
# Positional embedding (within-bar position, learnable)
self.pos_embed = nn.Embedding(config.vae_max_seq_len, d)
# Layer norm for embedding output stability
self.embed_ln = nn.LayerNorm(d)
self.embed_dropout = nn.Dropout(config.vae_dropout)
# Scale factor
self.scale = math.sqrt(d)
def forward(
self,
token_ids: torch.Tensor,
pitch_values: Optional[torch.Tensor] = None,
duration_values: Optional[torch.Tensor] = None,
onset_values: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
token_ids: (batch, seq_len) BPE token indices
pitch_values: (batch, seq_len) MIDI pitch values (0 where not applicable)
duration_values: (batch, seq_len) duration ticks (0 where not applicable)
onset_values: (batch, seq_len) onset positions (0 where not applicable)
"""
B, L = token_ids.shape
# Base token embedding
x = self.token_embed(token_ids) * self.scale
# Add FME for musically-meaningful attributes (when available)
if pitch_values is not None:
mask = (pitch_values > 0).float().unsqueeze(-1)
x = x + self.pitch_fme(pitch_values) * mask
if duration_values is not None:
mask = (duration_values > 0).float().unsqueeze(-1)
x = x + self.duration_fme(duration_values) * mask
if onset_values is not None:
mask = (onset_values > 0).float().unsqueeze(-1)
x = x + self.onset_fme(onset_values) * mask
# Add positional embedding
positions = torch.arange(L, device=token_ids.device).unsqueeze(0).expand(B, -1)
x = x + self.pos_embed(positions)
return self.embed_dropout(self.embed_ln(x))
# ============================================================================
# σReparam (Spectral Reparameterization) — Training Stability
# ============================================================================
class SigmaReparamLinear(nn.Module):
"""
Linear layer with spectral reparameterization (σReparam).
From Zhai et al. (2023) "Stabilizing Transformer Training by Preventing
Attention Entropy Collapse" (arXiv:2303.06296).
W_hat = (γ / σ(W)) * W
where σ(W) is the spectral norm (largest singular value).
Prevents attention entropy collapse — the #1 source of training instability.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = True):
super().__init__()
self.linear = nn.Linear(in_features, out_features, bias=bias)
# Apply spectral normalization
self.linear = nn.utils.parametrizations.spectral_norm(self.linear)
# Learnable scaling factor (initialized to 1)
self.gamma = nn.Parameter(torch.ones(1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.gamma * self.linear(x)
def make_linear(in_f: int, out_f: int, bias: bool = True, sigma_reparam: bool = True) -> nn.Module:
"""Factory for linear layers with optional σReparam."""
if sigma_reparam:
return SigmaReparamLinear(in_f, out_f, bias)
return nn.Linear(in_f, out_f, bias)
# ============================================================================
# Pre-LN Transformer Block (for PhraseVAE encoder/decoder)
# ============================================================================
class PreLNMultiHeadAttention(nn.Module):
"""Multi-head attention with Pre-LayerNorm and σReparam."""
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1,
sigma_reparam: bool = True, is_cross_attention: bool = False):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.scale = math.sqrt(self.d_head)
self.q_proj = make_linear(d_model, d_model, sigma_reparam=sigma_reparam)
self.k_proj = make_linear(d_model, d_model, sigma_reparam=sigma_reparam)
self.v_proj = make_linear(d_model, d_model, sigma_reparam=sigma_reparam)
self.out_proj = make_linear(d_model, d_model, sigma_reparam=sigma_reparam)
self.attn_dropout = nn.Dropout(dropout)
self.is_cross_attention = is_cross_attention
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
) -> torch.Tensor:
B, L, D = x.shape
q = self.q_proj(x)
kv_input = context if self.is_cross_attention and context is not None else x
k = self.k_proj(kv_input)
v = self.v_proj(kv_input)
# Reshape for multi-head
q = rearrange(q, 'b l (h d) -> b h l d', h=self.n_heads)
k = rearrange(k, 'b s (h d) -> b h s d', h=self.n_heads)
v = rearrange(v, 'b s (h d) -> b h s d', h=self.n_heads)
# Scaled dot-product attention (using PyTorch's efficient implementation)
attn_out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=mask,
dropout_p=self.attn_dropout.p if self.training else 0.0,
is_causal=is_causal,
)
attn_out = rearrange(attn_out, 'b h l d -> b l (h d)')
return self.out_proj(attn_out)
class PreLNFeedForward(nn.Module):
"""Feed-forward network with Pre-LN, SiLU activation, and σReparam."""
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1,
sigma_reparam: bool = True):
super().__init__()
self.w1 = make_linear(d_model, d_ff, sigma_reparam=sigma_reparam)
self.w2 = make_linear(d_ff, d_model, sigma_reparam=sigma_reparam)
self.gate = make_linear(d_model, d_ff, sigma_reparam=sigma_reparam)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# SwiGLU-style gating (used in LLaMA, Mamba)
return self.dropout(self.w2(F.silu(self.gate(x)) * self.w1(x)))
class PreLNTransformerBlock(nn.Module):
"""
Transformer block with Pre-LayerNorm for guaranteed training stability.
Pre-LN: x → LayerNorm → Sublayer → + residual
(vs Post-LN: x → Sublayer → + residual → LayerNorm, which is UNSTABLE)
Pre-LN has analytically bounded gradient norms, eliminates need for LR warmup.
"""
def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1,
sigma_reparam: bool = True, has_cross_attention: bool = False):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.self_attn = PreLNMultiHeadAttention(d_model, n_heads, dropout, sigma_reparam)
self.has_cross_attention = has_cross_attention
if has_cross_attention:
self.norm_cross = nn.LayerNorm(d_model)
self.cross_attn = PreLNMultiHeadAttention(
d_model, n_heads, dropout, sigma_reparam, is_cross_attention=True
)
self.norm2 = nn.LayerNorm(d_model)
self.ffn = PreLNFeedForward(d_model, d_ff, dropout, sigma_reparam)
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
) -> torch.Tensor:
# Pre-LN self-attention
x = x + self.self_attn(self.norm1(x), mask=mask, is_causal=is_causal)
# Pre-LN cross-attention (if applicable)
if self.has_cross_attention and context is not None:
x = x + self.cross_attn(self.norm_cross(x), context=context)
# Pre-LN feed-forward
x = x + self.ffn(self.norm2(x))
return x
# ============================================================================
# PhraseVAE — Stage 1: Compress REMI+ phrases to latent vectors
# ============================================================================
class PhraseVAEEncoder(nn.Module):
"""
Encode a sequence of REMI+ tokens into a latent vector using
multi-query cross-attention bottleneck.
Architecture: TransformerEncoder → MultiQueryBottleneck → μ, log_var
"""
def __init__(self, config: MuseMorphicConfig):
super().__init__()
self.config = config
d = config.d_model
# Transformer encoder layers
self.layers = nn.ModuleList([
PreLNTransformerBlock(
d, config.vae_n_heads, config.vae_d_ff,
config.vae_dropout, config.use_sigma_reparam
)
for _ in range(config.vae_encoder_layers)
])
self.final_norm = nn.LayerNorm(d)
# Multi-query bottleneck (m learned queries)
self.query_tokens = nn.Parameter(torch.randn(config.vae_n_queries, d) * 0.02)
self.bottleneck_attn = PreLNMultiHeadAttention(
d, config.vae_n_heads, config.vae_dropout,
config.use_sigma_reparam, is_cross_attention=True
)
self.bottleneck_norm = nn.LayerNorm(d)
# Project to latent space
bottleneck_dim = config.vae_n_queries * d
self.to_mu = nn.Linear(bottleneck_dim, config.latent_dim)
self.to_log_var = nn.Linear(bottleneck_dim, config.latent_dim)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Embedded tokens (batch, seq_len, d_model)
Returns:
mu: (batch, latent_dim)
log_var: (batch, latent_dim)
"""
B = x.shape[0]
# Encode through transformer layers
for layer in self.layers:
x = layer(x, mask=mask)
x = self.final_norm(x)
# Multi-query bottleneck
queries = self.query_tokens.unsqueeze(0).expand(B, -1, -1) # (B, m, d)
z_queries = self.bottleneck_attn(
self.bottleneck_norm(queries), context=x
) # (B, m, d)
# Flatten and project
z_flat = z_queries.reshape(B, -1) # (B, m*d)
mu = self.to_mu(z_flat)
log_var = self.to_log_var(z_flat)
return mu, log_var
class PhraseVAEDecoder(nn.Module):
"""
Decode a latent vector back to REMI+ token sequence (autoregressive).
Architecture: LatentProjection → CrossAttention with latent → AR generation
"""
def __init__(self, config: MuseMorphicConfig):
super().__init__()
self.config = config
d = config.d_model
# Project latent to key/value for cross-attention
self.latent_proj = nn.Linear(config.latent_dim, config.vae_n_queries * d)
# Token embedding for autoregressive decoding
self.token_embed = nn.Embedding(config.vocab_size, d, padding_idx=config.pad_token_id)
self.pos_embed = nn.Embedding(config.vae_max_seq_len, d)
self.embed_scale = math.sqrt(d)
# Decoder layers (with cross-attention to latent)
self.layers = nn.ModuleList([
PreLNTransformerBlock(
d, config.vae_n_heads, config.vae_d_ff,
config.vae_dropout, config.use_sigma_reparam,
has_cross_attention=True
)
for _ in range(config.vae_decoder_layers)
])
self.final_norm = nn.LayerNorm(d)
self.output_proj = nn.Linear(d, config.vocab_size, bias=False)
def forward(
self,
z: torch.Tensor,
target_tokens: torch.Tensor,
) -> torch.Tensor:
"""
Args:
z: Latent vector (batch, latent_dim)
target_tokens: Target token ids for teacher forcing (batch, seq_len)
Returns:
logits: (batch, seq_len, vocab_size)
"""
B, L = target_tokens.shape
d = self.config.d_model
# Project latent to cross-attention context
latent_ctx = self.latent_proj(z).reshape(B, self.config.vae_n_queries, d)
# Embed target tokens
positions = torch.arange(L, device=target_tokens.device).unsqueeze(0)
x = self.token_embed(target_tokens) * self.embed_scale + self.pos_embed(positions)
# Decode with causal masking
for layer in self.layers:
x = layer(x, context=latent_ctx, is_causal=True)
x = self.final_norm(x)
logits = self.output_proj(x)
return logits
class PhraseVAE(nn.Module):
"""
Complete PhraseVAE: Encode REMI+ token phrases → latent vectors → decode back.
Three-stage training curriculum:
Stage 1: Span-infilling pretraining (learn REMI grammar)
Stage 2: Autoencoder (KL weight = 0, pure reconstruction)
Stage 3: VAE fine-tuning (KL weight = β = 0.01)
"""
def __init__(self, config: MuseMorphicConfig):
super().__init__()
self.config = config
# Shared embedding (encoder input)
self.embedding = MusicTokenEmbedding(config)
# Encoder and decoder
self.encoder = PhraseVAEEncoder(config)
self.decoder = PhraseVAEDecoder(config)
def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
"""Reparameterization trick: z = μ + σ * ε"""
if self.training:
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + std * eps
return mu # At inference, just use the mean
def encode(self, token_ids: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Encode tokens to latent space."""
x = self.embedding(token_ids, **kwargs)
mu, log_var = self.encoder(x)
z = self.reparameterize(mu, log_var)
return z, mu, log_var
def decode(self, z: torch.Tensor, target_tokens: torch.Tensor) -> torch.Tensor:
"""Decode latent vector to token logits."""
return self.decoder(z, target_tokens)
def forward(
self,
token_ids: torch.Tensor,
target_tokens: Optional[torch.Tensor] = None,
kl_weight: float = 0.01,
**kwargs
) -> Dict[str, torch.Tensor]:
"""
Full forward pass with loss computation.
Args:
token_ids: Input tokens (batch, seq_len)
target_tokens: Target tokens for reconstruction (batch, seq_len),
defaults to token_ids shifted right
kl_weight: β for KL loss weighting (0 for AE stage, 0.01 for VAE stage)
"""
B, L = token_ids.shape
if target_tokens is None:
target_tokens = token_ids
# Encode
z, mu, log_var = self.encode(token_ids, **kwargs)
# Decode (teacher forcing with shifted input)
decoder_input = target_tokens[:, :-1] # Remove last token
decoder_target = target_tokens[:, 1:] # Remove first token (shift right)
logits = self.decode(z, decoder_input)
# Reconstruction loss with label smoothing
recon_loss = F.cross_entropy(
logits.reshape(-1, self.config.vocab_size),
decoder_target.reshape(-1),
ignore_index=self.config.pad_token_id,
label_smoothing=self.config.label_smoothing,
)
# KL divergence (per-sample, averaged)
kl_loss = -0.5 * torch.mean(
torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1)
)
total_loss = recon_loss + kl_weight * kl_loss
return {
'loss': total_loss,
'recon_loss': recon_loss,
'kl_loss': kl_loss,
'z': z,
'mu': mu,
'log_var': log_var,
'logits': logits,
}
# ============================================================================
# Selective SSM (Mamba) Block — O(n) Sequence Modeling
# ============================================================================
class SelectiveSSM(nn.Module):
"""
Selective State Space Model (Mamba core).
From Gu & Dao (2023) "Mamba: Linear-Time Sequence Modeling with Selective
State Spaces" (arXiv:2312.00752).
Key equations:
B(x) = Linear_N(x) -- input-dependent
C(x) = Linear_N(x) -- input-dependent
Δ(x) = softplus(Linear_1(x) + param) -- input-dependent discretization
Ā = exp(Δ · A) -- discretized state matrix
B̄ = Δ · B(x) -- simplified discretized input matrix
h_t = Ā · h_{t-1} + B̄ · x_t -- state update
y_t = C(x_t) · h_t -- output
Training: parallel scan O(BLD·N)
Inference: O(BD·N) per step, state is O(D·N) fixed
"""
def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4,
expand: int = 2, sigma_reparam: bool = True):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_inner = d_model * expand
self.d_conv = d_conv
# Input projection (expand dimension)
self.in_proj = make_linear(d_model, self.d_inner * 2, bias=False, sigma_reparam=sigma_reparam)
# Depthwise convolution (local context)
self.conv1d = nn.Conv1d(
self.d_inner, self.d_inner,
kernel_size=d_conv,
padding=d_conv - 1,
groups=self.d_inner,
)
# SSM parameters
# A is initialized as negative log-spaced values (HiPPO-inspired)
A = torch.arange(1, d_state + 1, dtype=torch.float32).unsqueeze(0).expand(self.d_inner, -1)
self.A_log = nn.Parameter(torch.log(A)) # Learn in log space for stability
self.D = nn.Parameter(torch.ones(self.d_inner)) # Skip connection
# Input-dependent projections
self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False) # B, C, dt
self.dt_proj = nn.Linear(1, self.d_inner, bias=True)
# Initialize dt bias for proper timescales
dt_init_std = 0.02
nn.init.uniform_(self.dt_proj.bias, math.log(0.001), math.log(0.1))
# Output projection
self.out_proj = make_linear(self.d_inner, d_model, bias=False, sigma_reparam=sigma_reparam)
def _ssm_scan(self, x: torch.Tensor, A: torch.Tensor, B: torch.Tensor,
C: torch.Tensor, D: torch.Tensor, dt: torch.Tensor) -> torch.Tensor:
"""
Parallel associative scan for training efficiency.
This is a pure PyTorch implementation using sequential scan.
For production, use the CUDA kernel from mamba-ssm package.
Args:
x: (B, L, D_inner)
A: (D_inner, N) — state transition (negative, in log space)
B: (B, L, N) — input-dependent input matrix
C: (B, L, N) — input-dependent output matrix
D: (D_inner,) — skip connection
dt: (B, L, D_inner) — input-dependent discretization step
"""
batch, seq_len, d_inner = x.shape
N = self.d_state
# Discretize: Ā = exp(dt * A), B̄ = dt * B
A_discrete = torch.exp(dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)) # (B, L, D, N)
B_discrete = dt.unsqueeze(-1) * B.unsqueeze(2) # (B, L, D, N)
# Sequential scan (can be parallelized with associative scan)
h = torch.zeros(batch, d_inner, N, device=x.device, dtype=x.dtype)
outputs = []
for t in range(seq_len):
h = A_discrete[:, t] * h + B_discrete[:, t] * x[:, t].unsqueeze(-1)
y_t = torch.sum(h * C[:, t].unsqueeze(1), dim=-1) # (B, D)
outputs.append(y_t)
y = torch.stack(outputs, dim=1) # (B, L, D)
# Skip connection
y = y + x * D.unsqueeze(0).unsqueeze(0)
return y
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, seq_len, d_model)
Returns:
(batch, seq_len, d_model)
"""
B, L, D = x.shape
# Input projection with gating
xz = self.in_proj(x) # (B, L, 2*D_inner)
x_inner, z = xz.chunk(2, dim=-1) # Each: (B, L, D_inner)
# Depthwise convolution for local context
x_conv = self.conv1d(x_inner.transpose(1, 2))[:, :, :L].transpose(1, 2)
x_conv = F.silu(x_conv)
# Compute input-dependent SSM parameters
x_proj = self.x_proj(x_conv) # (B, L, 2N+1)
B_param = x_proj[:, :, :self.d_state] # (B, L, N)
C_param = x_proj[:, :, self.d_state:2*self.d_state] # (B, L, N)
dt_param = x_proj[:, :, -1:] # (B, L, 1)
# Discretization step
dt = F.softplus(self.dt_proj(dt_param)) # (B, L, D_inner)
# Get A from log space
A = -torch.exp(self.A_log) # (D_inner, N), negative for stability
# Run SSM
y = self._ssm_scan(x_conv, A, B_param, C_param, self.D, dt)
# Gate and output
y = y * F.silu(z)
y = self.out_proj(y)
return y
class MambaBlock(nn.Module):
"""
Complete Mamba block with Pre-LN and residual connection.
x → Pre-LN → SelectiveSSM → + residual
"""
def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4,
expand: int = 2, dropout: float = 0.1, sigma_reparam: bool = True):
super().__init__()
self.norm = nn.LayerNorm(d_model)
self.ssm = SelectiveSSM(d_model, d_state, d_conv, expand, sigma_reparam)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.dropout(self.ssm(self.norm(x)))
# ============================================================================
# LatentMamba — Stage 2: Generate phrase latent sequences
# ============================================================================
class ControlEmbedding(nn.Module):
"""
Embed musical control parameters into d_model vectors.
Controls: tempo, key, time_signature, note_density, style
Each control is embedded and summed, then projected.
"""
def __init__(self, config: MuseMorphicConfig):
super().__init__()
d = config.mamba_d_model
self.tempo_embed = nn.Embedding(config.n_tempo_bins, d)
self.key_embed = nn.Embedding(config.n_key_classes, d)
self.time_sig_embed = nn.Embedding(config.n_time_sig_classes, d)
self.density_embed = nn.Embedding(config.n_density_bins, d)
self.style_embed = nn.Embedding(config.n_style_classes, d)
# Project combined controls
self.control_proj = nn.Sequential(
nn.Linear(d, d),
nn.SiLU(),
nn.Linear(d, d),
)
self.norm = nn.LayerNorm(d)
def forward(
self,
tempo: Optional[torch.Tensor] = None,
key: Optional[torch.Tensor] = None,
time_sig: Optional[torch.Tensor] = None,
density: Optional[torch.Tensor] = None,
style: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Returns control embedding of shape (batch, 1, d_model)."""
B = tempo.shape[0] if tempo is not None else key.shape[0]
d = self.tempo_embed.embedding_dim
device = next(self.parameters()).device
ctrl = torch.zeros(B, d, device=device)
if tempo is not None:
ctrl = ctrl + self.tempo_embed(tempo)
if key is not None:
ctrl = ctrl + self.key_embed(key)
if time_sig is not None:
ctrl = ctrl + self.time_sig_embed(time_sig)
if density is not None:
ctrl = ctrl + self.density_embed(density)
if style is not None:
ctrl = ctrl + self.style_embed(style)
ctrl = self.norm(self.control_proj(ctrl))
return ctrl.unsqueeze(1) # (B, 1, d)
class LatentMamba(nn.Module):
"""
Generate sequences of phrase latent vectors using Selective SSM (Mamba).
Architecture:
Input: [control_embed, z_1, z_2, ..., z_T]
→ Linear projection (latent_dim → d_model)
→ MambaBlock × N
→ Linear projection (d_model → latent_dim)
→ Output: predicted z_2, z_3, ..., z_{T+1}
Complexity: O(T·D·N) — linear in sequence length
Inference: O(D·N) per phrase — constant, enables infinite generation
"""
def __init__(self, config: MuseMorphicConfig):
super().__init__()
self.config = config
d = config.mamba_d_model
# Control embedding
self.control_embed = ControlEmbedding(config)
# Project latent to model dimension
self.latent_in = nn.Sequential(
nn.Linear(config.latent_dim, d),
nn.LayerNorm(d),
)
# Positional embedding for phrase positions
self.pos_embed = nn.Embedding(config.max_phrases + 1, d) # +1 for control token
# Mamba layers
self.layers = nn.ModuleList([
MambaBlock(
d, config.mamba_d_state, config.mamba_d_conv,
config.mamba_expand, config.mamba_dropout,
config.use_sigma_reparam
)
for _ in range(config.mamba_n_layers)
])
self.final_norm = nn.LayerNorm(d)
# Project back to latent space
self.latent_out = nn.Linear(d, config.latent_dim)
def forward(
self,
z_seq: torch.Tensor,
controls: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
"""
Args:
z_seq: Sequence of phrase latents (batch, n_phrases, latent_dim)
controls: Dict of control tensors (each (batch,) integer indices)
Returns:
z_pred: Predicted next phrase latents (batch, n_phrases, latent_dim)
"""
B, T, _ = z_seq.shape
device = z_seq.device
# Project latents to model dimension
x = self.latent_in(z_seq) # (B, T, d)
# Add control embedding at position 0
if controls is not None:
ctrl = self.control_embed(**controls) # (B, 1, d)
x = torch.cat([ctrl, x], dim=1) # (B, T+1, d)
T_total = T + 1
else:
T_total = T
# Add positional embeddings
positions = torch.arange(T_total, device=device).unsqueeze(0)
x = x + self.pos_embed(positions)
# Process through Mamba layers
for layer in self.layers:
x = layer(x)
x = self.final_norm(x)
# Remove control token position, project to latent space
if controls is not None:
x = x[:, 1:] # Remove control position
z_pred = self.latent_out(x) # (B, T, latent_dim)
return z_pred
def generate(
self,
n_phrases: int,
controls: Optional[Dict[str, torch.Tensor]] = None,
temperature: float = 0.8,
batch_size: int = 1,
) -> torch.Tensor:
"""
Generate a sequence of phrase latents autoregressively.
Uses Mamba's recurrent mode for O(1) memory per step.
Can generate infinitely without memory growth.
"""
device = next(self.parameters()).device
d = self.config.mamba_d_model
# Initialize with control embedding or zeros
if controls is not None:
z_init = self.control_embed(**controls) # (B, 1, d)
else:
z_init = torch.zeros(batch_size, 1, d, device=device)
# Generate phrase latents one by one
generated = []
x = z_init + self.pos_embed(torch.tensor([0], device=device))
# Initialize Mamba states
states = [torch.zeros(batch_size, self.config.mamba_d_model * self.config.mamba_expand,
self.config.mamba_d_state, device=device)
for _ in range(self.config.mamba_n_layers)]
for t in range(n_phrases):
h = x
for i, layer in enumerate(self.layers):
h = layer.norm(h)
# Note: In production, use Mamba's step() for true O(1) inference
h = layer.ssm(h) # Simplified; real impl would update states
h = x + layer.dropout(h - x + h) # residual
x = h
h = self.final_norm(h)
z_t = self.latent_out(h[:, -1:]) # (B, 1, latent_dim)
# Add noise for diversity (controlled by temperature)
if temperature > 0:
z_t = z_t + temperature * torch.randn_like(z_t)
generated.append(z_t)
# Prepare next input
x = self.latent_in(z_t) + self.pos_embed(
torch.tensor([t + 1], device=device).clamp(max=self.config.max_phrases - 1)
)
return torch.cat(generated, dim=1) # (B, n_phrases, latent_dim)
# ============================================================================
# Complete MuseMorphic Model
# ============================================================================
class MuseMorphic(nn.Module):
"""
Complete MuseMorphic model combining PhraseVAE and LatentMamba.
Two-stage training:
Stage 1: Train PhraseVAE (encode/decode individual phrases)
Stage 2: Freeze PhraseVAE encoder, train LatentMamba on latent sequences
Inference pipeline:
Controls → LatentMamba.generate() → PhraseVAE.decode() → REMI+ tokens → MIDI
"""
def __init__(self, config: MuseMorphicConfig):
super().__init__()
self.config = config
self.phrase_vae = PhraseVAE(config)
self.latent_mamba = LatentMamba(config)
def encode_phrases(self, phrases: List[torch.Tensor], **kwargs) -> torch.Tensor:
"""
Encode a list of phrase token sequences to latent vectors.
Args:
phrases: List of (batch, phrase_len) token tensors
Returns:
z_seq: (batch, n_phrases, latent_dim)
"""
z_list = []
self.phrase_vae.eval()
with torch.no_grad():
for phrase_tokens in phrases:
z, _, _ = self.phrase_vae.encode(phrase_tokens, **kwargs)
z_list.append(z.unsqueeze(1))
return torch.cat(z_list, dim=1)
def decode_phrases(self, z_seq: torch.Tensor, max_len: int = 256) -> List[torch.Tensor]:
"""
Decode latent vectors back to token sequences.
Args:
z_seq: (batch, n_phrases, latent_dim)
Returns:
List of (batch, phrase_len) token tensors
"""
B, T, _ = z_seq.shape
decoded = []
self.phrase_vae.eval()
with torch.no_grad():
for t in range(T):
z = z_seq[:, t]
# Autoregressive decoding
tokens = self._ar_decode(z, max_len)
decoded.append(tokens)
return decoded
def _ar_decode(self, z: torch.Tensor, max_len: int) -> torch.Tensor:
"""Autoregressive decoding from latent vector."""
B = z.shape[0]
device = z.device
# Start with BOS token
tokens = torch.full((B, 1), self.config.bos_token_id, dtype=torch.long, device=device)
for _ in range(max_len - 1):
logits = self.phrase_vae.decode(z, tokens)
next_token_logits = logits[:, -1, :] # (B, vocab_size)
# Greedy or sample
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
tokens = torch.cat([tokens, next_token], dim=1)
# Stop if all sequences have generated EOS
if (next_token == self.config.eos_token_id).all():
break
return tokens
@torch.no_grad()
def generate(
self,
n_phrases: int = 32,
controls: Optional[Dict[str, torch.Tensor]] = None,
temperature: float = 0.8,
max_phrase_len: int = 256,
batch_size: int = 1,
) -> List[torch.Tensor]:
"""
Full generation pipeline.
Controls → LatentMamba → PhraseVAE.decode → REMI+ tokens
Memory: O(D·N) fixed during generation — truly infinite.
"""
self.eval()
# Stage 2: Generate phrase latent sequence
z_seq = self.latent_mamba.generate(
n_phrases, controls, temperature, batch_size
)
# Stage 1 (decode): Latent → REMI+ tokens
decoded_phrases = self.decode_phrases(z_seq, max_phrase_len)
return decoded_phrases
def count_parameters(self) -> Dict[str, int]:
"""Count parameters by component."""
vae_enc = sum(p.numel() for p in self.phrase_vae.encoder.parameters())
vae_dec = sum(p.numel() for p in self.phrase_vae.decoder.parameters())
vae_emb = sum(p.numel() for p in self.phrase_vae.embedding.parameters())
mamba = sum(p.numel() for p in self.latent_mamba.parameters())
total = sum(p.numel() for p in self.parameters())
return {
'vae_encoder': vae_enc,
'vae_decoder': vae_dec,
'vae_embedding': vae_emb,
'latent_mamba': mamba,
'total': total,
}
def get_vram_estimate(self, batch_size: int = 1, seq_len: int = 256,
dtype_bytes: int = 2) -> Dict[str, str]:
"""Estimate VRAM usage."""
params = self.count_parameters()
# Parameters
param_mem = params['total'] * dtype_bytes
# Activations (rough estimate: 2x parameters for forward pass)
act_mem = param_mem * 2
# Optimizer states (AdamW: 2 states per param)
opt_mem = params['total'] * 4 * 2 # FP32 optimizer states
training_mem = param_mem + act_mem + opt_mem
inference_mem = param_mem + act_mem // 4 # Much less activations
return {
'parameters_mb': f"{param_mem / 1e6:.1f} MB",
'training_vram_mb': f"{training_mem / 1e6:.1f} MB",
'inference_vram_mb': f"{inference_mem / 1e6:.1f} MB",
}
# ============================================================================
# ZClip — Adaptive Gradient Clipping
# ============================================================================
class ZClip:
"""
Adaptive gradient clipping via z-score thresholding.
From ZClip (2025) "Adaptive Spike Mitigation for LLM Pre-Training"
(arXiv:2504.02507).
Only clips genuine gradient spikes, not normal gradients.
Optimal z_thresh: 2.0-3.0 (Table 6 in paper).
"""
def __init__(self, z_thresh: float = 2.5, alpha: float = 0.99):
self.z_thresh = z_thresh
self.alpha = alpha
self.mu = 0.0
self.var = 1.0
self.initialized = False
def __call__(self, model: nn.Module) -> float:
"""Clip gradients and return the original norm."""
total_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), float('inf')
).item()
if not self.initialized:
self.mu = total_norm
self.var = 0.0
self.initialized = True
return total_norm
# Compute adaptive threshold
sigma = max(math.sqrt(self.var), 1e-8)
threshold = self.mu + self.z_thresh * sigma
# Clip only if genuine spike
if total_norm > threshold:
torch.nn.utils.clip_grad_norm_(model.parameters(), threshold)
# Update EMA statistics
self.mu = self.alpha * self.mu + (1 - self.alpha) * total_norm
self.var = self.alpha * self.var + (1 - self.alpha) * (total_norm - self.mu) ** 2
return total_norm
# ============================================================================
# Utility: Model summary
# ============================================================================
def model_summary(config: Optional[MuseMorphicConfig] = None):
"""Print model summary with parameter counts and VRAM estimates."""
if config is None:
config = MuseMorphicConfig()
model = MuseMorphic(config)
params = model.count_parameters()
vram = model.get_vram_estimate()
print("=" * 60)
print("MuseMorphic Model Summary")
print("=" * 60)
print(f"\nParameter Counts:")
for name, count in params.items():
print(f" {name:20s}: {count:>10,d} ({count/1e6:.2f}M)")
print(f"\nVRAM Estimates (BF16):")
for name, est in vram.items():
print(f" {name:20s}: {est}")
print(f"\nArchitecture:")
print(f" d_model: {config.d_model}")
print(f" Vocab size: {config.vocab_size}")
print(f" Latent dim: {config.latent_dim}")
print(f" VAE layers: {config.vae_encoder_layers}+{config.vae_decoder_layers}")
print(f" Mamba layers: {config.mamba_n_layers}")
print(f" Mamba state dim: {config.mamba_d_state}")
print(f" Max phrase tokens: {config.vae_max_seq_len}")
print(f" Max phrases: {config.max_phrases}")
print("=" * 60)
return model
if __name__ == "__main__":
model = model_summary()