MuseMorphic / generate.py
asdf98's picture
Upload generate.py
561ecde verified
"""
MuseMorphic Inference Script
=============================
Generate MIDI music from a trained MuseMorphic model.
Usage:
python generate.py --checkpoint ./checkpoints/musemorphic_model.pt \
--output generated.mid \
--n_phrases 32 \
--temperature 0.7
"""
import argparse
import os
import sys
import torch
import torch.nn.functional as F
sys.path.insert(0, os.path.dirname(__file__))
from musemorphic.model import MuseMorphicConfig, MuseMorphic, PhraseVAE, LatentMamba
from musemorphic.tokenizer import REMIPlusTokenizer, notes_to_midi_file
def load_model(checkpoint_path: str, device: torch.device):
"""Load trained model from checkpoint."""
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
config = MuseMorphicConfig(**ckpt['config'])
vae = PhraseVAE(config).to(device)
mamba = LatentMamba(config).to(device)
vae.load_state_dict(ckpt['vae_state_dict'])
mamba.load_state_dict(ckpt['mamba_state_dict'])
model = MuseMorphic(config)
model.phrase_vae = vae
model.latent_mamba = mamba
model.to(device)
model.eval()
return model, config
def generate_midi(
model: MuseMorphic,
config: MuseMorphicConfig,
tokenizer: REMIPlusTokenizer,
n_phrases: int = 32,
temperature: float = 0.7,
max_decode_len: int = 128,
device: torch.device = torch.device('cpu'),
) -> list:
"""
Generate MIDI notes from the model.
Returns list of note dicts: {pitch, start, duration, velocity}
"""
with torch.no_grad():
# Stage 2: Generate latent phrase sequence
z_generated = model.latent_mamba.generate(
n_phrases=n_phrases,
temperature=temperature,
batch_size=1,
)
# Stage 1: Decode each phrase latent to tokens
all_tokens = []
for t in range(z_generated.shape[1]):
z = z_generated[:, t]
# Autoregressive decode
generated_ids = [config.bos_token_id]
for _ in range(max_decode_len):
input_tensor = torch.tensor([generated_ids], dtype=torch.long, device=device)
logits = model.phrase_vae.decode(z, input_tensor)
next_logits = logits[0, -1] / max(temperature, 0.1)
probs = F.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, 1).item()
generated_ids.append(next_token)
if next_token == config.eos_token_id:
break
phrase_tokens = tokenizer.decode(generated_ids)
all_tokens.extend(phrase_tokens)
# Convert tokens to MIDI notes
notes = tokenizer.tokens_to_midi_notes(all_tokens)
return notes, all_tokens
def main():
parser = argparse.ArgumentParser(description='MuseMorphic MIDI Generator')
parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint')
parser.add_argument('--output', type=str, default='generated.mid', help='Output MIDI file path')
parser.add_argument('--n_phrases', type=int, default=32, help='Number of phrases to generate')
parser.add_argument('--temperature', type=float, default=0.7, help='Sampling temperature')
parser.add_argument('--tempo', type=float, default=120.0, help='Output MIDI tempo')
parser.add_argument('--device', type=str, default='auto', help='Device (auto/cuda/cpu)')
args = parser.parse_args()
# Device
if args.device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
device = torch.device(args.device)
print(f'Device: {device}')
# Load model
print(f'Loading model from {args.checkpoint}...')
model, config = load_model(args.checkpoint, device)
params = sum(p.numel() for p in model.parameters())
print(f'Model parameters: {params:,} ({params/1e6:.2f}M)')
# Tokenizer
tokenizer = REMIPlusTokenizer()
# Generate
print(f'Generating {args.n_phrases} phrases at temperature {args.temperature}...')
notes, tokens = generate_midi(
model, config, tokenizer,
n_phrases=args.n_phrases,
temperature=args.temperature,
device=device,
)
print(f'Generated {len(notes)} notes, {len(tokens)} tokens')
if notes:
# Write MIDI
success = notes_to_midi_file(notes, args.output, tempo=args.tempo)
if success:
print(f'\n🎵 MIDI saved to: {args.output}')
total_dur = max(n['start'] + n['duration'] for n in notes)
print(f' Duration: ~{total_dur/480:.1f} beats ({total_dur/480/args.tempo*60:.1f} seconds at {args.tempo} BPM)')
pitches = [n['pitch'] for n in notes]
print(f' Pitch range: {min(pitches)}-{max(pitches)}')
print(f' Note count: {len(notes)}')
else:
print('Failed to write MIDI. Install midiutil: pip install midiutil')
else:
print('No notes generated. Model may need more training.')
if __name__ == '__main__':
main()