""" MuseMorphic Training Pipeline ============================== Two-stage training with curriculum and stability guarantees: Stage 1 — PhraseVAE Training: 1a. Span-infilling pretraining (learn REMI grammar) 1b. Autoencoder training (KL weight = 0, pure reconstruction) 1c. VAE fine-tuning (KL weight = 0.01) Stage 2 — LatentMamba Training: Freeze PhraseVAE encoder, train LatentMamba on latent phrase sequences. Uses MSE loss on predicted vs actual latent vectors. Training Stability Stack: - σReparam on all linear layers (prevents attention entropy collapse) - ZClip adaptive gradient clipping (clips only genuine spikes) - Pre-LayerNorm (bounded gradients, no warmup needed) - BFloat16 mixed precision (no loss scaling needed, no overflow) - Label smoothing ε=0.1 (prevents overconfident predictions) - Cosine annealing with warm restarts (SGDR) - Per-step NaN/Inf monitoring with automatic recovery """ import os import sys import math import time import json import random import logging from pathlib import Path from typing import Optional, Dict, List, Tuple from dataclasses import dataclass, asdict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from model import MuseMorphicConfig, MuseMorphic, PhraseVAE, LatentMamba, ZClip logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s') logger = logging.getLogger(__name__) # ============================================================================ # Training Configuration # ============================================================================ @dataclass class TrainConfig: """Training hyperparameters.""" # General seed: int = 42 device: str = "auto" # auto, cuda, cpu dtype: str = "bf16" # bf16, fp16, fp32 # Stage 1: PhraseVAE vae_epochs_pretrain: int = 5 # 1a: span-infilling vae_epochs_ae: int = 20 # 1b: autoencoder (KL=0) vae_epochs_vae: int = 10 # 1c: VAE fine-tune (KL=0.01) vae_batch_size: int = 64 vae_lr: float = 3e-4 vae_weight_decay: float = 0.01 vae_max_seq_len: int = 256 # Stage 2: LatentMamba mamba_epochs: int = 50 mamba_batch_size: int = 32 mamba_lr: float = 1e-4 mamba_weight_decay: float = 0.01 mamba_max_phrases: int = 128 # Optimization gradient_accumulation_steps: int = 1 max_grad_norm: float = 1.0 # Fallback fixed clip (ZClip adapts on top) warmup_steps: int = 500 # Scheduler: Cosine Annealing with Warm Restarts (SGDR) sgdr_t0: int = 1000 sgdr_t_mult: int = 2 sgdr_eta_min: float = 1e-6 # Stability use_zclip: bool = True zclip_z_thresh: float = 2.5 zclip_alpha: float = 0.99 label_smoothing: float = 0.1 kl_beta: float = 0.01 # Monitoring log_every_n_steps: int = 10 eval_every_n_steps: int = 500 save_every_n_steps: int = 1000 # Paths output_dir: str = "./checkpoints" data_dir: str = "./data" # Hub push_to_hub: bool = True hub_model_id: str = "" # ============================================================================ # Dataset # ============================================================================ class PhraseDataset(Dataset): """ Dataset of tokenized REMI+ phrases for PhraseVAE training. Each item is a padded sequence of token IDs representing one phrase (one bar of one track). """ def __init__(self, phrases: List[List[int]], max_len: int = 256, pad_id: int = 0): self.phrases = phrases self.max_len = max_len self.pad_id = pad_id def __len__(self): return len(self.phrases) def __getitem__(self, idx): ids = self.phrases[idx][:self.max_len] # Pad padded = ids + [self.pad_id] * (self.max_len - len(ids)) return { 'token_ids': torch.tensor(padded, dtype=torch.long), 'length': min(len(ids), self.max_len), } class LatentSequenceDataset(Dataset): """ Dataset of latent phrase sequences for LatentMamba training. Each item is a sequence of latent vectors (encoded by PhraseVAE) with associated control attributes. """ def __init__(self, latent_sequences: List[torch.Tensor], controls: Optional[List[Dict[str, int]]] = None, max_phrases: int = 128): self.latent_sequences = latent_sequences self.controls = controls self.max_phrases = max_phrases def __len__(self): return len(self.latent_sequences) def __getitem__(self, idx): z_seq = self.latent_sequences[idx][:self.max_phrases] T = z_seq.shape[0] # Pad if needed if T < self.max_phrases: pad = torch.zeros(self.max_phrases - T, z_seq.shape[-1]) z_seq = torch.cat([z_seq, pad], dim=0) item = { 'z_seq': z_seq, 'length': T, } if self.controls: ctrl = self.controls[idx] item['controls'] = {k: torch.tensor(v, dtype=torch.long) for k, v in ctrl.items()} return item # ============================================================================ # Training Utilities # ============================================================================ def get_device(config: TrainConfig) -> torch.device: """Auto-detect best device.""" if config.device == "auto": if torch.cuda.is_available(): return torch.device("cuda") elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") return torch.device(config.device) def get_dtype(config: TrainConfig) -> torch.dtype: """Get torch dtype from config string.""" if config.dtype == "bf16": if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): return torch.bfloat16 return torch.float32 # Fallback elif config.dtype == "fp16": return torch.float16 return torch.float32 def set_seed(seed: int): """Set all random seeds for reproducibility.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) class NaNMonitor: """ Monitor for NaN/Inf in loss and gradients. If NaN detected: 1. Skip the optimization step 2. Reduce learning rate by 50% 3. Log warning 4. If 5 consecutive NaNs, stop training """ def __init__(self, max_consecutive: int = 5): self.max_consecutive = max_consecutive self.consecutive_nan = 0 self.total_nan = 0 def check(self, loss: torch.Tensor, optimizer: torch.optim.Optimizer) -> bool: """ Check for NaN/Inf. Returns True if training should continue. """ if torch.isnan(loss) or torch.isinf(loss): self.consecutive_nan += 1 self.total_nan += 1 logger.warning(f"NaN/Inf detected! Consecutive: {self.consecutive_nan}, " f"Total: {self.total_nan}") if self.consecutive_nan >= self.max_consecutive: logger.error(f"Training stopped: {self.max_consecutive} consecutive NaN/Inf") return False # Reduce learning rate for param_group in optimizer.param_groups: param_group['lr'] *= 0.5 logger.info(f"Reduced LR to {param_group['lr']:.2e}") # Zero gradients (skip this step) optimizer.zero_grad() return True self.consecutive_nan = 0 return True class MetricsTracker: """Simple metrics tracking with exponential moving average.""" def __init__(self, alpha: float = 0.99): self.alpha = alpha self.metrics = {} self.step_count = 0 def update(self, **kwargs): for k, v in kwargs.items(): if isinstance(v, torch.Tensor): v = v.item() if k not in self.metrics: self.metrics[k] = v else: self.metrics[k] = self.alpha * self.metrics[k] + (1 - self.alpha) * v self.step_count += 1 def get(self) -> Dict[str, float]: return {k: round(v, 6) for k, v in self.metrics.items()} def log(self, prefix: str = ""): metrics = self.get() parts = [f"{k}={v:.6f}" for k, v in metrics.items()] logger.info(f"{prefix}step={self.step_count} | {' | '.join(parts)}") # ============================================================================ # Stage 1: PhraseVAE Training # ============================================================================ def train_phrase_vae( model: PhraseVAE, train_dataset: PhraseDataset, val_dataset: Optional[PhraseDataset], config: TrainConfig, device: torch.device, dtype: torch.dtype, ) -> PhraseVAE: """ Three-stage PhraseVAE training curriculum. Stage 1a: Span-infilling pretraining (learn REMI grammar) Stage 1b: Autoencoder (KL=0, pure reconstruction) Stage 1c: VAE fine-tuning (KL=0.01) """ logger.info("=" * 60) logger.info("Stage 1: PhraseVAE Training") logger.info("=" * 60) model = model.to(device) # Optimizer with weight decay (excluding biases and LN params) no_decay = ['bias', 'LayerNorm', 'layer_norm', 'b_sin', 'b_cos'] param_groups = [ {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': config.vae_weight_decay}, {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = torch.optim.AdamW(param_groups, lr=config.vae_lr, betas=(0.9, 0.999)) # SGDR scheduler scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=config.sgdr_t0, T_mult=config.sgdr_t_mult, eta_min=config.sgdr_eta_min ) # Stability tools zclip = ZClip(config.zclip_z_thresh, config.zclip_alpha) if config.use_zclip else None nan_monitor = NaNMonitor() metrics = MetricsTracker() train_loader = DataLoader( train_dataset, batch_size=config.vae_batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=True ) # ---- Stage 1a: Span-infilling pretraining ---- logger.info("\n--- Stage 1a: Span-infilling pretraining ---") for epoch in range(config.vae_epochs_pretrain): model.train() for batch_idx, batch in enumerate(train_loader): token_ids = batch['token_ids'].to(device) # Apply span masking (mask 15% of tokens) masked_ids, mask = _apply_span_mask(token_ids, mask_prob=0.15, mask_id=model.config.mask_token_id) with torch.autocast(device_type=device.type, dtype=dtype): outputs = model(masked_ids, target_tokens=token_ids, kl_weight=0.0) loss = outputs['loss'] if not nan_monitor.check(loss, optimizer): return model loss.backward() if zclip: grad_norm = zclip(model) else: grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm).item() optimizer.step() scheduler.step() optimizer.zero_grad() metrics.update(loss=loss, recon=outputs['recon_loss'], grad_norm=grad_norm) if batch_idx % config.log_every_n_steps == 0: metrics.log(prefix=f"[1a] Epoch {epoch+1}/{config.vae_epochs_pretrain} ") # ---- Stage 1b: Autoencoder training (KL=0) ---- logger.info("\n--- Stage 1b: Autoencoder training (KL weight = 0) ---") for epoch in range(config.vae_epochs_ae): model.train() for batch_idx, batch in enumerate(train_loader): token_ids = batch['token_ids'].to(device) with torch.autocast(device_type=device.type, dtype=dtype): outputs = model(token_ids, kl_weight=0.0) # Pure reconstruction loss = outputs['loss'] if not nan_monitor.check(loss, optimizer): return model loss.backward() if zclip: zclip(model) optimizer.step() scheduler.step() optimizer.zero_grad() metrics.update(loss=loss, recon=outputs['recon_loss'], kl=outputs['kl_loss']) if batch_idx % config.log_every_n_steps == 0: metrics.log(prefix=f"[1b] Epoch {epoch+1}/{config.vae_epochs_ae} ") # ---- Stage 1c: VAE fine-tuning (KL=β=0.01) ---- logger.info("\n--- Stage 1c: VAE fine-tuning (KL weight = 0.01) ---") # Lower learning rate for fine-tuning for pg in optimizer.param_groups: pg['lr'] = config.vae_lr * 0.1 for epoch in range(config.vae_epochs_vae): model.train() for batch_idx, batch in enumerate(train_loader): token_ids = batch['token_ids'].to(device) with torch.autocast(device_type=device.type, dtype=dtype): outputs = model(token_ids, kl_weight=config.kl_beta) loss = outputs['loss'] if not nan_monitor.check(loss, optimizer): return model loss.backward() if zclip: zclip(model) optimizer.step() scheduler.step() optimizer.zero_grad() metrics.update(loss=loss, recon=outputs['recon_loss'], kl=outputs['kl_loss']) if batch_idx % config.log_every_n_steps == 0: metrics.log(prefix=f"[1c] Epoch {epoch+1}/{config.vae_epochs_vae} ") logger.info("Stage 1 complete!") return model # ============================================================================ # Stage 2: LatentMamba Training # ============================================================================ def train_latent_mamba( mamba_model: LatentMamba, vae_model: PhraseVAE, train_dataset: PhraseDataset, config: TrainConfig, device: torch.device, dtype: torch.dtype, ) -> LatentMamba: """ Train LatentMamba on phrase latent sequences. 1. Freeze PhraseVAE encoder 2. Encode all training phrases into latent sequences 3. Train LatentMamba to predict next phrase latents """ logger.info("=" * 60) logger.info("Stage 2: LatentMamba Training") logger.info("=" * 60) # Freeze VAE vae_model.eval() for p in vae_model.parameters(): p.requires_grad = False mamba_model = mamba_model.to(device) # Optimizer optimizer = torch.optim.AdamW( mamba_model.parameters(), lr=config.mamba_lr, weight_decay=config.mamba_weight_decay, betas=(0.9, 0.999) ) scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=config.sgdr_t0, T_mult=config.sgdr_t_mult, eta_min=config.sgdr_eta_min ) zclip = ZClip(config.zclip_z_thresh, config.zclip_alpha) if config.use_zclip else None nan_monitor = NaNMonitor() metrics = MetricsTracker() # Encode all phrases to latent vectors first logger.info("Encoding training phrases to latent space...") latent_sequences = _encode_all_phrases(vae_model, train_dataset, device, dtype, config.mamba_batch_size) latent_dataset = LatentSequenceDataset(latent_sequences, max_phrases=config.mamba_max_phrases) train_loader = DataLoader( latent_dataset, batch_size=config.mamba_batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=True ) # Training loop for epoch in range(config.mamba_epochs): mamba_model.train() for batch_idx, batch in enumerate(train_loader): z_seq = batch['z_seq'].to(device) lengths = batch['length'] # Input: z_1, ..., z_{T-1} # Target: z_2, ..., z_T (shifted by 1) z_input = z_seq[:, :-1] z_target = z_seq[:, 1:] with torch.autocast(device_type=device.type, dtype=dtype): z_pred = mamba_model(z_input) # MSE loss on latent vectors (with length masking) mask = torch.arange(z_target.shape[1], device=device).unsqueeze(0) < (lengths.unsqueeze(1) - 1).to(device) mask = mask.unsqueeze(-1).float() loss = F.mse_loss(z_pred * mask, z_target * mask) # Optional: Add cosine similarity loss for direction matching cos_loss = 1.0 - F.cosine_similarity( z_pred.reshape(-1, z_pred.shape[-1]), z_target.reshape(-1, z_target.shape[-1]), dim=-1 ).mean() total_loss = loss + 0.1 * cos_loss if not nan_monitor.check(total_loss, optimizer): return mamba_model total_loss.backward() if zclip: zclip(mamba_model) optimizer.step() scheduler.step() optimizer.zero_grad() metrics.update(loss=loss, cos_loss=cos_loss, total=total_loss) if batch_idx % config.log_every_n_steps == 0: metrics.log(prefix=f"[S2] Epoch {epoch+1}/{config.mamba_epochs} ") logger.info("Stage 2 complete!") return mamba_model # ============================================================================ # Helper Functions # ============================================================================ def _apply_span_mask(token_ids: torch.Tensor, mask_prob: float = 0.15, mask_id: int = 3, span_length: int = 3) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply span masking for pretraining (like T5/BART). Masks contiguous spans of tokens. """ masked = token_ids.clone() B, L = masked.shape mask = torch.zeros_like(masked, dtype=torch.bool) for b in range(B): n_masks = max(1, int(L * mask_prob / span_length)) for _ in range(n_masks): start = random.randint(1, max(1, L - span_length - 1)) # Don't mask BOS end = min(start + span_length, L) masked[b, start:end] = mask_id mask[b, start:end] = True return masked, mask def _encode_all_phrases(vae_model: PhraseVAE, dataset: PhraseDataset, device: torch.device, dtype: torch.dtype, batch_size: int = 64) -> List[torch.Tensor]: """Encode all phrases in dataset to latent vectors.""" loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2) all_latents = [] with torch.no_grad(): for batch in loader: token_ids = batch['token_ids'].to(device) with torch.autocast(device_type=device.type, dtype=dtype): z, _, _ = vae_model.encode(token_ids) all_latents.append(z.cpu()) # Concatenate and reshape into sequences all_z = torch.cat(all_latents, dim=0) # (N_total, latent_dim) # Group into sequences (simple: fixed-length chunks) # In practice, you'd group by song/piece chunk_size = 32 # phrases per sequence sequences = [] for i in range(0, len(all_z) - chunk_size, chunk_size): sequences.append(all_z[i:i+chunk_size]) logger.info(f"Encoded {len(all_z)} phrases into {len(sequences)} sequences") return sequences # ============================================================================ # Save/Load # ============================================================================ def save_checkpoint(model: MuseMorphic, config: TrainConfig, model_config: MuseMorphicConfig, step: int, path: str): """Save model checkpoint.""" os.makedirs(path, exist_ok=True) torch.save({ 'model_state_dict': model.state_dict(), 'step': step, 'model_config': asdict(model_config), 'train_config': asdict(config), }, os.path.join(path, f'checkpoint_{step}.pt')) # Also save latest torch.save({ 'model_state_dict': model.state_dict(), 'step': step, 'model_config': asdict(model_config), 'train_config': asdict(config), }, os.path.join(path, 'checkpoint_latest.pt')) logger.info(f"Saved checkpoint at step {step} to {path}") def load_checkpoint(path: str, device: torch.device) -> Tuple[MuseMorphic, Dict]: """Load model from checkpoint.""" ckpt = torch.load(os.path.join(path, 'checkpoint_latest.pt'), map_location=device) model_config = MuseMorphicConfig(**ckpt['model_config']) model = MuseMorphic(model_config) model.load_state_dict(ckpt['model_state_dict']) return model, ckpt # ============================================================================ # Main Training Pipeline # ============================================================================ def train_musemorphic( model_config: Optional[MuseMorphicConfig] = None, train_config: Optional[TrainConfig] = None, train_phrases: Optional[List[List[int]]] = None, ): """ Complete MuseMorphic training pipeline. If train_phrases is None, generates synthetic data for testing. """ if model_config is None: model_config = MuseMorphicConfig() if train_config is None: train_config = TrainConfig() set_seed(train_config.seed) device = get_device(train_config) dtype = get_dtype(train_config) logger.info(f"Device: {device}, Dtype: {dtype}") # Create model model = MuseMorphic(model_config) params = model.count_parameters() logger.info(f"Model parameters: {params}") # Generate synthetic data if none provided if train_phrases is None: logger.info("No training data provided. Generating synthetic data for testing...") train_phrases = _generate_synthetic_phrases(1000, model_config.vae_max_seq_len, model_config.vocab_size) # Create dataset train_dataset = PhraseDataset(train_phrases, model_config.vae_max_seq_len, model_config.pad_token_id) logger.info(f"Training dataset: {len(train_dataset)} phrases") # Stage 1: Train PhraseVAE model.phrase_vae = train_phrase_vae( model.phrase_vae, train_dataset, None, train_config, device, dtype ) # Stage 2: Train LatentMamba model.latent_mamba = train_latent_mamba( model.latent_mamba, model.phrase_vae, train_dataset, train_config, device, dtype ) # Save final model save_checkpoint(model, train_config, model_config, -1, train_config.output_dir) return model def _generate_synthetic_phrases(n: int, max_len: int, vocab_size: int) -> List[List[int]]: """Generate synthetic REMI-like phrases for testing.""" phrases = [] for _ in range(n): length = random.randint(10, max_len) # Generate somewhat structured sequences (not purely random) phrase = [1] # BOS for _ in range(length - 2): # Simulate REMI structure: position, pitch, velocity, duration pattern tok = random.randint(4, vocab_size - 1) phrase.append(tok) phrase.append(2) # EOS phrases.append(phrase) return phrases if __name__ == "__main__": model = train_musemorphic()