asdf98 commited on
Commit
2f53a7d
·
verified ·
1 Parent(s): 09ff16c

Upload musemorphic/train.py

Browse files
Files changed (1) hide show
  1. musemorphic/train.py +713 -0
musemorphic/train.py ADDED
@@ -0,0 +1,713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MuseMorphic Training Pipeline
3
+ ==============================
4
+
5
+ Two-stage training with curriculum and stability guarantees:
6
+
7
+ Stage 1 — PhraseVAE Training:
8
+ 1a. Span-infilling pretraining (learn REMI grammar)
9
+ 1b. Autoencoder training (KL weight = 0, pure reconstruction)
10
+ 1c. VAE fine-tuning (KL weight = 0.01)
11
+
12
+ Stage 2 — LatentMamba Training:
13
+ Freeze PhraseVAE encoder, train LatentMamba on latent phrase sequences.
14
+ Uses MSE loss on predicted vs actual latent vectors.
15
+
16
+ Training Stability Stack:
17
+ - σReparam on all linear layers (prevents attention entropy collapse)
18
+ - ZClip adaptive gradient clipping (clips only genuine spikes)
19
+ - Pre-LayerNorm (bounded gradients, no warmup needed)
20
+ - BFloat16 mixed precision (no loss scaling needed, no overflow)
21
+ - Label smoothing ε=0.1 (prevents overconfident predictions)
22
+ - Cosine annealing with warm restarts (SGDR)
23
+ - Per-step NaN/Inf monitoring with automatic recovery
24
+ """
25
+
26
+ import os
27
+ import sys
28
+ import math
29
+ import time
30
+ import json
31
+ import random
32
+ import logging
33
+ from pathlib import Path
34
+ from typing import Optional, Dict, List, Tuple
35
+ from dataclasses import dataclass, asdict
36
+
37
+ import numpy as np
38
+ import torch
39
+ import torch.nn as nn
40
+ import torch.nn.functional as F
41
+ from torch.utils.data import Dataset, DataLoader
42
+
43
+ from model import MuseMorphicConfig, MuseMorphic, PhraseVAE, LatentMamba, ZClip
44
+
45
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
46
+ logger = logging.getLogger(__name__)
47
+
48
+
49
+ # ============================================================================
50
+ # Training Configuration
51
+ # ============================================================================
52
+
53
+ @dataclass
54
+ class TrainConfig:
55
+ """Training hyperparameters."""
56
+
57
+ # General
58
+ seed: int = 42
59
+ device: str = "auto" # auto, cuda, cpu
60
+ dtype: str = "bf16" # bf16, fp16, fp32
61
+
62
+ # Stage 1: PhraseVAE
63
+ vae_epochs_pretrain: int = 5 # 1a: span-infilling
64
+ vae_epochs_ae: int = 20 # 1b: autoencoder (KL=0)
65
+ vae_epochs_vae: int = 10 # 1c: VAE fine-tune (KL=0.01)
66
+ vae_batch_size: int = 64
67
+ vae_lr: float = 3e-4
68
+ vae_weight_decay: float = 0.01
69
+ vae_max_seq_len: int = 256
70
+
71
+ # Stage 2: LatentMamba
72
+ mamba_epochs: int = 50
73
+ mamba_batch_size: int = 32
74
+ mamba_lr: float = 1e-4
75
+ mamba_weight_decay: float = 0.01
76
+ mamba_max_phrases: int = 128
77
+
78
+ # Optimization
79
+ gradient_accumulation_steps: int = 1
80
+ max_grad_norm: float = 1.0 # Fallback fixed clip (ZClip adapts on top)
81
+ warmup_steps: int = 500
82
+
83
+ # Scheduler: Cosine Annealing with Warm Restarts (SGDR)
84
+ sgdr_t0: int = 1000
85
+ sgdr_t_mult: int = 2
86
+ sgdr_eta_min: float = 1e-6
87
+
88
+ # Stability
89
+ use_zclip: bool = True
90
+ zclip_z_thresh: float = 2.5
91
+ zclip_alpha: float = 0.99
92
+ label_smoothing: float = 0.1
93
+ kl_beta: float = 0.01
94
+
95
+ # Monitoring
96
+ log_every_n_steps: int = 10
97
+ eval_every_n_steps: int = 500
98
+ save_every_n_steps: int = 1000
99
+
100
+ # Paths
101
+ output_dir: str = "./checkpoints"
102
+ data_dir: str = "./data"
103
+
104
+ # Hub
105
+ push_to_hub: bool = True
106
+ hub_model_id: str = ""
107
+
108
+
109
+ # ============================================================================
110
+ # Dataset
111
+ # ============================================================================
112
+
113
+ class PhraseDataset(Dataset):
114
+ """
115
+ Dataset of tokenized REMI+ phrases for PhraseVAE training.
116
+
117
+ Each item is a padded sequence of token IDs representing one phrase
118
+ (one bar of one track).
119
+ """
120
+
121
+ def __init__(self, phrases: List[List[int]], max_len: int = 256, pad_id: int = 0):
122
+ self.phrases = phrases
123
+ self.max_len = max_len
124
+ self.pad_id = pad_id
125
+
126
+ def __len__(self):
127
+ return len(self.phrases)
128
+
129
+ def __getitem__(self, idx):
130
+ ids = self.phrases[idx][:self.max_len]
131
+
132
+ # Pad
133
+ padded = ids + [self.pad_id] * (self.max_len - len(ids))
134
+
135
+ return {
136
+ 'token_ids': torch.tensor(padded, dtype=torch.long),
137
+ 'length': min(len(ids), self.max_len),
138
+ }
139
+
140
+
141
+ class LatentSequenceDataset(Dataset):
142
+ """
143
+ Dataset of latent phrase sequences for LatentMamba training.
144
+
145
+ Each item is a sequence of latent vectors (encoded by PhraseVAE)
146
+ with associated control attributes.
147
+ """
148
+
149
+ def __init__(self, latent_sequences: List[torch.Tensor],
150
+ controls: Optional[List[Dict[str, int]]] = None,
151
+ max_phrases: int = 128):
152
+ self.latent_sequences = latent_sequences
153
+ self.controls = controls
154
+ self.max_phrases = max_phrases
155
+
156
+ def __len__(self):
157
+ return len(self.latent_sequences)
158
+
159
+ def __getitem__(self, idx):
160
+ z_seq = self.latent_sequences[idx][:self.max_phrases]
161
+ T = z_seq.shape[0]
162
+
163
+ # Pad if needed
164
+ if T < self.max_phrases:
165
+ pad = torch.zeros(self.max_phrases - T, z_seq.shape[-1])
166
+ z_seq = torch.cat([z_seq, pad], dim=0)
167
+
168
+ item = {
169
+ 'z_seq': z_seq,
170
+ 'length': T,
171
+ }
172
+
173
+ if self.controls:
174
+ ctrl = self.controls[idx]
175
+ item['controls'] = {k: torch.tensor(v, dtype=torch.long) for k, v in ctrl.items()}
176
+
177
+ return item
178
+
179
+
180
+ # ============================================================================
181
+ # Training Utilities
182
+ # ============================================================================
183
+
184
+ def get_device(config: TrainConfig) -> torch.device:
185
+ """Auto-detect best device."""
186
+ if config.device == "auto":
187
+ if torch.cuda.is_available():
188
+ return torch.device("cuda")
189
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
190
+ return torch.device("mps")
191
+ return torch.device("cpu")
192
+ return torch.device(config.device)
193
+
194
+
195
+ def get_dtype(config: TrainConfig) -> torch.dtype:
196
+ """Get torch dtype from config string."""
197
+ if config.dtype == "bf16":
198
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
199
+ return torch.bfloat16
200
+ return torch.float32 # Fallback
201
+ elif config.dtype == "fp16":
202
+ return torch.float16
203
+ return torch.float32
204
+
205
+
206
+ def set_seed(seed: int):
207
+ """Set all random seeds for reproducibility."""
208
+ random.seed(seed)
209
+ np.random.seed(seed)
210
+ torch.manual_seed(seed)
211
+ if torch.cuda.is_available():
212
+ torch.cuda.manual_seed_all(seed)
213
+
214
+
215
+ class NaNMonitor:
216
+ """
217
+ Monitor for NaN/Inf in loss and gradients.
218
+
219
+ If NaN detected:
220
+ 1. Skip the optimization step
221
+ 2. Reduce learning rate by 50%
222
+ 3. Log warning
223
+ 4. If 5 consecutive NaNs, stop training
224
+ """
225
+
226
+ def __init__(self, max_consecutive: int = 5):
227
+ self.max_consecutive = max_consecutive
228
+ self.consecutive_nan = 0
229
+ self.total_nan = 0
230
+
231
+ def check(self, loss: torch.Tensor, optimizer: torch.optim.Optimizer) -> bool:
232
+ """
233
+ Check for NaN/Inf. Returns True if training should continue.
234
+ """
235
+ if torch.isnan(loss) or torch.isinf(loss):
236
+ self.consecutive_nan += 1
237
+ self.total_nan += 1
238
+
239
+ logger.warning(f"NaN/Inf detected! Consecutive: {self.consecutive_nan}, "
240
+ f"Total: {self.total_nan}")
241
+
242
+ if self.consecutive_nan >= self.max_consecutive:
243
+ logger.error(f"Training stopped: {self.max_consecutive} consecutive NaN/Inf")
244
+ return False
245
+
246
+ # Reduce learning rate
247
+ for param_group in optimizer.param_groups:
248
+ param_group['lr'] *= 0.5
249
+ logger.info(f"Reduced LR to {param_group['lr']:.2e}")
250
+
251
+ # Zero gradients (skip this step)
252
+ optimizer.zero_grad()
253
+ return True
254
+
255
+ self.consecutive_nan = 0
256
+ return True
257
+
258
+
259
+ class MetricsTracker:
260
+ """Simple metrics tracking with exponential moving average."""
261
+
262
+ def __init__(self, alpha: float = 0.99):
263
+ self.alpha = alpha
264
+ self.metrics = {}
265
+ self.step_count = 0
266
+
267
+ def update(self, **kwargs):
268
+ for k, v in kwargs.items():
269
+ if isinstance(v, torch.Tensor):
270
+ v = v.item()
271
+ if k not in self.metrics:
272
+ self.metrics[k] = v
273
+ else:
274
+ self.metrics[k] = self.alpha * self.metrics[k] + (1 - self.alpha) * v
275
+ self.step_count += 1
276
+
277
+ def get(self) -> Dict[str, float]:
278
+ return {k: round(v, 6) for k, v in self.metrics.items()}
279
+
280
+ def log(self, prefix: str = ""):
281
+ metrics = self.get()
282
+ parts = [f"{k}={v:.6f}" for k, v in metrics.items()]
283
+ logger.info(f"{prefix}step={self.step_count} | {' | '.join(parts)}")
284
+
285
+
286
+ # ============================================================================
287
+ # Stage 1: PhraseVAE Training
288
+ # ============================================================================
289
+
290
+ def train_phrase_vae(
291
+ model: PhraseVAE,
292
+ train_dataset: PhraseDataset,
293
+ val_dataset: Optional[PhraseDataset],
294
+ config: TrainConfig,
295
+ device: torch.device,
296
+ dtype: torch.dtype,
297
+ ) -> PhraseVAE:
298
+ """
299
+ Three-stage PhraseVAE training curriculum.
300
+
301
+ Stage 1a: Span-infilling pretraining (learn REMI grammar)
302
+ Stage 1b: Autoencoder (KL=0, pure reconstruction)
303
+ Stage 1c: VAE fine-tuning (KL=0.01)
304
+ """
305
+
306
+ logger.info("=" * 60)
307
+ logger.info("Stage 1: PhraseVAE Training")
308
+ logger.info("=" * 60)
309
+
310
+ model = model.to(device)
311
+
312
+ # Optimizer with weight decay (excluding biases and LN params)
313
+ no_decay = ['bias', 'LayerNorm', 'layer_norm', 'b_sin', 'b_cos']
314
+ param_groups = [
315
+ {'params': [p for n, p in model.named_parameters()
316
+ if not any(nd in n for nd in no_decay)],
317
+ 'weight_decay': config.vae_weight_decay},
318
+ {'params': [p for n, p in model.named_parameters()
319
+ if any(nd in n for nd in no_decay)],
320
+ 'weight_decay': 0.0}
321
+ ]
322
+ optimizer = torch.optim.AdamW(param_groups, lr=config.vae_lr, betas=(0.9, 0.999))
323
+
324
+ # SGDR scheduler
325
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
326
+ optimizer, T_0=config.sgdr_t0, T_mult=config.sgdr_t_mult,
327
+ eta_min=config.sgdr_eta_min
328
+ )
329
+
330
+ # Stability tools
331
+ zclip = ZClip(config.zclip_z_thresh, config.zclip_alpha) if config.use_zclip else None
332
+ nan_monitor = NaNMonitor()
333
+ metrics = MetricsTracker()
334
+
335
+ train_loader = DataLoader(
336
+ train_dataset, batch_size=config.vae_batch_size,
337
+ shuffle=True, num_workers=2, pin_memory=True, drop_last=True
338
+ )
339
+
340
+ # ---- Stage 1a: Span-infilling pretraining ----
341
+ logger.info("\n--- Stage 1a: Span-infilling pretraining ---")
342
+ for epoch in range(config.vae_epochs_pretrain):
343
+ model.train()
344
+ for batch_idx, batch in enumerate(train_loader):
345
+ token_ids = batch['token_ids'].to(device)
346
+
347
+ # Apply span masking (mask 15% of tokens)
348
+ masked_ids, mask = _apply_span_mask(token_ids, mask_prob=0.15,
349
+ mask_id=model.config.mask_token_id)
350
+
351
+ with torch.autocast(device_type=device.type, dtype=dtype):
352
+ outputs = model(masked_ids, target_tokens=token_ids, kl_weight=0.0)
353
+
354
+ loss = outputs['loss']
355
+
356
+ if not nan_monitor.check(loss, optimizer):
357
+ return model
358
+
359
+ loss.backward()
360
+
361
+ if zclip:
362
+ grad_norm = zclip(model)
363
+ else:
364
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm).item()
365
+
366
+ optimizer.step()
367
+ scheduler.step()
368
+ optimizer.zero_grad()
369
+
370
+ metrics.update(loss=loss, recon=outputs['recon_loss'], grad_norm=grad_norm)
371
+
372
+ if batch_idx % config.log_every_n_steps == 0:
373
+ metrics.log(prefix=f"[1a] Epoch {epoch+1}/{config.vae_epochs_pretrain} ")
374
+
375
+ # ---- Stage 1b: Autoencoder training (KL=0) ----
376
+ logger.info("\n--- Stage 1b: Autoencoder training (KL weight = 0) ---")
377
+ for epoch in range(config.vae_epochs_ae):
378
+ model.train()
379
+ for batch_idx, batch in enumerate(train_loader):
380
+ token_ids = batch['token_ids'].to(device)
381
+
382
+ with torch.autocast(device_type=device.type, dtype=dtype):
383
+ outputs = model(token_ids, kl_weight=0.0) # Pure reconstruction
384
+
385
+ loss = outputs['loss']
386
+
387
+ if not nan_monitor.check(loss, optimizer):
388
+ return model
389
+
390
+ loss.backward()
391
+
392
+ if zclip:
393
+ zclip(model)
394
+
395
+ optimizer.step()
396
+ scheduler.step()
397
+ optimizer.zero_grad()
398
+
399
+ metrics.update(loss=loss, recon=outputs['recon_loss'], kl=outputs['kl_loss'])
400
+
401
+ if batch_idx % config.log_every_n_steps == 0:
402
+ metrics.log(prefix=f"[1b] Epoch {epoch+1}/{config.vae_epochs_ae} ")
403
+
404
+ # ---- Stage 1c: VAE fine-tuning (KL=β=0.01) ----
405
+ logger.info("\n--- Stage 1c: VAE fine-tuning (KL weight = 0.01) ---")
406
+ # Lower learning rate for fine-tuning
407
+ for pg in optimizer.param_groups:
408
+ pg['lr'] = config.vae_lr * 0.1
409
+
410
+ for epoch in range(config.vae_epochs_vae):
411
+ model.train()
412
+ for batch_idx, batch in enumerate(train_loader):
413
+ token_ids = batch['token_ids'].to(device)
414
+
415
+ with torch.autocast(device_type=device.type, dtype=dtype):
416
+ outputs = model(token_ids, kl_weight=config.kl_beta)
417
+
418
+ loss = outputs['loss']
419
+
420
+ if not nan_monitor.check(loss, optimizer):
421
+ return model
422
+
423
+ loss.backward()
424
+
425
+ if zclip:
426
+ zclip(model)
427
+
428
+ optimizer.step()
429
+ scheduler.step()
430
+ optimizer.zero_grad()
431
+
432
+ metrics.update(loss=loss, recon=outputs['recon_loss'], kl=outputs['kl_loss'])
433
+
434
+ if batch_idx % config.log_every_n_steps == 0:
435
+ metrics.log(prefix=f"[1c] Epoch {epoch+1}/{config.vae_epochs_vae} ")
436
+
437
+ logger.info("Stage 1 complete!")
438
+ return model
439
+
440
+
441
+ # ============================================================================
442
+ # Stage 2: LatentMamba Training
443
+ # ============================================================================
444
+
445
+ def train_latent_mamba(
446
+ mamba_model: LatentMamba,
447
+ vae_model: PhraseVAE,
448
+ train_dataset: PhraseDataset,
449
+ config: TrainConfig,
450
+ device: torch.device,
451
+ dtype: torch.dtype,
452
+ ) -> LatentMamba:
453
+ """
454
+ Train LatentMamba on phrase latent sequences.
455
+
456
+ 1. Freeze PhraseVAE encoder
457
+ 2. Encode all training phrases into latent sequences
458
+ 3. Train LatentMamba to predict next phrase latents
459
+ """
460
+
461
+ logger.info("=" * 60)
462
+ logger.info("Stage 2: LatentMamba Training")
463
+ logger.info("=" * 60)
464
+
465
+ # Freeze VAE
466
+ vae_model.eval()
467
+ for p in vae_model.parameters():
468
+ p.requires_grad = False
469
+
470
+ mamba_model = mamba_model.to(device)
471
+
472
+ # Optimizer
473
+ optimizer = torch.optim.AdamW(
474
+ mamba_model.parameters(), lr=config.mamba_lr,
475
+ weight_decay=config.mamba_weight_decay, betas=(0.9, 0.999)
476
+ )
477
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
478
+ optimizer, T_0=config.sgdr_t0, T_mult=config.sgdr_t_mult,
479
+ eta_min=config.sgdr_eta_min
480
+ )
481
+
482
+ zclip = ZClip(config.zclip_z_thresh, config.zclip_alpha) if config.use_zclip else None
483
+ nan_monitor = NaNMonitor()
484
+ metrics = MetricsTracker()
485
+
486
+ # Encode all phrases to latent vectors first
487
+ logger.info("Encoding training phrases to latent space...")
488
+ latent_sequences = _encode_all_phrases(vae_model, train_dataset, device, dtype,
489
+ config.mamba_batch_size)
490
+
491
+ latent_dataset = LatentSequenceDataset(latent_sequences, max_phrases=config.mamba_max_phrases)
492
+ train_loader = DataLoader(
493
+ latent_dataset, batch_size=config.mamba_batch_size,
494
+ shuffle=True, num_workers=2, pin_memory=True, drop_last=True
495
+ )
496
+
497
+ # Training loop
498
+ for epoch in range(config.mamba_epochs):
499
+ mamba_model.train()
500
+ for batch_idx, batch in enumerate(train_loader):
501
+ z_seq = batch['z_seq'].to(device)
502
+ lengths = batch['length']
503
+
504
+ # Input: z_1, ..., z_{T-1}
505
+ # Target: z_2, ..., z_T (shifted by 1)
506
+ z_input = z_seq[:, :-1]
507
+ z_target = z_seq[:, 1:]
508
+
509
+ with torch.autocast(device_type=device.type, dtype=dtype):
510
+ z_pred = mamba_model(z_input)
511
+
512
+ # MSE loss on latent vectors (with length masking)
513
+ mask = torch.arange(z_target.shape[1], device=device).unsqueeze(0) < (lengths.unsqueeze(1) - 1).to(device)
514
+ mask = mask.unsqueeze(-1).float()
515
+
516
+ loss = F.mse_loss(z_pred * mask, z_target * mask)
517
+
518
+ # Optional: Add cosine similarity loss for direction matching
519
+ cos_loss = 1.0 - F.cosine_similarity(
520
+ z_pred.reshape(-1, z_pred.shape[-1]),
521
+ z_target.reshape(-1, z_target.shape[-1]),
522
+ dim=-1
523
+ ).mean()
524
+
525
+ total_loss = loss + 0.1 * cos_loss
526
+
527
+ if not nan_monitor.check(total_loss, optimizer):
528
+ return mamba_model
529
+
530
+ total_loss.backward()
531
+
532
+ if zclip:
533
+ zclip(mamba_model)
534
+
535
+ optimizer.step()
536
+ scheduler.step()
537
+ optimizer.zero_grad()
538
+
539
+ metrics.update(loss=loss, cos_loss=cos_loss, total=total_loss)
540
+
541
+ if batch_idx % config.log_every_n_steps == 0:
542
+ metrics.log(prefix=f"[S2] Epoch {epoch+1}/{config.mamba_epochs} ")
543
+
544
+ logger.info("Stage 2 complete!")
545
+ return mamba_model
546
+
547
+
548
+ # ============================================================================
549
+ # Helper Functions
550
+ # ============================================================================
551
+
552
+ def _apply_span_mask(token_ids: torch.Tensor, mask_prob: float = 0.15,
553
+ mask_id: int = 3, span_length: int = 3) -> Tuple[torch.Tensor, torch.Tensor]:
554
+ """
555
+ Apply span masking for pretraining (like T5/BART).
556
+ Masks contiguous spans of tokens.
557
+ """
558
+ masked = token_ids.clone()
559
+ B, L = masked.shape
560
+ mask = torch.zeros_like(masked, dtype=torch.bool)
561
+
562
+ for b in range(B):
563
+ n_masks = max(1, int(L * mask_prob / span_length))
564
+ for _ in range(n_masks):
565
+ start = random.randint(1, max(1, L - span_length - 1)) # Don't mask BOS
566
+ end = min(start + span_length, L)
567
+ masked[b, start:end] = mask_id
568
+ mask[b, start:end] = True
569
+
570
+ return masked, mask
571
+
572
+
573
+ def _encode_all_phrases(vae_model: PhraseVAE, dataset: PhraseDataset,
574
+ device: torch.device, dtype: torch.dtype,
575
+ batch_size: int = 64) -> List[torch.Tensor]:
576
+ """Encode all phrases in dataset to latent vectors."""
577
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)
578
+
579
+ all_latents = []
580
+ with torch.no_grad():
581
+ for batch in loader:
582
+ token_ids = batch['token_ids'].to(device)
583
+ with torch.autocast(device_type=device.type, dtype=dtype):
584
+ z, _, _ = vae_model.encode(token_ids)
585
+ all_latents.append(z.cpu())
586
+
587
+ # Concatenate and reshape into sequences
588
+ all_z = torch.cat(all_latents, dim=0) # (N_total, latent_dim)
589
+
590
+ # Group into sequences (simple: fixed-length chunks)
591
+ # In practice, you'd group by song/piece
592
+ chunk_size = 32 # phrases per sequence
593
+ sequences = []
594
+ for i in range(0, len(all_z) - chunk_size, chunk_size):
595
+ sequences.append(all_z[i:i+chunk_size])
596
+
597
+ logger.info(f"Encoded {len(all_z)} phrases into {len(sequences)} sequences")
598
+ return sequences
599
+
600
+
601
+ # ============================================================================
602
+ # Save/Load
603
+ # ============================================================================
604
+
605
+ def save_checkpoint(model: MuseMorphic, config: TrainConfig,
606
+ model_config: MuseMorphicConfig, step: int, path: str):
607
+ """Save model checkpoint."""
608
+ os.makedirs(path, exist_ok=True)
609
+
610
+ torch.save({
611
+ 'model_state_dict': model.state_dict(),
612
+ 'step': step,
613
+ 'model_config': asdict(model_config),
614
+ 'train_config': asdict(config),
615
+ }, os.path.join(path, f'checkpoint_{step}.pt'))
616
+
617
+ # Also save latest
618
+ torch.save({
619
+ 'model_state_dict': model.state_dict(),
620
+ 'step': step,
621
+ 'model_config': asdict(model_config),
622
+ 'train_config': asdict(config),
623
+ }, os.path.join(path, 'checkpoint_latest.pt'))
624
+
625
+ logger.info(f"Saved checkpoint at step {step} to {path}")
626
+
627
+
628
+ def load_checkpoint(path: str, device: torch.device) -> Tuple[MuseMorphic, Dict]:
629
+ """Load model from checkpoint."""
630
+ ckpt = torch.load(os.path.join(path, 'checkpoint_latest.pt'), map_location=device)
631
+
632
+ model_config = MuseMorphicConfig(**ckpt['model_config'])
633
+ model = MuseMorphic(model_config)
634
+ model.load_state_dict(ckpt['model_state_dict'])
635
+
636
+ return model, ckpt
637
+
638
+
639
+ # ============================================================================
640
+ # Main Training Pipeline
641
+ # ============================================================================
642
+
643
+ def train_musemorphic(
644
+ model_config: Optional[MuseMorphicConfig] = None,
645
+ train_config: Optional[TrainConfig] = None,
646
+ train_phrases: Optional[List[List[int]]] = None,
647
+ ):
648
+ """
649
+ Complete MuseMorphic training pipeline.
650
+
651
+ If train_phrases is None, generates synthetic data for testing.
652
+ """
653
+ if model_config is None:
654
+ model_config = MuseMorphicConfig()
655
+ if train_config is None:
656
+ train_config = TrainConfig()
657
+
658
+ set_seed(train_config.seed)
659
+ device = get_device(train_config)
660
+ dtype = get_dtype(train_config)
661
+
662
+ logger.info(f"Device: {device}, Dtype: {dtype}")
663
+
664
+ # Create model
665
+ model = MuseMorphic(model_config)
666
+ params = model.count_parameters()
667
+ logger.info(f"Model parameters: {params}")
668
+
669
+ # Generate synthetic data if none provided
670
+ if train_phrases is None:
671
+ logger.info("No training data provided. Generating synthetic data for testing...")
672
+ train_phrases = _generate_synthetic_phrases(1000, model_config.vae_max_seq_len,
673
+ model_config.vocab_size)
674
+
675
+ # Create dataset
676
+ train_dataset = PhraseDataset(train_phrases, model_config.vae_max_seq_len, model_config.pad_token_id)
677
+ logger.info(f"Training dataset: {len(train_dataset)} phrases")
678
+
679
+ # Stage 1: Train PhraseVAE
680
+ model.phrase_vae = train_phrase_vae(
681
+ model.phrase_vae, train_dataset, None, train_config, device, dtype
682
+ )
683
+
684
+ # Stage 2: Train LatentMamba
685
+ model.latent_mamba = train_latent_mamba(
686
+ model.latent_mamba, model.phrase_vae, train_dataset,
687
+ train_config, device, dtype
688
+ )
689
+
690
+ # Save final model
691
+ save_checkpoint(model, train_config, model_config, -1, train_config.output_dir)
692
+
693
+ return model
694
+
695
+
696
+ def _generate_synthetic_phrases(n: int, max_len: int, vocab_size: int) -> List[List[int]]:
697
+ """Generate synthetic REMI-like phrases for testing."""
698
+ phrases = []
699
+ for _ in range(n):
700
+ length = random.randint(10, max_len)
701
+ # Generate somewhat structured sequences (not purely random)
702
+ phrase = [1] # BOS
703
+ for _ in range(length - 2):
704
+ # Simulate REMI structure: position, pitch, velocity, duration pattern
705
+ tok = random.randint(4, vocab_size - 1)
706
+ phrase.append(tok)
707
+ phrase.append(2) # EOS
708
+ phrases.append(phrase)
709
+ return phrases
710
+
711
+
712
+ if __name__ == "__main__":
713
+ model = train_musemorphic()