""" MuseMorphic Inference Script ============================= Generate MIDI music from a trained MuseMorphic model. Usage: python generate.py --checkpoint ./checkpoints/musemorphic_model.pt \ --output generated.mid \ --n_phrases 32 \ --temperature 0.7 """ import argparse import os import sys import torch import torch.nn.functional as F sys.path.insert(0, os.path.dirname(__file__)) from musemorphic.model import MuseMorphicConfig, MuseMorphic, PhraseVAE, LatentMamba from musemorphic.tokenizer import REMIPlusTokenizer, notes_to_midi_file def load_model(checkpoint_path: str, device: torch.device): """Load trained model from checkpoint.""" ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) config = MuseMorphicConfig(**ckpt['config']) vae = PhraseVAE(config).to(device) mamba = LatentMamba(config).to(device) vae.load_state_dict(ckpt['vae_state_dict']) mamba.load_state_dict(ckpt['mamba_state_dict']) model = MuseMorphic(config) model.phrase_vae = vae model.latent_mamba = mamba model.to(device) model.eval() return model, config def generate_midi( model: MuseMorphic, config: MuseMorphicConfig, tokenizer: REMIPlusTokenizer, n_phrases: int = 32, temperature: float = 0.7, max_decode_len: int = 128, device: torch.device = torch.device('cpu'), ) -> list: """ Generate MIDI notes from the model. Returns list of note dicts: {pitch, start, duration, velocity} """ with torch.no_grad(): # Stage 2: Generate latent phrase sequence z_generated = model.latent_mamba.generate( n_phrases=n_phrases, temperature=temperature, batch_size=1, ) # Stage 1: Decode each phrase latent to tokens all_tokens = [] for t in range(z_generated.shape[1]): z = z_generated[:, t] # Autoregressive decode generated_ids = [config.bos_token_id] for _ in range(max_decode_len): input_tensor = torch.tensor([generated_ids], dtype=torch.long, device=device) logits = model.phrase_vae.decode(z, input_tensor) next_logits = logits[0, -1] / max(temperature, 0.1) probs = F.softmax(next_logits, dim=-1) next_token = torch.multinomial(probs, 1).item() generated_ids.append(next_token) if next_token == config.eos_token_id: break phrase_tokens = tokenizer.decode(generated_ids) all_tokens.extend(phrase_tokens) # Convert tokens to MIDI notes notes = tokenizer.tokens_to_midi_notes(all_tokens) return notes, all_tokens def main(): parser = argparse.ArgumentParser(description='MuseMorphic MIDI Generator') parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint') parser.add_argument('--output', type=str, default='generated.mid', help='Output MIDI file path') parser.add_argument('--n_phrases', type=int, default=32, help='Number of phrases to generate') parser.add_argument('--temperature', type=float, default=0.7, help='Sampling temperature') parser.add_argument('--tempo', type=float, default=120.0, help='Output MIDI tempo') parser.add_argument('--device', type=str, default='auto', help='Device (auto/cuda/cpu)') args = parser.parse_args() # Device if args.device == 'auto': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: device = torch.device(args.device) print(f'Device: {device}') # Load model print(f'Loading model from {args.checkpoint}...') model, config = load_model(args.checkpoint, device) params = sum(p.numel() for p in model.parameters()) print(f'Model parameters: {params:,} ({params/1e6:.2f}M)') # Tokenizer tokenizer = REMIPlusTokenizer() # Generate print(f'Generating {args.n_phrases} phrases at temperature {args.temperature}...') notes, tokens = generate_midi( model, config, tokenizer, n_phrases=args.n_phrases, temperature=args.temperature, device=device, ) print(f'Generated {len(notes)} notes, {len(tokens)} tokens') if notes: # Write MIDI success = notes_to_midi_file(notes, args.output, tempo=args.tempo) if success: print(f'\n🎵 MIDI saved to: {args.output}') total_dur = max(n['start'] + n['duration'] for n in notes) print(f' Duration: ~{total_dur/480:.1f} beats ({total_dur/480/args.tempo*60:.1f} seconds at {args.tempo} BPM)') pitches = [n['pitch'] for n in notes] print(f' Pitch range: {min(pitches)}-{max(pitches)}') print(f' Note count: {len(notes)}') else: print('Failed to write MIDI. Install midiutil: pip install midiutil') else: print('No notes generated. Model may need more training.') if __name__ == '__main__': main()