File size: 5,255 Bytes
561ecde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""
MuseMorphic Inference Script
=============================

Generate MIDI music from a trained MuseMorphic model.

Usage:
    python generate.py --checkpoint ./checkpoints/musemorphic_model.pt \
                       --output generated.mid \
                       --n_phrases 32 \
                       --temperature 0.7
"""

import argparse
import os
import sys
import torch
import torch.nn.functional as F

sys.path.insert(0, os.path.dirname(__file__))
from musemorphic.model import MuseMorphicConfig, MuseMorphic, PhraseVAE, LatentMamba
from musemorphic.tokenizer import REMIPlusTokenizer, notes_to_midi_file


def load_model(checkpoint_path: str, device: torch.device):
    """Load trained model from checkpoint."""
    ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
    
    config = MuseMorphicConfig(**ckpt['config'])
    
    vae = PhraseVAE(config).to(device)
    mamba = LatentMamba(config).to(device)
    
    vae.load_state_dict(ckpt['vae_state_dict'])
    mamba.load_state_dict(ckpt['mamba_state_dict'])
    
    model = MuseMorphic(config)
    model.phrase_vae = vae
    model.latent_mamba = mamba
    model.to(device)
    model.eval()
    
    return model, config


def generate_midi(
    model: MuseMorphic,
    config: MuseMorphicConfig,
    tokenizer: REMIPlusTokenizer,
    n_phrases: int = 32,
    temperature: float = 0.7,
    max_decode_len: int = 128,
    device: torch.device = torch.device('cpu'),
) -> list:
    """
    Generate MIDI notes from the model.
    
    Returns list of note dicts: {pitch, start, duration, velocity}
    """
    with torch.no_grad():
        # Stage 2: Generate latent phrase sequence
        z_generated = model.latent_mamba.generate(
            n_phrases=n_phrases,
            temperature=temperature,
            batch_size=1,
        )
        
        # Stage 1: Decode each phrase latent to tokens
        all_tokens = []
        for t in range(z_generated.shape[1]):
            z = z_generated[:, t]
            
            # Autoregressive decode
            generated_ids = [config.bos_token_id]
            
            for _ in range(max_decode_len):
                input_tensor = torch.tensor([generated_ids], dtype=torch.long, device=device)
                logits = model.phrase_vae.decode(z, input_tensor)
                
                next_logits = logits[0, -1] / max(temperature, 0.1)
                probs = F.softmax(next_logits, dim=-1)
                next_token = torch.multinomial(probs, 1).item()
                generated_ids.append(next_token)
                
                if next_token == config.eos_token_id:
                    break
            
            phrase_tokens = tokenizer.decode(generated_ids)
            all_tokens.extend(phrase_tokens)
    
    # Convert tokens to MIDI notes
    notes = tokenizer.tokens_to_midi_notes(all_tokens)
    return notes, all_tokens


def main():
    parser = argparse.ArgumentParser(description='MuseMorphic MIDI Generator')
    parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint')
    parser.add_argument('--output', type=str, default='generated.mid', help='Output MIDI file path')
    parser.add_argument('--n_phrases', type=int, default=32, help='Number of phrases to generate')
    parser.add_argument('--temperature', type=float, default=0.7, help='Sampling temperature')
    parser.add_argument('--tempo', type=float, default=120.0, help='Output MIDI tempo')
    parser.add_argument('--device', type=str, default='auto', help='Device (auto/cuda/cpu)')
    
    args = parser.parse_args()
    
    # Device
    if args.device == 'auto':
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    else:
        device = torch.device(args.device)
    
    print(f'Device: {device}')
    
    # Load model
    print(f'Loading model from {args.checkpoint}...')
    model, config = load_model(args.checkpoint, device)
    
    params = sum(p.numel() for p in model.parameters())
    print(f'Model parameters: {params:,} ({params/1e6:.2f}M)')
    
    # Tokenizer
    tokenizer = REMIPlusTokenizer()
    
    # Generate
    print(f'Generating {args.n_phrases} phrases at temperature {args.temperature}...')
    notes, tokens = generate_midi(
        model, config, tokenizer,
        n_phrases=args.n_phrases,
        temperature=args.temperature,
        device=device,
    )
    
    print(f'Generated {len(notes)} notes, {len(tokens)} tokens')
    
    if notes:
        # Write MIDI
        success = notes_to_midi_file(notes, args.output, tempo=args.tempo)
        if success:
            print(f'\n🎵 MIDI saved to: {args.output}')
            total_dur = max(n['start'] + n['duration'] for n in notes)
            print(f'  Duration: ~{total_dur/480:.1f} beats ({total_dur/480/args.tempo*60:.1f} seconds at {args.tempo} BPM)')
            pitches = [n['pitch'] for n in notes]
            print(f'  Pitch range: {min(pitches)}-{max(pitches)}')
            print(f'  Note count: {len(notes)}')
        else:
            print('Failed to write MIDI. Install midiutil: pip install midiutil')
    else:
        print('No notes generated. Model may need more training.')


if __name__ == '__main__':
    main()