| """ |
| MuseMorphic Training Pipeline |
| ============================== |
| |
| Two-stage training with curriculum and stability guarantees: |
| |
| Stage 1 — PhraseVAE Training: |
| 1a. Span-infilling pretraining (learn REMI grammar) |
| 1b. Autoencoder training (KL weight = 0, pure reconstruction) |
| 1c. VAE fine-tuning (KL weight = 0.01) |
| |
| Stage 2 — LatentMamba Training: |
| Freeze PhraseVAE encoder, train LatentMamba on latent phrase sequences. |
| Uses MSE loss on predicted vs actual latent vectors. |
| |
| Training Stability Stack: |
| - σReparam on all linear layers (prevents attention entropy collapse) |
| - ZClip adaptive gradient clipping (clips only genuine spikes) |
| - Pre-LayerNorm (bounded gradients, no warmup needed) |
| - BFloat16 mixed precision (no loss scaling needed, no overflow) |
| - Label smoothing ε=0.1 (prevents overconfident predictions) |
| - Cosine annealing with warm restarts (SGDR) |
| - Per-step NaN/Inf monitoring with automatic recovery |
| """ |
|
|
| import os |
| import sys |
| import math |
| import time |
| import json |
| import random |
| import logging |
| from pathlib import Path |
| from typing import Optional, Dict, List, Tuple |
| from dataclasses import dataclass, asdict |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
|
|
| from model import MuseMorphicConfig, MuseMorphic, PhraseVAE, LatentMamba, ZClip |
|
|
| logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class TrainConfig: |
| """Training hyperparameters.""" |
| |
| |
| seed: int = 42 |
| device: str = "auto" |
| dtype: str = "bf16" |
| |
| |
| vae_epochs_pretrain: int = 5 |
| vae_epochs_ae: int = 20 |
| vae_epochs_vae: int = 10 |
| vae_batch_size: int = 64 |
| vae_lr: float = 3e-4 |
| vae_weight_decay: float = 0.01 |
| vae_max_seq_len: int = 256 |
| |
| |
| mamba_epochs: int = 50 |
| mamba_batch_size: int = 32 |
| mamba_lr: float = 1e-4 |
| mamba_weight_decay: float = 0.01 |
| mamba_max_phrases: int = 128 |
| |
| |
| gradient_accumulation_steps: int = 1 |
| max_grad_norm: float = 1.0 |
| warmup_steps: int = 500 |
| |
| |
| sgdr_t0: int = 1000 |
| sgdr_t_mult: int = 2 |
| sgdr_eta_min: float = 1e-6 |
| |
| |
| use_zclip: bool = True |
| zclip_z_thresh: float = 2.5 |
| zclip_alpha: float = 0.99 |
| label_smoothing: float = 0.1 |
| kl_beta: float = 0.01 |
| |
| |
| log_every_n_steps: int = 10 |
| eval_every_n_steps: int = 500 |
| save_every_n_steps: int = 1000 |
| |
| |
| output_dir: str = "./checkpoints" |
| data_dir: str = "./data" |
| |
| |
| push_to_hub: bool = True |
| hub_model_id: str = "" |
|
|
|
|
| |
| |
| |
|
|
| class PhraseDataset(Dataset): |
| """ |
| Dataset of tokenized REMI+ phrases for PhraseVAE training. |
| |
| Each item is a padded sequence of token IDs representing one phrase |
| (one bar of one track). |
| """ |
| |
| def __init__(self, phrases: List[List[int]], max_len: int = 256, pad_id: int = 0): |
| self.phrases = phrases |
| self.max_len = max_len |
| self.pad_id = pad_id |
| |
| def __len__(self): |
| return len(self.phrases) |
| |
| def __getitem__(self, idx): |
| ids = self.phrases[idx][:self.max_len] |
| |
| |
| padded = ids + [self.pad_id] * (self.max_len - len(ids)) |
| |
| return { |
| 'token_ids': torch.tensor(padded, dtype=torch.long), |
| 'length': min(len(ids), self.max_len), |
| } |
|
|
|
|
| class LatentSequenceDataset(Dataset): |
| """ |
| Dataset of latent phrase sequences for LatentMamba training. |
| |
| Each item is a sequence of latent vectors (encoded by PhraseVAE) |
| with associated control attributes. |
| """ |
| |
| def __init__(self, latent_sequences: List[torch.Tensor], |
| controls: Optional[List[Dict[str, int]]] = None, |
| max_phrases: int = 128): |
| self.latent_sequences = latent_sequences |
| self.controls = controls |
| self.max_phrases = max_phrases |
| |
| def __len__(self): |
| return len(self.latent_sequences) |
| |
| def __getitem__(self, idx): |
| z_seq = self.latent_sequences[idx][:self.max_phrases] |
| T = z_seq.shape[0] |
| |
| |
| if T < self.max_phrases: |
| pad = torch.zeros(self.max_phrases - T, z_seq.shape[-1]) |
| z_seq = torch.cat([z_seq, pad], dim=0) |
| |
| item = { |
| 'z_seq': z_seq, |
| 'length': T, |
| } |
| |
| if self.controls: |
| ctrl = self.controls[idx] |
| item['controls'] = {k: torch.tensor(v, dtype=torch.long) for k, v in ctrl.items()} |
| |
| return item |
|
|
|
|
| |
| |
| |
|
|
| def get_device(config: TrainConfig) -> torch.device: |
| """Auto-detect best device.""" |
| if config.device == "auto": |
| if torch.cuda.is_available(): |
| return torch.device("cuda") |
| elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): |
| return torch.device("mps") |
| return torch.device("cpu") |
| return torch.device(config.device) |
|
|
|
|
| def get_dtype(config: TrainConfig) -> torch.dtype: |
| """Get torch dtype from config string.""" |
| if config.dtype == "bf16": |
| if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): |
| return torch.bfloat16 |
| return torch.float32 |
| elif config.dtype == "fp16": |
| return torch.float16 |
| return torch.float32 |
|
|
|
|
| def set_seed(seed: int): |
| """Set all random seeds for reproducibility.""" |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| class NaNMonitor: |
| """ |
| Monitor for NaN/Inf in loss and gradients. |
| |
| If NaN detected: |
| 1. Skip the optimization step |
| 2. Reduce learning rate by 50% |
| 3. Log warning |
| 4. If 5 consecutive NaNs, stop training |
| """ |
| |
| def __init__(self, max_consecutive: int = 5): |
| self.max_consecutive = max_consecutive |
| self.consecutive_nan = 0 |
| self.total_nan = 0 |
| |
| def check(self, loss: torch.Tensor, optimizer: torch.optim.Optimizer) -> bool: |
| """ |
| Check for NaN/Inf. Returns True if training should continue. |
| """ |
| if torch.isnan(loss) or torch.isinf(loss): |
| self.consecutive_nan += 1 |
| self.total_nan += 1 |
| |
| logger.warning(f"NaN/Inf detected! Consecutive: {self.consecutive_nan}, " |
| f"Total: {self.total_nan}") |
| |
| if self.consecutive_nan >= self.max_consecutive: |
| logger.error(f"Training stopped: {self.max_consecutive} consecutive NaN/Inf") |
| return False |
| |
| |
| for param_group in optimizer.param_groups: |
| param_group['lr'] *= 0.5 |
| logger.info(f"Reduced LR to {param_group['lr']:.2e}") |
| |
| |
| optimizer.zero_grad() |
| return True |
| |
| self.consecutive_nan = 0 |
| return True |
|
|
|
|
| class MetricsTracker: |
| """Simple metrics tracking with exponential moving average.""" |
| |
| def __init__(self, alpha: float = 0.99): |
| self.alpha = alpha |
| self.metrics = {} |
| self.step_count = 0 |
| |
| def update(self, **kwargs): |
| for k, v in kwargs.items(): |
| if isinstance(v, torch.Tensor): |
| v = v.item() |
| if k not in self.metrics: |
| self.metrics[k] = v |
| else: |
| self.metrics[k] = self.alpha * self.metrics[k] + (1 - self.alpha) * v |
| self.step_count += 1 |
| |
| def get(self) -> Dict[str, float]: |
| return {k: round(v, 6) for k, v in self.metrics.items()} |
| |
| def log(self, prefix: str = ""): |
| metrics = self.get() |
| parts = [f"{k}={v:.6f}" for k, v in metrics.items()] |
| logger.info(f"{prefix}step={self.step_count} | {' | '.join(parts)}") |
|
|
|
|
| |
| |
| |
|
|
| def train_phrase_vae( |
| model: PhraseVAE, |
| train_dataset: PhraseDataset, |
| val_dataset: Optional[PhraseDataset], |
| config: TrainConfig, |
| device: torch.device, |
| dtype: torch.dtype, |
| ) -> PhraseVAE: |
| """ |
| Three-stage PhraseVAE training curriculum. |
| |
| Stage 1a: Span-infilling pretraining (learn REMI grammar) |
| Stage 1b: Autoencoder (KL=0, pure reconstruction) |
| Stage 1c: VAE fine-tuning (KL=0.01) |
| """ |
| |
| logger.info("=" * 60) |
| logger.info("Stage 1: PhraseVAE Training") |
| logger.info("=" * 60) |
| |
| model = model.to(device) |
| |
| |
| no_decay = ['bias', 'LayerNorm', 'layer_norm', 'b_sin', 'b_cos'] |
| param_groups = [ |
| {'params': [p for n, p in model.named_parameters() |
| if not any(nd in n for nd in no_decay)], |
| 'weight_decay': config.vae_weight_decay}, |
| {'params': [p for n, p in model.named_parameters() |
| if any(nd in n for nd in no_decay)], |
| 'weight_decay': 0.0} |
| ] |
| optimizer = torch.optim.AdamW(param_groups, lr=config.vae_lr, betas=(0.9, 0.999)) |
| |
| |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( |
| optimizer, T_0=config.sgdr_t0, T_mult=config.sgdr_t_mult, |
| eta_min=config.sgdr_eta_min |
| ) |
| |
| |
| zclip = ZClip(config.zclip_z_thresh, config.zclip_alpha) if config.use_zclip else None |
| nan_monitor = NaNMonitor() |
| metrics = MetricsTracker() |
| |
| train_loader = DataLoader( |
| train_dataset, batch_size=config.vae_batch_size, |
| shuffle=True, num_workers=2, pin_memory=True, drop_last=True |
| ) |
| |
| |
| logger.info("\n--- Stage 1a: Span-infilling pretraining ---") |
| for epoch in range(config.vae_epochs_pretrain): |
| model.train() |
| for batch_idx, batch in enumerate(train_loader): |
| token_ids = batch['token_ids'].to(device) |
| |
| |
| masked_ids, mask = _apply_span_mask(token_ids, mask_prob=0.15, |
| mask_id=model.config.mask_token_id) |
| |
| with torch.autocast(device_type=device.type, dtype=dtype): |
| outputs = model(masked_ids, target_tokens=token_ids, kl_weight=0.0) |
| |
| loss = outputs['loss'] |
| |
| if not nan_monitor.check(loss, optimizer): |
| return model |
| |
| loss.backward() |
| |
| if zclip: |
| grad_norm = zclip(model) |
| else: |
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm).item() |
| |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| |
| metrics.update(loss=loss, recon=outputs['recon_loss'], grad_norm=grad_norm) |
| |
| if batch_idx % config.log_every_n_steps == 0: |
| metrics.log(prefix=f"[1a] Epoch {epoch+1}/{config.vae_epochs_pretrain} ") |
| |
| |
| logger.info("\n--- Stage 1b: Autoencoder training (KL weight = 0) ---") |
| for epoch in range(config.vae_epochs_ae): |
| model.train() |
| for batch_idx, batch in enumerate(train_loader): |
| token_ids = batch['token_ids'].to(device) |
| |
| with torch.autocast(device_type=device.type, dtype=dtype): |
| outputs = model(token_ids, kl_weight=0.0) |
| |
| loss = outputs['loss'] |
| |
| if not nan_monitor.check(loss, optimizer): |
| return model |
| |
| loss.backward() |
| |
| if zclip: |
| zclip(model) |
| |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| |
| metrics.update(loss=loss, recon=outputs['recon_loss'], kl=outputs['kl_loss']) |
| |
| if batch_idx % config.log_every_n_steps == 0: |
| metrics.log(prefix=f"[1b] Epoch {epoch+1}/{config.vae_epochs_ae} ") |
| |
| |
| logger.info("\n--- Stage 1c: VAE fine-tuning (KL weight = 0.01) ---") |
| |
| for pg in optimizer.param_groups: |
| pg['lr'] = config.vae_lr * 0.1 |
| |
| for epoch in range(config.vae_epochs_vae): |
| model.train() |
| for batch_idx, batch in enumerate(train_loader): |
| token_ids = batch['token_ids'].to(device) |
| |
| with torch.autocast(device_type=device.type, dtype=dtype): |
| outputs = model(token_ids, kl_weight=config.kl_beta) |
| |
| loss = outputs['loss'] |
| |
| if not nan_monitor.check(loss, optimizer): |
| return model |
| |
| loss.backward() |
| |
| if zclip: |
| zclip(model) |
| |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| |
| metrics.update(loss=loss, recon=outputs['recon_loss'], kl=outputs['kl_loss']) |
| |
| if batch_idx % config.log_every_n_steps == 0: |
| metrics.log(prefix=f"[1c] Epoch {epoch+1}/{config.vae_epochs_vae} ") |
| |
| logger.info("Stage 1 complete!") |
| return model |
|
|
|
|
| |
| |
| |
|
|
| def train_latent_mamba( |
| mamba_model: LatentMamba, |
| vae_model: PhraseVAE, |
| train_dataset: PhraseDataset, |
| config: TrainConfig, |
| device: torch.device, |
| dtype: torch.dtype, |
| ) -> LatentMamba: |
| """ |
| Train LatentMamba on phrase latent sequences. |
| |
| 1. Freeze PhraseVAE encoder |
| 2. Encode all training phrases into latent sequences |
| 3. Train LatentMamba to predict next phrase latents |
| """ |
| |
| logger.info("=" * 60) |
| logger.info("Stage 2: LatentMamba Training") |
| logger.info("=" * 60) |
| |
| |
| vae_model.eval() |
| for p in vae_model.parameters(): |
| p.requires_grad = False |
| |
| mamba_model = mamba_model.to(device) |
| |
| |
| optimizer = torch.optim.AdamW( |
| mamba_model.parameters(), lr=config.mamba_lr, |
| weight_decay=config.mamba_weight_decay, betas=(0.9, 0.999) |
| ) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( |
| optimizer, T_0=config.sgdr_t0, T_mult=config.sgdr_t_mult, |
| eta_min=config.sgdr_eta_min |
| ) |
| |
| zclip = ZClip(config.zclip_z_thresh, config.zclip_alpha) if config.use_zclip else None |
| nan_monitor = NaNMonitor() |
| metrics = MetricsTracker() |
| |
| |
| logger.info("Encoding training phrases to latent space...") |
| latent_sequences = _encode_all_phrases(vae_model, train_dataset, device, dtype, |
| config.mamba_batch_size) |
| |
| latent_dataset = LatentSequenceDataset(latent_sequences, max_phrases=config.mamba_max_phrases) |
| train_loader = DataLoader( |
| latent_dataset, batch_size=config.mamba_batch_size, |
| shuffle=True, num_workers=2, pin_memory=True, drop_last=True |
| ) |
| |
| |
| for epoch in range(config.mamba_epochs): |
| mamba_model.train() |
| for batch_idx, batch in enumerate(train_loader): |
| z_seq = batch['z_seq'].to(device) |
| lengths = batch['length'] |
| |
| |
| |
| z_input = z_seq[:, :-1] |
| z_target = z_seq[:, 1:] |
| |
| with torch.autocast(device_type=device.type, dtype=dtype): |
| z_pred = mamba_model(z_input) |
| |
| |
| mask = torch.arange(z_target.shape[1], device=device).unsqueeze(0) < (lengths.unsqueeze(1) - 1).to(device) |
| mask = mask.unsqueeze(-1).float() |
| |
| loss = F.mse_loss(z_pred * mask, z_target * mask) |
| |
| |
| cos_loss = 1.0 - F.cosine_similarity( |
| z_pred.reshape(-1, z_pred.shape[-1]), |
| z_target.reshape(-1, z_target.shape[-1]), |
| dim=-1 |
| ).mean() |
| |
| total_loss = loss + 0.1 * cos_loss |
| |
| if not nan_monitor.check(total_loss, optimizer): |
| return mamba_model |
| |
| total_loss.backward() |
| |
| if zclip: |
| zclip(mamba_model) |
| |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| |
| metrics.update(loss=loss, cos_loss=cos_loss, total=total_loss) |
| |
| if batch_idx % config.log_every_n_steps == 0: |
| metrics.log(prefix=f"[S2] Epoch {epoch+1}/{config.mamba_epochs} ") |
| |
| logger.info("Stage 2 complete!") |
| return mamba_model |
|
|
|
|
| |
| |
| |
|
|
| def _apply_span_mask(token_ids: torch.Tensor, mask_prob: float = 0.15, |
| mask_id: int = 3, span_length: int = 3) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Apply span masking for pretraining (like T5/BART). |
| Masks contiguous spans of tokens. |
| """ |
| masked = token_ids.clone() |
| B, L = masked.shape |
| mask = torch.zeros_like(masked, dtype=torch.bool) |
| |
| for b in range(B): |
| n_masks = max(1, int(L * mask_prob / span_length)) |
| for _ in range(n_masks): |
| start = random.randint(1, max(1, L - span_length - 1)) |
| end = min(start + span_length, L) |
| masked[b, start:end] = mask_id |
| mask[b, start:end] = True |
| |
| return masked, mask |
|
|
|
|
| def _encode_all_phrases(vae_model: PhraseVAE, dataset: PhraseDataset, |
| device: torch.device, dtype: torch.dtype, |
| batch_size: int = 64) -> List[torch.Tensor]: |
| """Encode all phrases in dataset to latent vectors.""" |
| loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2) |
| |
| all_latents = [] |
| with torch.no_grad(): |
| for batch in loader: |
| token_ids = batch['token_ids'].to(device) |
| with torch.autocast(device_type=device.type, dtype=dtype): |
| z, _, _ = vae_model.encode(token_ids) |
| all_latents.append(z.cpu()) |
| |
| |
| all_z = torch.cat(all_latents, dim=0) |
| |
| |
| |
| chunk_size = 32 |
| sequences = [] |
| for i in range(0, len(all_z) - chunk_size, chunk_size): |
| sequences.append(all_z[i:i+chunk_size]) |
| |
| logger.info(f"Encoded {len(all_z)} phrases into {len(sequences)} sequences") |
| return sequences |
|
|
|
|
| |
| |
| |
|
|
| def save_checkpoint(model: MuseMorphic, config: TrainConfig, |
| model_config: MuseMorphicConfig, step: int, path: str): |
| """Save model checkpoint.""" |
| os.makedirs(path, exist_ok=True) |
| |
| torch.save({ |
| 'model_state_dict': model.state_dict(), |
| 'step': step, |
| 'model_config': asdict(model_config), |
| 'train_config': asdict(config), |
| }, os.path.join(path, f'checkpoint_{step}.pt')) |
| |
| |
| torch.save({ |
| 'model_state_dict': model.state_dict(), |
| 'step': step, |
| 'model_config': asdict(model_config), |
| 'train_config': asdict(config), |
| }, os.path.join(path, 'checkpoint_latest.pt')) |
| |
| logger.info(f"Saved checkpoint at step {step} to {path}") |
|
|
|
|
| def load_checkpoint(path: str, device: torch.device) -> Tuple[MuseMorphic, Dict]: |
| """Load model from checkpoint.""" |
| ckpt = torch.load(os.path.join(path, 'checkpoint_latest.pt'), map_location=device) |
| |
| model_config = MuseMorphicConfig(**ckpt['model_config']) |
| model = MuseMorphic(model_config) |
| model.load_state_dict(ckpt['model_state_dict']) |
| |
| return model, ckpt |
|
|
|
|
| |
| |
| |
|
|
| def train_musemorphic( |
| model_config: Optional[MuseMorphicConfig] = None, |
| train_config: Optional[TrainConfig] = None, |
| train_phrases: Optional[List[List[int]]] = None, |
| ): |
| """ |
| Complete MuseMorphic training pipeline. |
| |
| If train_phrases is None, generates synthetic data for testing. |
| """ |
| if model_config is None: |
| model_config = MuseMorphicConfig() |
| if train_config is None: |
| train_config = TrainConfig() |
| |
| set_seed(train_config.seed) |
| device = get_device(train_config) |
| dtype = get_dtype(train_config) |
| |
| logger.info(f"Device: {device}, Dtype: {dtype}") |
| |
| |
| model = MuseMorphic(model_config) |
| params = model.count_parameters() |
| logger.info(f"Model parameters: {params}") |
| |
| |
| if train_phrases is None: |
| logger.info("No training data provided. Generating synthetic data for testing...") |
| train_phrases = _generate_synthetic_phrases(1000, model_config.vae_max_seq_len, |
| model_config.vocab_size) |
| |
| |
| train_dataset = PhraseDataset(train_phrases, model_config.vae_max_seq_len, model_config.pad_token_id) |
| logger.info(f"Training dataset: {len(train_dataset)} phrases") |
| |
| |
| model.phrase_vae = train_phrase_vae( |
| model.phrase_vae, train_dataset, None, train_config, device, dtype |
| ) |
| |
| |
| model.latent_mamba = train_latent_mamba( |
| model.latent_mamba, model.phrase_vae, train_dataset, |
| train_config, device, dtype |
| ) |
| |
| |
| save_checkpoint(model, train_config, model_config, -1, train_config.output_dir) |
| |
| return model |
|
|
|
|
| def _generate_synthetic_phrases(n: int, max_len: int, vocab_size: int) -> List[List[int]]: |
| """Generate synthetic REMI-like phrases for testing.""" |
| phrases = [] |
| for _ in range(n): |
| length = random.randint(10, max_len) |
| |
| phrase = [1] |
| for _ in range(length - 2): |
| |
| tok = random.randint(4, vocab_size - 1) |
| phrase.append(tok) |
| phrase.append(2) |
| phrases.append(phrase) |
| return phrases |
|
|
|
|
| if __name__ == "__main__": |
| model = train_musemorphic() |
|
|