| """ |
| MuseMorphic Inference Script |
| ============================= |
| |
| Generate MIDI music from a trained MuseMorphic model. |
| |
| Usage: |
| python generate.py --checkpoint ./checkpoints/musemorphic_model.pt \ |
| --output generated.mid \ |
| --n_phrases 32 \ |
| --temperature 0.7 |
| """ |
|
|
| import argparse |
| import os |
| import sys |
| import torch |
| import torch.nn.functional as F |
|
|
| sys.path.insert(0, os.path.dirname(__file__)) |
| from musemorphic.model import MuseMorphicConfig, MuseMorphic, PhraseVAE, LatentMamba |
| from musemorphic.tokenizer import REMIPlusTokenizer, notes_to_midi_file |
|
|
|
|
| def load_model(checkpoint_path: str, device: torch.device): |
| """Load trained model from checkpoint.""" |
| ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) |
| |
| config = MuseMorphicConfig(**ckpt['config']) |
| |
| vae = PhraseVAE(config).to(device) |
| mamba = LatentMamba(config).to(device) |
| |
| vae.load_state_dict(ckpt['vae_state_dict']) |
| mamba.load_state_dict(ckpt['mamba_state_dict']) |
| |
| model = MuseMorphic(config) |
| model.phrase_vae = vae |
| model.latent_mamba = mamba |
| model.to(device) |
| model.eval() |
| |
| return model, config |
|
|
|
|
| def generate_midi( |
| model: MuseMorphic, |
| config: MuseMorphicConfig, |
| tokenizer: REMIPlusTokenizer, |
| n_phrases: int = 32, |
| temperature: float = 0.7, |
| max_decode_len: int = 128, |
| device: torch.device = torch.device('cpu'), |
| ) -> list: |
| """ |
| Generate MIDI notes from the model. |
| |
| Returns list of note dicts: {pitch, start, duration, velocity} |
| """ |
| with torch.no_grad(): |
| |
| z_generated = model.latent_mamba.generate( |
| n_phrases=n_phrases, |
| temperature=temperature, |
| batch_size=1, |
| ) |
| |
| |
| all_tokens = [] |
| for t in range(z_generated.shape[1]): |
| z = z_generated[:, t] |
| |
| |
| generated_ids = [config.bos_token_id] |
| |
| for _ in range(max_decode_len): |
| input_tensor = torch.tensor([generated_ids], dtype=torch.long, device=device) |
| logits = model.phrase_vae.decode(z, input_tensor) |
| |
| next_logits = logits[0, -1] / max(temperature, 0.1) |
| probs = F.softmax(next_logits, dim=-1) |
| next_token = torch.multinomial(probs, 1).item() |
| generated_ids.append(next_token) |
| |
| if next_token == config.eos_token_id: |
| break |
| |
| phrase_tokens = tokenizer.decode(generated_ids) |
| all_tokens.extend(phrase_tokens) |
| |
| |
| notes = tokenizer.tokens_to_midi_notes(all_tokens) |
| return notes, all_tokens |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='MuseMorphic MIDI Generator') |
| parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint') |
| parser.add_argument('--output', type=str, default='generated.mid', help='Output MIDI file path') |
| parser.add_argument('--n_phrases', type=int, default=32, help='Number of phrases to generate') |
| parser.add_argument('--temperature', type=float, default=0.7, help='Sampling temperature') |
| parser.add_argument('--tempo', type=float, default=120.0, help='Output MIDI tempo') |
| parser.add_argument('--device', type=str, default='auto', help='Device (auto/cuda/cpu)') |
| |
| args = parser.parse_args() |
| |
| |
| if args.device == 'auto': |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| else: |
| device = torch.device(args.device) |
| |
| print(f'Device: {device}') |
| |
| |
| print(f'Loading model from {args.checkpoint}...') |
| model, config = load_model(args.checkpoint, device) |
| |
| params = sum(p.numel() for p in model.parameters()) |
| print(f'Model parameters: {params:,} ({params/1e6:.2f}M)') |
| |
| |
| tokenizer = REMIPlusTokenizer() |
| |
| |
| print(f'Generating {args.n_phrases} phrases at temperature {args.temperature}...') |
| notes, tokens = generate_midi( |
| model, config, tokenizer, |
| n_phrases=args.n_phrases, |
| temperature=args.temperature, |
| device=device, |
| ) |
| |
| print(f'Generated {len(notes)} notes, {len(tokens)} tokens') |
| |
| if notes: |
| |
| success = notes_to_midi_file(notes, args.output, tempo=args.tempo) |
| if success: |
| print(f'\n🎵 MIDI saved to: {args.output}') |
| total_dur = max(n['start'] + n['duration'] for n in notes) |
| print(f' Duration: ~{total_dur/480:.1f} beats ({total_dur/480/args.tempo*60:.1f} seconds at {args.tempo} BPM)') |
| pitches = [n['pitch'] for n in notes] |
| print(f' Pitch range: {min(pitches)}-{max(pitches)}') |
| print(f' Note count: {len(notes)}') |
| else: |
| print('Failed to write MIDI. Install midiutil: pip install midiutil') |
| else: |
| print('No notes generated. Model may need more training.') |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|