asdf98 commited on
Commit
561ecde
·
verified ·
1 Parent(s): 9e6cfe4

Upload generate.py

Browse files
Files changed (1) hide show
  1. generate.py +153 -0
generate.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MuseMorphic Inference Script
3
+ =============================
4
+
5
+ Generate MIDI music from a trained MuseMorphic model.
6
+
7
+ Usage:
8
+ python generate.py --checkpoint ./checkpoints/musemorphic_model.pt \
9
+ --output generated.mid \
10
+ --n_phrases 32 \
11
+ --temperature 0.7
12
+ """
13
+
14
+ import argparse
15
+ import os
16
+ import sys
17
+ import torch
18
+ import torch.nn.functional as F
19
+
20
+ sys.path.insert(0, os.path.dirname(__file__))
21
+ from musemorphic.model import MuseMorphicConfig, MuseMorphic, PhraseVAE, LatentMamba
22
+ from musemorphic.tokenizer import REMIPlusTokenizer, notes_to_midi_file
23
+
24
+
25
+ def load_model(checkpoint_path: str, device: torch.device):
26
+ """Load trained model from checkpoint."""
27
+ ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
28
+
29
+ config = MuseMorphicConfig(**ckpt['config'])
30
+
31
+ vae = PhraseVAE(config).to(device)
32
+ mamba = LatentMamba(config).to(device)
33
+
34
+ vae.load_state_dict(ckpt['vae_state_dict'])
35
+ mamba.load_state_dict(ckpt['mamba_state_dict'])
36
+
37
+ model = MuseMorphic(config)
38
+ model.phrase_vae = vae
39
+ model.latent_mamba = mamba
40
+ model.to(device)
41
+ model.eval()
42
+
43
+ return model, config
44
+
45
+
46
+ def generate_midi(
47
+ model: MuseMorphic,
48
+ config: MuseMorphicConfig,
49
+ tokenizer: REMIPlusTokenizer,
50
+ n_phrases: int = 32,
51
+ temperature: float = 0.7,
52
+ max_decode_len: int = 128,
53
+ device: torch.device = torch.device('cpu'),
54
+ ) -> list:
55
+ """
56
+ Generate MIDI notes from the model.
57
+
58
+ Returns list of note dicts: {pitch, start, duration, velocity}
59
+ """
60
+ with torch.no_grad():
61
+ # Stage 2: Generate latent phrase sequence
62
+ z_generated = model.latent_mamba.generate(
63
+ n_phrases=n_phrases,
64
+ temperature=temperature,
65
+ batch_size=1,
66
+ )
67
+
68
+ # Stage 1: Decode each phrase latent to tokens
69
+ all_tokens = []
70
+ for t in range(z_generated.shape[1]):
71
+ z = z_generated[:, t]
72
+
73
+ # Autoregressive decode
74
+ generated_ids = [config.bos_token_id]
75
+
76
+ for _ in range(max_decode_len):
77
+ input_tensor = torch.tensor([generated_ids], dtype=torch.long, device=device)
78
+ logits = model.phrase_vae.decode(z, input_tensor)
79
+
80
+ next_logits = logits[0, -1] / max(temperature, 0.1)
81
+ probs = F.softmax(next_logits, dim=-1)
82
+ next_token = torch.multinomial(probs, 1).item()
83
+ generated_ids.append(next_token)
84
+
85
+ if next_token == config.eos_token_id:
86
+ break
87
+
88
+ phrase_tokens = tokenizer.decode(generated_ids)
89
+ all_tokens.extend(phrase_tokens)
90
+
91
+ # Convert tokens to MIDI notes
92
+ notes = tokenizer.tokens_to_midi_notes(all_tokens)
93
+ return notes, all_tokens
94
+
95
+
96
+ def main():
97
+ parser = argparse.ArgumentParser(description='MuseMorphic MIDI Generator')
98
+ parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint')
99
+ parser.add_argument('--output', type=str, default='generated.mid', help='Output MIDI file path')
100
+ parser.add_argument('--n_phrases', type=int, default=32, help='Number of phrases to generate')
101
+ parser.add_argument('--temperature', type=float, default=0.7, help='Sampling temperature')
102
+ parser.add_argument('--tempo', type=float, default=120.0, help='Output MIDI tempo')
103
+ parser.add_argument('--device', type=str, default='auto', help='Device (auto/cuda/cpu)')
104
+
105
+ args = parser.parse_args()
106
+
107
+ # Device
108
+ if args.device == 'auto':
109
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
110
+ else:
111
+ device = torch.device(args.device)
112
+
113
+ print(f'Device: {device}')
114
+
115
+ # Load model
116
+ print(f'Loading model from {args.checkpoint}...')
117
+ model, config = load_model(args.checkpoint, device)
118
+
119
+ params = sum(p.numel() for p in model.parameters())
120
+ print(f'Model parameters: {params:,} ({params/1e6:.2f}M)')
121
+
122
+ # Tokenizer
123
+ tokenizer = REMIPlusTokenizer()
124
+
125
+ # Generate
126
+ print(f'Generating {args.n_phrases} phrases at temperature {args.temperature}...')
127
+ notes, tokens = generate_midi(
128
+ model, config, tokenizer,
129
+ n_phrases=args.n_phrases,
130
+ temperature=args.temperature,
131
+ device=device,
132
+ )
133
+
134
+ print(f'Generated {len(notes)} notes, {len(tokens)} tokens')
135
+
136
+ if notes:
137
+ # Write MIDI
138
+ success = notes_to_midi_file(notes, args.output, tempo=args.tempo)
139
+ if success:
140
+ print(f'\n🎵 MIDI saved to: {args.output}')
141
+ total_dur = max(n['start'] + n['duration'] for n in notes)
142
+ print(f' Duration: ~{total_dur/480:.1f} beats ({total_dur/480/args.tempo*60:.1f} seconds at {args.tempo} BPM)')
143
+ pitches = [n['pitch'] for n in notes]
144
+ print(f' Pitch range: {min(pitches)}-{max(pitches)}')
145
+ print(f' Note count: {len(notes)}')
146
+ else:
147
+ print('Failed to write MIDI. Install midiutil: pip install midiutil')
148
+ else:
149
+ print('No notes generated. Model may need more training.')
150
+
151
+
152
+ if __name__ == '__main__':
153
+ main()