asdf98 commited on
Commit
f88bd6f
·
verified ·
1 Parent(s): f7d254f

Upload musemorphic/model.py

Browse files
Files changed (1) hide show
  1. musemorphic/model.py +1199 -0
musemorphic/model.py ADDED
@@ -0,0 +1,1199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MuseMorphic: Lightweight Consumer-Grade MIDI Generation Architecture
3
+ ====================================================================
4
+
5
+ A novel two-stage hierarchical architecture combining:
6
+ Stage 1 - PhraseVAE: Compress REMI+ tokens → 64-dim latent vectors
7
+ Stage 2 - LatentMamba: Generate latent sequences with O(n) complexity
8
+
9
+ Key innovations:
10
+ - O(n) complexity everywhere (Selective SSM backbone)
11
+ - Music-native FME embeddings (translational invariance, transposability)
12
+ - ~33M parameters, trains on free Colab T4, inference <1GB VRAM
13
+ - Controllable via multi-attribute conditioning
14
+ - Infinite generation via fixed-size recurrent state
15
+ - Training stability by design (σReparam, ZClip, Pre-LN, BF16, label smoothing)
16
+ """
17
+
18
+ import math
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from dataclasses import dataclass, field
23
+ from typing import Optional, List, Tuple, Dict
24
+ from einops import rearrange
25
+
26
+ # ============================================================================
27
+ # Configuration
28
+ # ============================================================================
29
+
30
+ @dataclass
31
+ class MuseMorphicConfig:
32
+ """Complete configuration for MuseMorphic architecture."""
33
+
34
+ # --- Tokenizer ---
35
+ vocab_size: int = 8192 # BPE vocabulary size
36
+ pad_token_id: int = 0
37
+ bos_token_id: int = 1
38
+ eos_token_id: int = 2
39
+ mask_token_id: int = 3
40
+
41
+ # --- FME Embeddings ---
42
+ d_model: int = 256 # Model dimension throughout
43
+ fme_base_pitch: float = 10000.0 # Base B for pitch FME
44
+ fme_base_duration: float = 1000.0 # Base B for duration FME
45
+ fme_base_onset: float = 5000.0 # Base B for onset FME
46
+ use_log_frequency: bool = True # Encode pitch as log-frequency
47
+
48
+ # --- PhraseVAE ---
49
+ vae_encoder_layers: int = 3
50
+ vae_decoder_layers: int = 3
51
+ vae_n_heads: int = 4
52
+ vae_d_ff: int = 512 # Feed-forward dim
53
+ vae_n_queries: int = 4 # Multi-query bottleneck queries
54
+ latent_dim: int = 64 # VAE latent dimension
55
+ vae_dropout: float = 0.1
56
+ vae_max_seq_len: int = 256 # Max tokens per phrase
57
+ kl_beta: float = 0.01 # KL weight (low to prevent posterior collapse)
58
+ label_smoothing: float = 0.1
59
+
60
+ # --- LatentMamba ---
61
+ mamba_d_model: int = 256
62
+ mamba_n_layers: int = 8
63
+ mamba_d_state: int = 16 # SSM state dimension N
64
+ mamba_d_conv: int = 4 # Local convolution width
65
+ mamba_expand: int = 2 # Inner dimension expansion factor
66
+ mamba_dropout: float = 0.1
67
+ max_phrases: int = 512 # Max phrases in a piece
68
+
69
+ # --- Control ---
70
+ n_tempo_bins: int = 45 # (30-210 BPM, step 4)
71
+ n_key_classes: int = 24 # 12 keys × major/minor
72
+ n_time_sig_classes: int = 8 # Common time signatures
73
+ n_density_bins: int = 10 # Note density percentile bins
74
+ n_style_classes: int = 32 # Style/genre categories
75
+
76
+ # --- Training Stability ---
77
+ use_sigma_reparam: bool = True
78
+ use_pre_ln: bool = True
79
+ zclip_z_thresh: float = 2.5
80
+ zclip_alpha: float = 0.99
81
+
82
+ # --- Training ---
83
+ learning_rate: float = 3e-4
84
+ weight_decay: float = 0.01
85
+ warmup_steps: int = 500
86
+ max_steps: int = 100000
87
+ batch_size: int = 32
88
+ gradient_accumulation_steps: int = 1
89
+
90
+
91
+ # ============================================================================
92
+ # Fundamental Music Embedding (FME) — Physics-Aware
93
+ # ============================================================================
94
+
95
+ class FundamentalMusicEmbedding(nn.Module):
96
+ """
97
+ Translational-invariant, transposable pitch/duration/onset embedding.
98
+
99
+ From Liang et al. (2022) "Domain-Knowledge-Inspired Music Embedding"
100
+ Extended with log-frequency pitch encoding for harmonic series awareness.
101
+
102
+ Properties:
103
+ 1. |f_a - f_b| = |f_c - f_d| => ||FME(f_a) - FME(f_b)|| = ||FME(f_c) - FME(f_d)||
104
+ 2. Transposition is a linear operation in embedding space
105
+ 3. Pitch, duration, onset are orthogonal via different base B values
106
+ """
107
+
108
+ def __init__(self, d_model: int, base_B: float = 10000.0, use_log_freq: bool = False):
109
+ super().__init__()
110
+ self.d_model = d_model
111
+ self.use_log_freq = use_log_freq
112
+ half_d = d_model // 2
113
+
114
+ # Exponentially decaying frequencies
115
+ k = torch.arange(half_d, dtype=torch.float32)
116
+ w_k = base_B ** (-2.0 * k / d_model)
117
+ self.register_buffer('w_k', w_k)
118
+
119
+ # Learnable biases (enable fine-tuning of embedding geometry)
120
+ self.b_sin = nn.Parameter(torch.zeros(half_d))
121
+ self.b_cos = nn.Parameter(torch.zeros(half_d))
122
+
123
+ def forward(self, values: torch.Tensor) -> torch.Tensor:
124
+ """
125
+ Args:
126
+ values: Integer or float values, shape (batch, seq_len)
127
+ Returns:
128
+ Embedding, shape (batch, seq_len, d_model)
129
+ """
130
+ f = values.float()
131
+
132
+ if self.use_log_freq:
133
+ # Convert MIDI pitch to log-frequency (respects harmonic series)
134
+ # f_hz = 440 * 2^((p-69)/12), log2(f_hz) = log2(440) + (p-69)/12
135
+ f = torch.log2(440.0 * (2.0 ** ((f - 69.0) / 12.0)) + 1e-8)
136
+
137
+ f = f.unsqueeze(-1) # (B, L, 1)
138
+
139
+ sin_enc = torch.sin(self.w_k * f) + self.b_sin # (B, L, d/2)
140
+ cos_enc = torch.cos(self.w_k * f) + self.b_cos # (B, L, d/2)
141
+
142
+ return torch.cat([sin_enc, cos_enc], dim=-1) # (B, L, d)
143
+
144
+
145
+ class MusicTokenEmbedding(nn.Module):
146
+ """
147
+ Combined embedding for REMI+ tokens using FME for musically-meaningful tokens
148
+ and standard learned embeddings for structural tokens.
149
+ """
150
+
151
+ def __init__(self, config: MuseMorphicConfig):
152
+ super().__init__()
153
+ self.config = config
154
+ d = config.d_model
155
+
156
+ # Standard token embedding (for BPE tokens)
157
+ self.token_embed = nn.Embedding(config.vocab_size, d, padding_idx=config.pad_token_id)
158
+
159
+ # FME components (used as additive bias for pitch/duration/onset tokens)
160
+ self.pitch_fme = FundamentalMusicEmbedding(d, config.fme_base_pitch, config.use_log_frequency)
161
+ self.duration_fme = FundamentalMusicEmbedding(d, config.fme_base_duration, False)
162
+ self.onset_fme = FundamentalMusicEmbedding(d, config.fme_base_onset, False)
163
+
164
+ # Positional embedding (within-bar position, learnable)
165
+ self.pos_embed = nn.Embedding(config.vae_max_seq_len, d)
166
+
167
+ # Layer norm for embedding output stability
168
+ self.embed_ln = nn.LayerNorm(d)
169
+ self.embed_dropout = nn.Dropout(config.vae_dropout)
170
+
171
+ # Scale factor
172
+ self.scale = math.sqrt(d)
173
+
174
+ def forward(
175
+ self,
176
+ token_ids: torch.Tensor,
177
+ pitch_values: Optional[torch.Tensor] = None,
178
+ duration_values: Optional[torch.Tensor] = None,
179
+ onset_values: Optional[torch.Tensor] = None,
180
+ ) -> torch.Tensor:
181
+ """
182
+ Args:
183
+ token_ids: (batch, seq_len) BPE token indices
184
+ pitch_values: (batch, seq_len) MIDI pitch values (0 where not applicable)
185
+ duration_values: (batch, seq_len) duration ticks (0 where not applicable)
186
+ onset_values: (batch, seq_len) onset positions (0 where not applicable)
187
+ """
188
+ B, L = token_ids.shape
189
+
190
+ # Base token embedding
191
+ x = self.token_embed(token_ids) * self.scale
192
+
193
+ # Add FME for musically-meaningful attributes (when available)
194
+ if pitch_values is not None:
195
+ mask = (pitch_values > 0).float().unsqueeze(-1)
196
+ x = x + self.pitch_fme(pitch_values) * mask
197
+
198
+ if duration_values is not None:
199
+ mask = (duration_values > 0).float().unsqueeze(-1)
200
+ x = x + self.duration_fme(duration_values) * mask
201
+
202
+ if onset_values is not None:
203
+ mask = (onset_values > 0).float().unsqueeze(-1)
204
+ x = x + self.onset_fme(onset_values) * mask
205
+
206
+ # Add positional embedding
207
+ positions = torch.arange(L, device=token_ids.device).unsqueeze(0).expand(B, -1)
208
+ x = x + self.pos_embed(positions)
209
+
210
+ return self.embed_dropout(self.embed_ln(x))
211
+
212
+
213
+ # ============================================================================
214
+ # σReparam (Spectral Reparameterization) — Training Stability
215
+ # ============================================================================
216
+
217
+ class SigmaReparamLinear(nn.Module):
218
+ """
219
+ Linear layer with spectral reparameterization (σReparam).
220
+
221
+ From Zhai et al. (2023) "Stabilizing Transformer Training by Preventing
222
+ Attention Entropy Collapse" (arXiv:2303.06296).
223
+
224
+ W_hat = (γ / σ(W)) * W
225
+
226
+ where σ(W) is the spectral norm (largest singular value).
227
+ Prevents attention entropy collapse — the #1 source of training instability.
228
+ """
229
+
230
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
231
+ super().__init__()
232
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
233
+ # Apply spectral normalization
234
+ self.linear = nn.utils.parametrizations.spectral_norm(self.linear)
235
+ # Learnable scaling factor (initialized to 1)
236
+ self.gamma = nn.Parameter(torch.ones(1))
237
+
238
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
239
+ return self.gamma * self.linear(x)
240
+
241
+
242
+ def make_linear(in_f: int, out_f: int, bias: bool = True, sigma_reparam: bool = True) -> nn.Module:
243
+ """Factory for linear layers with optional σReparam."""
244
+ if sigma_reparam:
245
+ return SigmaReparamLinear(in_f, out_f, bias)
246
+ return nn.Linear(in_f, out_f, bias)
247
+
248
+
249
+ # ============================================================================
250
+ # Pre-LN Transformer Block (for PhraseVAE encoder/decoder)
251
+ # ============================================================================
252
+
253
+ class PreLNMultiHeadAttention(nn.Module):
254
+ """Multi-head attention with Pre-LayerNorm and σReparam."""
255
+
256
+ def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1,
257
+ sigma_reparam: bool = True, is_cross_attention: bool = False):
258
+ super().__init__()
259
+ assert d_model % n_heads == 0
260
+ self.n_heads = n_heads
261
+ self.d_head = d_model // n_heads
262
+ self.scale = math.sqrt(self.d_head)
263
+
264
+ self.q_proj = make_linear(d_model, d_model, sigma_reparam=sigma_reparam)
265
+ self.k_proj = make_linear(d_model, d_model, sigma_reparam=sigma_reparam)
266
+ self.v_proj = make_linear(d_model, d_model, sigma_reparam=sigma_reparam)
267
+ self.out_proj = make_linear(d_model, d_model, sigma_reparam=sigma_reparam)
268
+
269
+ self.attn_dropout = nn.Dropout(dropout)
270
+ self.is_cross_attention = is_cross_attention
271
+
272
+ def forward(
273
+ self,
274
+ x: torch.Tensor,
275
+ context: Optional[torch.Tensor] = None,
276
+ mask: Optional[torch.Tensor] = None,
277
+ is_causal: bool = False,
278
+ ) -> torch.Tensor:
279
+ B, L, D = x.shape
280
+
281
+ q = self.q_proj(x)
282
+ kv_input = context if self.is_cross_attention and context is not None else x
283
+ k = self.k_proj(kv_input)
284
+ v = self.v_proj(kv_input)
285
+
286
+ # Reshape for multi-head
287
+ q = rearrange(q, 'b l (h d) -> b h l d', h=self.n_heads)
288
+ k = rearrange(k, 'b s (h d) -> b h s d', h=self.n_heads)
289
+ v = rearrange(v, 'b s (h d) -> b h s d', h=self.n_heads)
290
+
291
+ # Scaled dot-product attention (using PyTorch's efficient implementation)
292
+ attn_out = F.scaled_dot_product_attention(
293
+ q, k, v,
294
+ attn_mask=mask,
295
+ dropout_p=self.attn_dropout.p if self.training else 0.0,
296
+ is_causal=is_causal,
297
+ )
298
+
299
+ attn_out = rearrange(attn_out, 'b h l d -> b l (h d)')
300
+ return self.out_proj(attn_out)
301
+
302
+
303
+ class PreLNFeedForward(nn.Module):
304
+ """Feed-forward network with Pre-LN, SiLU activation, and σReparam."""
305
+
306
+ def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1,
307
+ sigma_reparam: bool = True):
308
+ super().__init__()
309
+ self.w1 = make_linear(d_model, d_ff, sigma_reparam=sigma_reparam)
310
+ self.w2 = make_linear(d_ff, d_model, sigma_reparam=sigma_reparam)
311
+ self.gate = make_linear(d_model, d_ff, sigma_reparam=sigma_reparam)
312
+ self.dropout = nn.Dropout(dropout)
313
+
314
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
315
+ # SwiGLU-style gating (used in LLaMA, Mamba)
316
+ return self.dropout(self.w2(F.silu(self.gate(x)) * self.w1(x)))
317
+
318
+
319
+ class PreLNTransformerBlock(nn.Module):
320
+ """
321
+ Transformer block with Pre-LayerNorm for guaranteed training stability.
322
+
323
+ Pre-LN: x → LayerNorm → Sublayer → + residual
324
+ (vs Post-LN: x → Sublayer → + residual → LayerNorm, which is UNSTABLE)
325
+
326
+ Pre-LN has analytically bounded gradient norms, eliminates need for LR warmup.
327
+ """
328
+
329
+ def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1,
330
+ sigma_reparam: bool = True, has_cross_attention: bool = False):
331
+ super().__init__()
332
+
333
+ self.norm1 = nn.LayerNorm(d_model)
334
+ self.self_attn = PreLNMultiHeadAttention(d_model, n_heads, dropout, sigma_reparam)
335
+
336
+ self.has_cross_attention = has_cross_attention
337
+ if has_cross_attention:
338
+ self.norm_cross = nn.LayerNorm(d_model)
339
+ self.cross_attn = PreLNMultiHeadAttention(
340
+ d_model, n_heads, dropout, sigma_reparam, is_cross_attention=True
341
+ )
342
+
343
+ self.norm2 = nn.LayerNorm(d_model)
344
+ self.ffn = PreLNFeedForward(d_model, d_ff, dropout, sigma_reparam)
345
+
346
+ def forward(
347
+ self,
348
+ x: torch.Tensor,
349
+ context: Optional[torch.Tensor] = None,
350
+ mask: Optional[torch.Tensor] = None,
351
+ is_causal: bool = False,
352
+ ) -> torch.Tensor:
353
+ # Pre-LN self-attention
354
+ x = x + self.self_attn(self.norm1(x), mask=mask, is_causal=is_causal)
355
+
356
+ # Pre-LN cross-attention (if applicable)
357
+ if self.has_cross_attention and context is not None:
358
+ x = x + self.cross_attn(self.norm_cross(x), context=context)
359
+
360
+ # Pre-LN feed-forward
361
+ x = x + self.ffn(self.norm2(x))
362
+
363
+ return x
364
+
365
+
366
+ # ============================================================================
367
+ # PhraseVAE — Stage 1: Compress REMI+ phrases to latent vectors
368
+ # ============================================================================
369
+
370
+ class PhraseVAEEncoder(nn.Module):
371
+ """
372
+ Encode a sequence of REMI+ tokens into a latent vector using
373
+ multi-query cross-attention bottleneck.
374
+
375
+ Architecture: TransformerEncoder → MultiQueryBottleneck → μ, log_var
376
+ """
377
+
378
+ def __init__(self, config: MuseMorphicConfig):
379
+ super().__init__()
380
+ self.config = config
381
+ d = config.d_model
382
+
383
+ # Transformer encoder layers
384
+ self.layers = nn.ModuleList([
385
+ PreLNTransformerBlock(
386
+ d, config.vae_n_heads, config.vae_d_ff,
387
+ config.vae_dropout, config.use_sigma_reparam
388
+ )
389
+ for _ in range(config.vae_encoder_layers)
390
+ ])
391
+
392
+ self.final_norm = nn.LayerNorm(d)
393
+
394
+ # Multi-query bottleneck (m learned queries)
395
+ self.query_tokens = nn.Parameter(torch.randn(config.vae_n_queries, d) * 0.02)
396
+ self.bottleneck_attn = PreLNMultiHeadAttention(
397
+ d, config.vae_n_heads, config.vae_dropout,
398
+ config.use_sigma_reparam, is_cross_attention=True
399
+ )
400
+ self.bottleneck_norm = nn.LayerNorm(d)
401
+
402
+ # Project to latent space
403
+ bottleneck_dim = config.vae_n_queries * d
404
+ self.to_mu = nn.Linear(bottleneck_dim, config.latent_dim)
405
+ self.to_log_var = nn.Linear(bottleneck_dim, config.latent_dim)
406
+
407
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
408
+ """
409
+ Args:
410
+ x: Embedded tokens (batch, seq_len, d_model)
411
+ Returns:
412
+ mu: (batch, latent_dim)
413
+ log_var: (batch, latent_dim)
414
+ """
415
+ B = x.shape[0]
416
+
417
+ # Encode through transformer layers
418
+ for layer in self.layers:
419
+ x = layer(x, mask=mask)
420
+ x = self.final_norm(x)
421
+
422
+ # Multi-query bottleneck
423
+ queries = self.query_tokens.unsqueeze(0).expand(B, -1, -1) # (B, m, d)
424
+ z_queries = self.bottleneck_attn(
425
+ self.bottleneck_norm(queries), context=x
426
+ ) # (B, m, d)
427
+
428
+ # Flatten and project
429
+ z_flat = z_queries.reshape(B, -1) # (B, m*d)
430
+ mu = self.to_mu(z_flat)
431
+ log_var = self.to_log_var(z_flat)
432
+
433
+ return mu, log_var
434
+
435
+
436
+ class PhraseVAEDecoder(nn.Module):
437
+ """
438
+ Decode a latent vector back to REMI+ token sequence (autoregressive).
439
+
440
+ Architecture: LatentProjection → CrossAttention with latent → AR generation
441
+ """
442
+
443
+ def __init__(self, config: MuseMorphicConfig):
444
+ super().__init__()
445
+ self.config = config
446
+ d = config.d_model
447
+
448
+ # Project latent to key/value for cross-attention
449
+ self.latent_proj = nn.Linear(config.latent_dim, config.vae_n_queries * d)
450
+
451
+ # Token embedding for autoregressive decoding
452
+ self.token_embed = nn.Embedding(config.vocab_size, d, padding_idx=config.pad_token_id)
453
+ self.pos_embed = nn.Embedding(config.vae_max_seq_len, d)
454
+ self.embed_scale = math.sqrt(d)
455
+
456
+ # Decoder layers (with cross-attention to latent)
457
+ self.layers = nn.ModuleList([
458
+ PreLNTransformerBlock(
459
+ d, config.vae_n_heads, config.vae_d_ff,
460
+ config.vae_dropout, config.use_sigma_reparam,
461
+ has_cross_attention=True
462
+ )
463
+ for _ in range(config.vae_decoder_layers)
464
+ ])
465
+
466
+ self.final_norm = nn.LayerNorm(d)
467
+ self.output_proj = nn.Linear(d, config.vocab_size, bias=False)
468
+
469
+ def forward(
470
+ self,
471
+ z: torch.Tensor,
472
+ target_tokens: torch.Tensor,
473
+ ) -> torch.Tensor:
474
+ """
475
+ Args:
476
+ z: Latent vector (batch, latent_dim)
477
+ target_tokens: Target token ids for teacher forcing (batch, seq_len)
478
+ Returns:
479
+ logits: (batch, seq_len, vocab_size)
480
+ """
481
+ B, L = target_tokens.shape
482
+ d = self.config.d_model
483
+
484
+ # Project latent to cross-attention context
485
+ latent_ctx = self.latent_proj(z).reshape(B, self.config.vae_n_queries, d)
486
+
487
+ # Embed target tokens
488
+ positions = torch.arange(L, device=target_tokens.device).unsqueeze(0)
489
+ x = self.token_embed(target_tokens) * self.embed_scale + self.pos_embed(positions)
490
+
491
+ # Decode with causal masking
492
+ for layer in self.layers:
493
+ x = layer(x, context=latent_ctx, is_causal=True)
494
+
495
+ x = self.final_norm(x)
496
+ logits = self.output_proj(x)
497
+
498
+ return logits
499
+
500
+
501
+ class PhraseVAE(nn.Module):
502
+ """
503
+ Complete PhraseVAE: Encode REMI+ token phrases → latent vectors → decode back.
504
+
505
+ Three-stage training curriculum:
506
+ Stage 1: Span-infilling pretraining (learn REMI grammar)
507
+ Stage 2: Autoencoder (KL weight = 0, pure reconstruction)
508
+ Stage 3: VAE fine-tuning (KL weight = β = 0.01)
509
+ """
510
+
511
+ def __init__(self, config: MuseMorphicConfig):
512
+ super().__init__()
513
+ self.config = config
514
+
515
+ # Shared embedding (encoder input)
516
+ self.embedding = MusicTokenEmbedding(config)
517
+
518
+ # Encoder and decoder
519
+ self.encoder = PhraseVAEEncoder(config)
520
+ self.decoder = PhraseVAEDecoder(config)
521
+
522
+ def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
523
+ """Reparameterization trick: z = μ + σ * ε"""
524
+ if self.training:
525
+ std = torch.exp(0.5 * log_var)
526
+ eps = torch.randn_like(std)
527
+ return mu + std * eps
528
+ return mu # At inference, just use the mean
529
+
530
+ def encode(self, token_ids: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
531
+ """Encode tokens to latent space."""
532
+ x = self.embedding(token_ids, **kwargs)
533
+ mu, log_var = self.encoder(x)
534
+ z = self.reparameterize(mu, log_var)
535
+ return z, mu, log_var
536
+
537
+ def decode(self, z: torch.Tensor, target_tokens: torch.Tensor) -> torch.Tensor:
538
+ """Decode latent vector to token logits."""
539
+ return self.decoder(z, target_tokens)
540
+
541
+ def forward(
542
+ self,
543
+ token_ids: torch.Tensor,
544
+ target_tokens: Optional[torch.Tensor] = None,
545
+ kl_weight: float = 0.01,
546
+ **kwargs
547
+ ) -> Dict[str, torch.Tensor]:
548
+ """
549
+ Full forward pass with loss computation.
550
+
551
+ Args:
552
+ token_ids: Input tokens (batch, seq_len)
553
+ target_tokens: Target tokens for reconstruction (batch, seq_len),
554
+ defaults to token_ids shifted right
555
+ kl_weight: β for KL loss weighting (0 for AE stage, 0.01 for VAE stage)
556
+ """
557
+ B, L = token_ids.shape
558
+
559
+ if target_tokens is None:
560
+ target_tokens = token_ids
561
+
562
+ # Encode
563
+ z, mu, log_var = self.encode(token_ids, **kwargs)
564
+
565
+ # Decode (teacher forcing with shifted input)
566
+ decoder_input = target_tokens[:, :-1] # Remove last token
567
+ decoder_target = target_tokens[:, 1:] # Remove first token (shift right)
568
+ logits = self.decode(z, decoder_input)
569
+
570
+ # Reconstruction loss with label smoothing
571
+ recon_loss = F.cross_entropy(
572
+ logits.reshape(-1, self.config.vocab_size),
573
+ decoder_target.reshape(-1),
574
+ ignore_index=self.config.pad_token_id,
575
+ label_smoothing=self.config.label_smoothing,
576
+ )
577
+
578
+ # KL divergence (per-sample, averaged)
579
+ kl_loss = -0.5 * torch.mean(
580
+ torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1)
581
+ )
582
+
583
+ total_loss = recon_loss + kl_weight * kl_loss
584
+
585
+ return {
586
+ 'loss': total_loss,
587
+ 'recon_loss': recon_loss,
588
+ 'kl_loss': kl_loss,
589
+ 'z': z,
590
+ 'mu': mu,
591
+ 'log_var': log_var,
592
+ 'logits': logits,
593
+ }
594
+
595
+
596
+ # ============================================================================
597
+ # Selective SSM (Mamba) Block — O(n) Sequence Modeling
598
+ # ============================================================================
599
+
600
+ class SelectiveSSM(nn.Module):
601
+ """
602
+ Selective State Space Model (Mamba core).
603
+
604
+ From Gu & Dao (2023) "Mamba: Linear-Time Sequence Modeling with Selective
605
+ State Spaces" (arXiv:2312.00752).
606
+
607
+ Key equations:
608
+ B(x) = Linear_N(x) -- input-dependent
609
+ C(x) = Linear_N(x) -- input-dependent
610
+ Δ(x) = softplus(Linear_1(x) + param) -- input-dependent discretization
611
+ Ā = exp(Δ · A) -- discretized state matrix
612
+ B̄ = Δ · B(x) -- simplified discretized input matrix
613
+ h_t = Ā · h_{t-1} + B̄ · x_t -- state update
614
+ y_t = C(x_t) · h_t -- output
615
+
616
+ Training: parallel scan O(BLD·N)
617
+ Inference: O(BD·N) per step, state is O(D·N) fixed
618
+ """
619
+
620
+ def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4,
621
+ expand: int = 2, sigma_reparam: bool = True):
622
+ super().__init__()
623
+ self.d_model = d_model
624
+ self.d_state = d_state
625
+ self.d_inner = d_model * expand
626
+ self.d_conv = d_conv
627
+
628
+ # Input projection (expand dimension)
629
+ self.in_proj = make_linear(d_model, self.d_inner * 2, bias=False, sigma_reparam=sigma_reparam)
630
+
631
+ # Depthwise convolution (local context)
632
+ self.conv1d = nn.Conv1d(
633
+ self.d_inner, self.d_inner,
634
+ kernel_size=d_conv,
635
+ padding=d_conv - 1,
636
+ groups=self.d_inner,
637
+ )
638
+
639
+ # SSM parameters
640
+ # A is initialized as negative log-spaced values (HiPPO-inspired)
641
+ A = torch.arange(1, d_state + 1, dtype=torch.float32).unsqueeze(0).expand(self.d_inner, -1)
642
+ self.A_log = nn.Parameter(torch.log(A)) # Learn in log space for stability
643
+ self.D = nn.Parameter(torch.ones(self.d_inner)) # Skip connection
644
+
645
+ # Input-dependent projections
646
+ self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False) # B, C, dt
647
+ self.dt_proj = nn.Linear(1, self.d_inner, bias=True)
648
+
649
+ # Initialize dt bias for proper timescales
650
+ dt_init_std = 0.02
651
+ nn.init.uniform_(self.dt_proj.bias, math.log(0.001), math.log(0.1))
652
+
653
+ # Output projection
654
+ self.out_proj = make_linear(self.d_inner, d_model, bias=False, sigma_reparam=sigma_reparam)
655
+
656
+ def _ssm_scan(self, x: torch.Tensor, A: torch.Tensor, B: torch.Tensor,
657
+ C: torch.Tensor, D: torch.Tensor, dt: torch.Tensor) -> torch.Tensor:
658
+ """
659
+ Parallel associative scan for training efficiency.
660
+
661
+ This is a pure PyTorch implementation using sequential scan.
662
+ For production, use the CUDA kernel from mamba-ssm package.
663
+
664
+ Args:
665
+ x: (B, L, D_inner)
666
+ A: (D_inner, N) — state transition (negative, in log space)
667
+ B: (B, L, N) — input-dependent input matrix
668
+ C: (B, L, N) — input-dependent output matrix
669
+ D: (D_inner,) — skip connection
670
+ dt: (B, L, D_inner) — input-dependent discretization step
671
+ """
672
+ batch, seq_len, d_inner = x.shape
673
+ N = self.d_state
674
+
675
+ # Discretize: Ā = exp(dt * A), B̄ = dt * B
676
+ A_discrete = torch.exp(dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)) # (B, L, D, N)
677
+ B_discrete = dt.unsqueeze(-1) * B.unsqueeze(2) # (B, L, D, N)
678
+
679
+ # Sequential scan (can be parallelized with associative scan)
680
+ h = torch.zeros(batch, d_inner, N, device=x.device, dtype=x.dtype)
681
+ outputs = []
682
+
683
+ for t in range(seq_len):
684
+ h = A_discrete[:, t] * h + B_discrete[:, t] * x[:, t].unsqueeze(-1)
685
+ y_t = torch.sum(h * C[:, t].unsqueeze(1), dim=-1) # (B, D)
686
+ outputs.append(y_t)
687
+
688
+ y = torch.stack(outputs, dim=1) # (B, L, D)
689
+
690
+ # Skip connection
691
+ y = y + x * D.unsqueeze(0).unsqueeze(0)
692
+
693
+ return y
694
+
695
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
696
+ """
697
+ Args:
698
+ x: (batch, seq_len, d_model)
699
+ Returns:
700
+ (batch, seq_len, d_model)
701
+ """
702
+ B, L, D = x.shape
703
+
704
+ # Input projection with gating
705
+ xz = self.in_proj(x) # (B, L, 2*D_inner)
706
+ x_inner, z = xz.chunk(2, dim=-1) # Each: (B, L, D_inner)
707
+
708
+ # Depthwise convolution for local context
709
+ x_conv = self.conv1d(x_inner.transpose(1, 2))[:, :, :L].transpose(1, 2)
710
+ x_conv = F.silu(x_conv)
711
+
712
+ # Compute input-dependent SSM parameters
713
+ x_proj = self.x_proj(x_conv) # (B, L, 2N+1)
714
+ B_param = x_proj[:, :, :self.d_state] # (B, L, N)
715
+ C_param = x_proj[:, :, self.d_state:2*self.d_state] # (B, L, N)
716
+ dt_param = x_proj[:, :, -1:] # (B, L, 1)
717
+
718
+ # Discretization step
719
+ dt = F.softplus(self.dt_proj(dt_param)) # (B, L, D_inner)
720
+
721
+ # Get A from log space
722
+ A = -torch.exp(self.A_log) # (D_inner, N), negative for stability
723
+
724
+ # Run SSM
725
+ y = self._ssm_scan(x_conv, A, B_param, C_param, self.D, dt)
726
+
727
+ # Gate and output
728
+ y = y * F.silu(z)
729
+ y = self.out_proj(y)
730
+
731
+ return y
732
+
733
+
734
+ class MambaBlock(nn.Module):
735
+ """
736
+ Complete Mamba block with Pre-LN and residual connection.
737
+
738
+ x → Pre-LN → SelectiveSSM → + residual
739
+ """
740
+
741
+ def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4,
742
+ expand: int = 2, dropout: float = 0.1, sigma_reparam: bool = True):
743
+ super().__init__()
744
+ self.norm = nn.LayerNorm(d_model)
745
+ self.ssm = SelectiveSSM(d_model, d_state, d_conv, expand, sigma_reparam)
746
+ self.dropout = nn.Dropout(dropout)
747
+
748
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
749
+ return x + self.dropout(self.ssm(self.norm(x)))
750
+
751
+
752
+ # ============================================================================
753
+ # LatentMamba — Stage 2: Generate phrase latent sequences
754
+ # ============================================================================
755
+
756
+ class ControlEmbedding(nn.Module):
757
+ """
758
+ Embed musical control parameters into d_model vectors.
759
+
760
+ Controls: tempo, key, time_signature, note_density, style
761
+ Each control is embedded and summed, then projected.
762
+ """
763
+
764
+ def __init__(self, config: MuseMorphicConfig):
765
+ super().__init__()
766
+ d = config.mamba_d_model
767
+
768
+ self.tempo_embed = nn.Embedding(config.n_tempo_bins, d)
769
+ self.key_embed = nn.Embedding(config.n_key_classes, d)
770
+ self.time_sig_embed = nn.Embedding(config.n_time_sig_classes, d)
771
+ self.density_embed = nn.Embedding(config.n_density_bins, d)
772
+ self.style_embed = nn.Embedding(config.n_style_classes, d)
773
+
774
+ # Project combined controls
775
+ self.control_proj = nn.Sequential(
776
+ nn.Linear(d, d),
777
+ nn.SiLU(),
778
+ nn.Linear(d, d),
779
+ )
780
+ self.norm = nn.LayerNorm(d)
781
+
782
+ def forward(
783
+ self,
784
+ tempo: Optional[torch.Tensor] = None,
785
+ key: Optional[torch.Tensor] = None,
786
+ time_sig: Optional[torch.Tensor] = None,
787
+ density: Optional[torch.Tensor] = None,
788
+ style: Optional[torch.Tensor] = None,
789
+ ) -> torch.Tensor:
790
+ """Returns control embedding of shape (batch, 1, d_model)."""
791
+ B = tempo.shape[0] if tempo is not None else key.shape[0]
792
+ d = self.tempo_embed.embedding_dim
793
+ device = next(self.parameters()).device
794
+
795
+ ctrl = torch.zeros(B, d, device=device)
796
+
797
+ if tempo is not None:
798
+ ctrl = ctrl + self.tempo_embed(tempo)
799
+ if key is not None:
800
+ ctrl = ctrl + self.key_embed(key)
801
+ if time_sig is not None:
802
+ ctrl = ctrl + self.time_sig_embed(time_sig)
803
+ if density is not None:
804
+ ctrl = ctrl + self.density_embed(density)
805
+ if style is not None:
806
+ ctrl = ctrl + self.style_embed(style)
807
+
808
+ ctrl = self.norm(self.control_proj(ctrl))
809
+ return ctrl.unsqueeze(1) # (B, 1, d)
810
+
811
+
812
+ class LatentMamba(nn.Module):
813
+ """
814
+ Generate sequences of phrase latent vectors using Selective SSM (Mamba).
815
+
816
+ Architecture:
817
+ Input: [control_embed, z_1, z_2, ..., z_T]
818
+ → Linear projection (latent_dim → d_model)
819
+ → MambaBlock × N
820
+ → Linear projection (d_model → latent_dim)
821
+ → Output: predicted z_2, z_3, ..., z_{T+1}
822
+
823
+ Complexity: O(T·D·N) — linear in sequence length
824
+ Inference: O(D·N) per phrase — constant, enables infinite generation
825
+ """
826
+
827
+ def __init__(self, config: MuseMorphicConfig):
828
+ super().__init__()
829
+ self.config = config
830
+ d = config.mamba_d_model
831
+
832
+ # Control embedding
833
+ self.control_embed = ControlEmbedding(config)
834
+
835
+ # Project latent to model dimension
836
+ self.latent_in = nn.Sequential(
837
+ nn.Linear(config.latent_dim, d),
838
+ nn.LayerNorm(d),
839
+ )
840
+
841
+ # Positional embedding for phrase positions
842
+ self.pos_embed = nn.Embedding(config.max_phrases + 1, d) # +1 for control token
843
+
844
+ # Mamba layers
845
+ self.layers = nn.ModuleList([
846
+ MambaBlock(
847
+ d, config.mamba_d_state, config.mamba_d_conv,
848
+ config.mamba_expand, config.mamba_dropout,
849
+ config.use_sigma_reparam
850
+ )
851
+ for _ in range(config.mamba_n_layers)
852
+ ])
853
+
854
+ self.final_norm = nn.LayerNorm(d)
855
+
856
+ # Project back to latent space
857
+ self.latent_out = nn.Linear(d, config.latent_dim)
858
+
859
+ def forward(
860
+ self,
861
+ z_seq: torch.Tensor,
862
+ controls: Optional[Dict[str, torch.Tensor]] = None,
863
+ ) -> torch.Tensor:
864
+ """
865
+ Args:
866
+ z_seq: Sequence of phrase latents (batch, n_phrases, latent_dim)
867
+ controls: Dict of control tensors (each (batch,) integer indices)
868
+ Returns:
869
+ z_pred: Predicted next phrase latents (batch, n_phrases, latent_dim)
870
+ """
871
+ B, T, _ = z_seq.shape
872
+ device = z_seq.device
873
+
874
+ # Project latents to model dimension
875
+ x = self.latent_in(z_seq) # (B, T, d)
876
+
877
+ # Add control embedding at position 0
878
+ if controls is not None:
879
+ ctrl = self.control_embed(**controls) # (B, 1, d)
880
+ x = torch.cat([ctrl, x], dim=1) # (B, T+1, d)
881
+ T_total = T + 1
882
+ else:
883
+ T_total = T
884
+
885
+ # Add positional embeddings
886
+ positions = torch.arange(T_total, device=device).unsqueeze(0)
887
+ x = x + self.pos_embed(positions)
888
+
889
+ # Process through Mamba layers
890
+ for layer in self.layers:
891
+ x = layer(x)
892
+
893
+ x = self.final_norm(x)
894
+
895
+ # Remove control token position, project to latent space
896
+ if controls is not None:
897
+ x = x[:, 1:] # Remove control position
898
+
899
+ z_pred = self.latent_out(x) # (B, T, latent_dim)
900
+
901
+ return z_pred
902
+
903
+ def generate(
904
+ self,
905
+ n_phrases: int,
906
+ controls: Optional[Dict[str, torch.Tensor]] = None,
907
+ temperature: float = 0.8,
908
+ batch_size: int = 1,
909
+ ) -> torch.Tensor:
910
+ """
911
+ Generate a sequence of phrase latents autoregressively.
912
+
913
+ Uses Mamba's recurrent mode for O(1) memory per step.
914
+ Can generate infinitely without memory growth.
915
+ """
916
+ device = next(self.parameters()).device
917
+ d = self.config.mamba_d_model
918
+
919
+ # Initialize with control embedding or zeros
920
+ if controls is not None:
921
+ z_init = self.control_embed(**controls) # (B, 1, d)
922
+ else:
923
+ z_init = torch.zeros(batch_size, 1, d, device=device)
924
+
925
+ # Generate phrase latents one by one
926
+ generated = []
927
+ x = z_init + self.pos_embed(torch.tensor([0], device=device))
928
+
929
+ # Initialize Mamba states
930
+ states = [torch.zeros(batch_size, self.config.mamba_d_model * self.config.mamba_expand,
931
+ self.config.mamba_d_state, device=device)
932
+ for _ in range(self.config.mamba_n_layers)]
933
+
934
+ for t in range(n_phrases):
935
+ h = x
936
+ for i, layer in enumerate(self.layers):
937
+ h = layer.norm(h)
938
+ # Note: In production, use Mamba's step() for true O(1) inference
939
+ h = layer.ssm(h) # Simplified; real impl would update states
940
+ h = x + layer.dropout(h - x + h) # residual
941
+ x = h
942
+
943
+ h = self.final_norm(h)
944
+ z_t = self.latent_out(h[:, -1:]) # (B, 1, latent_dim)
945
+
946
+ # Add noise for diversity (controlled by temperature)
947
+ if temperature > 0:
948
+ z_t = z_t + temperature * torch.randn_like(z_t)
949
+
950
+ generated.append(z_t)
951
+
952
+ # Prepare next input
953
+ x = self.latent_in(z_t) + self.pos_embed(
954
+ torch.tensor([t + 1], device=device).clamp(max=self.config.max_phrases - 1)
955
+ )
956
+
957
+ return torch.cat(generated, dim=1) # (B, n_phrases, latent_dim)
958
+
959
+
960
+ # ============================================================================
961
+ # Complete MuseMorphic Model
962
+ # ============================================================================
963
+
964
+ class MuseMorphic(nn.Module):
965
+ """
966
+ Complete MuseMorphic model combining PhraseVAE and LatentMamba.
967
+
968
+ Two-stage training:
969
+ Stage 1: Train PhraseVAE (encode/decode individual phrases)
970
+ Stage 2: Freeze PhraseVAE encoder, train LatentMamba on latent sequences
971
+
972
+ Inference pipeline:
973
+ Controls → LatentMamba.generate() → PhraseVAE.decode() → REMI+ tokens → MIDI
974
+ """
975
+
976
+ def __init__(self, config: MuseMorphicConfig):
977
+ super().__init__()
978
+ self.config = config
979
+ self.phrase_vae = PhraseVAE(config)
980
+ self.latent_mamba = LatentMamba(config)
981
+
982
+ def encode_phrases(self, phrases: List[torch.Tensor], **kwargs) -> torch.Tensor:
983
+ """
984
+ Encode a list of phrase token sequences to latent vectors.
985
+
986
+ Args:
987
+ phrases: List of (batch, phrase_len) token tensors
988
+ Returns:
989
+ z_seq: (batch, n_phrases, latent_dim)
990
+ """
991
+ z_list = []
992
+ self.phrase_vae.eval()
993
+ with torch.no_grad():
994
+ for phrase_tokens in phrases:
995
+ z, _, _ = self.phrase_vae.encode(phrase_tokens, **kwargs)
996
+ z_list.append(z.unsqueeze(1))
997
+ return torch.cat(z_list, dim=1)
998
+
999
+ def decode_phrases(self, z_seq: torch.Tensor, max_len: int = 256) -> List[torch.Tensor]:
1000
+ """
1001
+ Decode latent vectors back to token sequences.
1002
+
1003
+ Args:
1004
+ z_seq: (batch, n_phrases, latent_dim)
1005
+ Returns:
1006
+ List of (batch, phrase_len) token tensors
1007
+ """
1008
+ B, T, _ = z_seq.shape
1009
+ decoded = []
1010
+
1011
+ self.phrase_vae.eval()
1012
+ with torch.no_grad():
1013
+ for t in range(T):
1014
+ z = z_seq[:, t]
1015
+ # Autoregressive decoding
1016
+ tokens = self._ar_decode(z, max_len)
1017
+ decoded.append(tokens)
1018
+
1019
+ return decoded
1020
+
1021
+ def _ar_decode(self, z: torch.Tensor, max_len: int) -> torch.Tensor:
1022
+ """Autoregressive decoding from latent vector."""
1023
+ B = z.shape[0]
1024
+ device = z.device
1025
+
1026
+ # Start with BOS token
1027
+ tokens = torch.full((B, 1), self.config.bos_token_id, dtype=torch.long, device=device)
1028
+
1029
+ for _ in range(max_len - 1):
1030
+ logits = self.phrase_vae.decode(z, tokens)
1031
+ next_token_logits = logits[:, -1, :] # (B, vocab_size)
1032
+
1033
+ # Greedy or sample
1034
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
1035
+ tokens = torch.cat([tokens, next_token], dim=1)
1036
+
1037
+ # Stop if all sequences have generated EOS
1038
+ if (next_token == self.config.eos_token_id).all():
1039
+ break
1040
+
1041
+ return tokens
1042
+
1043
+ @torch.no_grad()
1044
+ def generate(
1045
+ self,
1046
+ n_phrases: int = 32,
1047
+ controls: Optional[Dict[str, torch.Tensor]] = None,
1048
+ temperature: float = 0.8,
1049
+ max_phrase_len: int = 256,
1050
+ batch_size: int = 1,
1051
+ ) -> List[torch.Tensor]:
1052
+ """
1053
+ Full generation pipeline.
1054
+
1055
+ Controls → LatentMamba → PhraseVAE.decode → REMI+ tokens
1056
+
1057
+ Memory: O(D·N) fixed during generation — truly infinite.
1058
+ """
1059
+ self.eval()
1060
+
1061
+ # Stage 2: Generate phrase latent sequence
1062
+ z_seq = self.latent_mamba.generate(
1063
+ n_phrases, controls, temperature, batch_size
1064
+ )
1065
+
1066
+ # Stage 1 (decode): Latent → REMI+ tokens
1067
+ decoded_phrases = self.decode_phrases(z_seq, max_phrase_len)
1068
+
1069
+ return decoded_phrases
1070
+
1071
+ def count_parameters(self) -> Dict[str, int]:
1072
+ """Count parameters by component."""
1073
+ vae_enc = sum(p.numel() for p in self.phrase_vae.encoder.parameters())
1074
+ vae_dec = sum(p.numel() for p in self.phrase_vae.decoder.parameters())
1075
+ vae_emb = sum(p.numel() for p in self.phrase_vae.embedding.parameters())
1076
+ mamba = sum(p.numel() for p in self.latent_mamba.parameters())
1077
+ total = sum(p.numel() for p in self.parameters())
1078
+
1079
+ return {
1080
+ 'vae_encoder': vae_enc,
1081
+ 'vae_decoder': vae_dec,
1082
+ 'vae_embedding': vae_emb,
1083
+ 'latent_mamba': mamba,
1084
+ 'total': total,
1085
+ }
1086
+
1087
+ def get_vram_estimate(self, batch_size: int = 1, seq_len: int = 256,
1088
+ dtype_bytes: int = 2) -> Dict[str, str]:
1089
+ """Estimate VRAM usage."""
1090
+ params = self.count_parameters()
1091
+
1092
+ # Parameters
1093
+ param_mem = params['total'] * dtype_bytes
1094
+
1095
+ # Activations (rough estimate: 2x parameters for forward pass)
1096
+ act_mem = param_mem * 2
1097
+
1098
+ # Optimizer states (AdamW: 2 states per param)
1099
+ opt_mem = params['total'] * 4 * 2 # FP32 optimizer states
1100
+
1101
+ training_mem = param_mem + act_mem + opt_mem
1102
+ inference_mem = param_mem + act_mem // 4 # Much less activations
1103
+
1104
+ return {
1105
+ 'parameters_mb': f"{param_mem / 1e6:.1f} MB",
1106
+ 'training_vram_mb': f"{training_mem / 1e6:.1f} MB",
1107
+ 'inference_vram_mb': f"{inference_mem / 1e6:.1f} MB",
1108
+ }
1109
+
1110
+
1111
+ # ============================================================================
1112
+ # ZClip — Adaptive Gradient Clipping
1113
+ # ============================================================================
1114
+
1115
+ class ZClip:
1116
+ """
1117
+ Adaptive gradient clipping via z-score thresholding.
1118
+
1119
+ From ZClip (2025) "Adaptive Spike Mitigation for LLM Pre-Training"
1120
+ (arXiv:2504.02507).
1121
+
1122
+ Only clips genuine gradient spikes, not normal gradients.
1123
+ Optimal z_thresh: 2.0-3.0 (Table 6 in paper).
1124
+ """
1125
+
1126
+ def __init__(self, z_thresh: float = 2.5, alpha: float = 0.99):
1127
+ self.z_thresh = z_thresh
1128
+ self.alpha = alpha
1129
+ self.mu = 0.0
1130
+ self.var = 1.0
1131
+ self.initialized = False
1132
+
1133
+ def __call__(self, model: nn.Module) -> float:
1134
+ """Clip gradients and return the original norm."""
1135
+ total_norm = torch.nn.utils.clip_grad_norm_(
1136
+ model.parameters(), float('inf')
1137
+ ).item()
1138
+
1139
+ if not self.initialized:
1140
+ self.mu = total_norm
1141
+ self.var = 0.0
1142
+ self.initialized = True
1143
+ return total_norm
1144
+
1145
+ # Compute adaptive threshold
1146
+ sigma = max(math.sqrt(self.var), 1e-8)
1147
+ threshold = self.mu + self.z_thresh * sigma
1148
+
1149
+ # Clip only if genuine spike
1150
+ if total_norm > threshold:
1151
+ torch.nn.utils.clip_grad_norm_(model.parameters(), threshold)
1152
+
1153
+ # Update EMA statistics
1154
+ self.mu = self.alpha * self.mu + (1 - self.alpha) * total_norm
1155
+ self.var = self.alpha * self.var + (1 - self.alpha) * (total_norm - self.mu) ** 2
1156
+
1157
+ return total_norm
1158
+
1159
+
1160
+ # ============================================================================
1161
+ # Utility: Model summary
1162
+ # ============================================================================
1163
+
1164
+ def model_summary(config: Optional[MuseMorphicConfig] = None):
1165
+ """Print model summary with parameter counts and VRAM estimates."""
1166
+ if config is None:
1167
+ config = MuseMorphicConfig()
1168
+
1169
+ model = MuseMorphic(config)
1170
+ params = model.count_parameters()
1171
+ vram = model.get_vram_estimate()
1172
+
1173
+ print("=" * 60)
1174
+ print("MuseMorphic Model Summary")
1175
+ print("=" * 60)
1176
+ print(f"\nParameter Counts:")
1177
+ for name, count in params.items():
1178
+ print(f" {name:20s}: {count:>10,d} ({count/1e6:.2f}M)")
1179
+
1180
+ print(f"\nVRAM Estimates (BF16):")
1181
+ for name, est in vram.items():
1182
+ print(f" {name:20s}: {est}")
1183
+
1184
+ print(f"\nArchitecture:")
1185
+ print(f" d_model: {config.d_model}")
1186
+ print(f" Vocab size: {config.vocab_size}")
1187
+ print(f" Latent dim: {config.latent_dim}")
1188
+ print(f" VAE layers: {config.vae_encoder_layers}+{config.vae_decoder_layers}")
1189
+ print(f" Mamba layers: {config.mamba_n_layers}")
1190
+ print(f" Mamba state dim: {config.mamba_d_state}")
1191
+ print(f" Max phrase tokens: {config.vae_max_seq_len}")
1192
+ print(f" Max phrases: {config.max_phrases}")
1193
+ print("=" * 60)
1194
+
1195
+ return model
1196
+
1197
+
1198
+ if __name__ == "__main__":
1199
+ model = model_summary()