File size: 5,255 Bytes
561ecde | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | """
MuseMorphic Inference Script
=============================
Generate MIDI music from a trained MuseMorphic model.
Usage:
python generate.py --checkpoint ./checkpoints/musemorphic_model.pt \
--output generated.mid \
--n_phrases 32 \
--temperature 0.7
"""
import argparse
import os
import sys
import torch
import torch.nn.functional as F
sys.path.insert(0, os.path.dirname(__file__))
from musemorphic.model import MuseMorphicConfig, MuseMorphic, PhraseVAE, LatentMamba
from musemorphic.tokenizer import REMIPlusTokenizer, notes_to_midi_file
def load_model(checkpoint_path: str, device: torch.device):
"""Load trained model from checkpoint."""
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
config = MuseMorphicConfig(**ckpt['config'])
vae = PhraseVAE(config).to(device)
mamba = LatentMamba(config).to(device)
vae.load_state_dict(ckpt['vae_state_dict'])
mamba.load_state_dict(ckpt['mamba_state_dict'])
model = MuseMorphic(config)
model.phrase_vae = vae
model.latent_mamba = mamba
model.to(device)
model.eval()
return model, config
def generate_midi(
model: MuseMorphic,
config: MuseMorphicConfig,
tokenizer: REMIPlusTokenizer,
n_phrases: int = 32,
temperature: float = 0.7,
max_decode_len: int = 128,
device: torch.device = torch.device('cpu'),
) -> list:
"""
Generate MIDI notes from the model.
Returns list of note dicts: {pitch, start, duration, velocity}
"""
with torch.no_grad():
# Stage 2: Generate latent phrase sequence
z_generated = model.latent_mamba.generate(
n_phrases=n_phrases,
temperature=temperature,
batch_size=1,
)
# Stage 1: Decode each phrase latent to tokens
all_tokens = []
for t in range(z_generated.shape[1]):
z = z_generated[:, t]
# Autoregressive decode
generated_ids = [config.bos_token_id]
for _ in range(max_decode_len):
input_tensor = torch.tensor([generated_ids], dtype=torch.long, device=device)
logits = model.phrase_vae.decode(z, input_tensor)
next_logits = logits[0, -1] / max(temperature, 0.1)
probs = F.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, 1).item()
generated_ids.append(next_token)
if next_token == config.eos_token_id:
break
phrase_tokens = tokenizer.decode(generated_ids)
all_tokens.extend(phrase_tokens)
# Convert tokens to MIDI notes
notes = tokenizer.tokens_to_midi_notes(all_tokens)
return notes, all_tokens
def main():
parser = argparse.ArgumentParser(description='MuseMorphic MIDI Generator')
parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint')
parser.add_argument('--output', type=str, default='generated.mid', help='Output MIDI file path')
parser.add_argument('--n_phrases', type=int, default=32, help='Number of phrases to generate')
parser.add_argument('--temperature', type=float, default=0.7, help='Sampling temperature')
parser.add_argument('--tempo', type=float, default=120.0, help='Output MIDI tempo')
parser.add_argument('--device', type=str, default='auto', help='Device (auto/cuda/cpu)')
args = parser.parse_args()
# Device
if args.device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
device = torch.device(args.device)
print(f'Device: {device}')
# Load model
print(f'Loading model from {args.checkpoint}...')
model, config = load_model(args.checkpoint, device)
params = sum(p.numel() for p in model.parameters())
print(f'Model parameters: {params:,} ({params/1e6:.2f}M)')
# Tokenizer
tokenizer = REMIPlusTokenizer()
# Generate
print(f'Generating {args.n_phrases} phrases at temperature {args.temperature}...')
notes, tokens = generate_midi(
model, config, tokenizer,
n_phrases=args.n_phrases,
temperature=args.temperature,
device=device,
)
print(f'Generated {len(notes)} notes, {len(tokens)} tokens')
if notes:
# Write MIDI
success = notes_to_midi_file(notes, args.output, tempo=args.tempo)
if success:
print(f'\n🎵 MIDI saved to: {args.output}')
total_dur = max(n['start'] + n['duration'] for n in notes)
print(f' Duration: ~{total_dur/480:.1f} beats ({total_dur/480/args.tempo*60:.1f} seconds at {args.tempo} BPM)')
pitches = [n['pitch'] for n in notes]
print(f' Pitch range: {min(pitches)}-{max(pitches)}')
print(f' Note count: {len(notes)}')
else:
print('Failed to write MIDI. Install midiutil: pip install midiutil')
else:
print('No notes generated. Model may need more training.')
if __name__ == '__main__':
main()
|