asdf98 commited on
Commit
39cadac
·
verified ·
1 Parent(s): 893db59

Upload musemorphic/data_pipeline.py

Browse files
Files changed (1) hide show
  1. musemorphic/data_pipeline.py +386 -0
musemorphic/data_pipeline.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MuseMorphic Data Pipeline
3
+ ==========================
4
+
5
+ Automatic MIDI dataset discovery, download, and preprocessing.
6
+ Supports multiple dataset sources with automatic format detection.
7
+
8
+ Datasets (auto-selected by availability and size):
9
+ 1. MAESTRO v3 (piano, ~1200 pieces, HQ performances)
10
+ 2. POP909 (pop, ~800 songs, multi-track)
11
+ 3. Los Angeles MIDI Dataset (diverse, large)
12
+ 4. Custom MIDI file directories
13
+ """
14
+
15
+ import os
16
+ import glob
17
+ import json
18
+ import random
19
+ import logging
20
+ from typing import List, Dict, Tuple, Optional
21
+ from pathlib import Path
22
+
23
+ import numpy as np
24
+ import torch
25
+ from torch.utils.data import Dataset
26
+
27
+ from tokenizer import REMIPlusTokenizer, TokenizerConfig
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ # ============================================================================
33
+ # Dataset Discovery & Download
34
+ # ============================================================================
35
+
36
+ DATASET_REGISTRY = {
37
+ 'maestro_v1_sustain': {
38
+ 'hf_id': 'roszcz/maestro-v1-sustain',
39
+ 'description': 'MAESTRO piano performances with sustain',
40
+ 'format': 'note_events', # Has 'notes' column with {pitch, start, duration, velocity}
41
+ 'priority': 1,
42
+ 'genre': 'classical',
43
+ },
44
+ 'maestro_v3': {
45
+ 'hf_id': 'roszcz/maestro-v3-public',
46
+ 'description': 'MAESTRO v3 piano performances',
47
+ 'format': 'note_events',
48
+ 'priority': 2,
49
+ 'genre': 'classical',
50
+ },
51
+ 'midi_dataset_1': {
52
+ 'hf_id': 'B-K/midi-dataset',
53
+ 'description': 'Aria MIDI dataset with MIDI files',
54
+ 'format': 'midi_bytes',
55
+ 'priority': 3,
56
+ 'genre': 'mixed',
57
+ },
58
+ 'midi_dataset_2': {
59
+ 'hf_id': 'B-K/midi-dataset-2',
60
+ 'description': 'MidiCaps dataset with MIDI files',
61
+ 'format': 'midi_bytes',
62
+ 'priority': 4,
63
+ 'genre': 'mixed',
64
+ },
65
+ }
66
+
67
+
68
+ def auto_select_dataset(preferred_genre: str = 'any', max_size_gb: float = 2.0) -> str:
69
+ """
70
+ Automatically select the best available dataset.
71
+
72
+ Priority:
73
+ 1. MAESTRO (high quality, well-structured)
74
+ 2. B-K MIDI datasets (pre-processed, easy to load)
75
+ 3. Large collections (for diversity)
76
+ """
77
+ for name, info in sorted(DATASET_REGISTRY.items(), key=lambda x: x[1]['priority']):
78
+ if preferred_genre != 'any' and info['genre'] != preferred_genre and info['genre'] != 'mixed':
79
+ continue
80
+
81
+ logger.info(f"Selected dataset: {name} ({info['description']})")
82
+ return name
83
+
84
+ return list(DATASET_REGISTRY.keys())[0]
85
+
86
+
87
+ def load_dataset_notes(dataset_name: str, split: str = 'train',
88
+ max_pieces: int = None) -> List[Dict]:
89
+ """
90
+ Load a dataset and return as list of note event dicts.
91
+
92
+ Each piece is a dict with:
93
+ - notes: List[Dict] with pitch, start, duration, velocity
94
+ - tempo: float
95
+ - time_sig: Tuple[int, int]
96
+ - metadata: Dict (composer, title, etc.)
97
+ """
98
+ from datasets import load_dataset
99
+
100
+ info = DATASET_REGISTRY[dataset_name]
101
+ hf_id = info['hf_id']
102
+
103
+ logger.info(f"Loading dataset: {hf_id} (split={split})")
104
+
105
+ try:
106
+ ds = load_dataset(hf_id, split=split, trust_remote_code=True)
107
+ except Exception as e:
108
+ logger.warning(f"Failed to load {hf_id}: {e}")
109
+ logger.info("Falling back to synthetic data generation")
110
+ return _generate_synthetic_dataset(max_pieces or 100)
111
+
112
+ pieces = []
113
+ n = min(len(ds), max_pieces) if max_pieces else len(ds)
114
+
115
+ for i in range(n):
116
+ item = ds[i]
117
+
118
+ if info['format'] == 'note_events':
119
+ piece = _parse_note_events_format(item)
120
+ elif info['format'] == 'midi_bytes':
121
+ piece = _parse_midi_bytes_format(item)
122
+ else:
123
+ continue
124
+
125
+ if piece and len(piece.get('notes', [])) > 0:
126
+ pieces.append(piece)
127
+
128
+ logger.info(f"Loaded {len(pieces)} pieces from {dataset_name}")
129
+ return pieces
130
+
131
+
132
+ def _parse_note_events_format(item: Dict) -> Optional[Dict]:
133
+ """Parse note events format (MAESTRO-style)."""
134
+ try:
135
+ notes_data = item.get('notes', {})
136
+
137
+ if isinstance(notes_data, dict):
138
+ # Columnar format: {pitch: [...], start: [...], duration: [...], velocity: [...]}
139
+ pitches = notes_data.get('pitch', [])
140
+ starts = notes_data.get('start', [])
141
+ durations = notes_data.get('duration', [])
142
+ velocities = notes_data.get('velocity', [])
143
+
144
+ notes = []
145
+ for j in range(len(pitches)):
146
+ notes.append({
147
+ 'pitch': int(pitches[j]),
148
+ 'start': int(float(starts[j]) * 480), # Convert to ticks
149
+ 'duration': max(1, int(float(durations[j]) * 480)),
150
+ 'velocity': int(velocities[j]) if j < len(velocities) else 80,
151
+ })
152
+ else:
153
+ return None
154
+
155
+ return {
156
+ 'notes': notes,
157
+ 'tempo': 120.0, # Default, could extract from MIDI
158
+ 'time_sig': (4, 4),
159
+ 'metadata': {
160
+ 'composer': item.get('composer', 'Unknown'),
161
+ 'title': item.get('title', 'Untitled'),
162
+ }
163
+ }
164
+ except Exception as e:
165
+ logger.debug(f"Failed to parse note events: {e}")
166
+ return None
167
+
168
+
169
+ def _parse_midi_bytes_format(item: Dict) -> Optional[Dict]:
170
+ """Parse MIDI bytes format."""
171
+ try:
172
+ import pretty_midi
173
+ import io
174
+
175
+ midi_data = item.get('midi', None)
176
+ if midi_data is None:
177
+ return None
178
+
179
+ if isinstance(midi_data, bytes):
180
+ pm = pretty_midi.PrettyMIDI(io.BytesIO(midi_data))
181
+ else:
182
+ return None
183
+
184
+ tempo = pm.estimate_tempo()
185
+ time_sig = (4, 4)
186
+ if pm.time_signature_changes:
187
+ ts = pm.time_signature_changes[0]
188
+ time_sig = (ts.numerator, ts.denominator)
189
+
190
+ notes = []
191
+ tpb = 480
192
+
193
+ for instrument in pm.instruments:
194
+ if instrument.is_drum:
195
+ continue
196
+ for note in instrument.notes:
197
+ start_ticks = int(note.start * tempo / 60.0 * tpb)
198
+ duration_ticks = int((note.end - note.start) * tempo / 60.0 * tpb)
199
+ notes.append({
200
+ 'pitch': note.pitch,
201
+ 'start': start_ticks,
202
+ 'duration': max(1, duration_ticks),
203
+ 'velocity': note.velocity,
204
+ })
205
+
206
+ return {
207
+ 'notes': notes,
208
+ 'tempo': tempo,
209
+ 'time_sig': time_sig,
210
+ 'metadata': {},
211
+ }
212
+ except Exception as e:
213
+ logger.debug(f"Failed to parse MIDI bytes: {e}")
214
+ return None
215
+
216
+
217
+ def _generate_synthetic_dataset(n_pieces: int = 100) -> List[Dict]:
218
+ """Generate synthetic MIDI-like data for testing/fallback."""
219
+ logger.info(f"Generating {n_pieces} synthetic pieces...")
220
+
221
+ pieces = []
222
+ scales = {
223
+ 'major': [0, 2, 4, 5, 7, 9, 11],
224
+ 'minor': [0, 2, 3, 5, 7, 8, 10],
225
+ 'pentatonic': [0, 2, 4, 7, 9],
226
+ }
227
+
228
+ for _ in range(n_pieces):
229
+ scale_name = random.choice(list(scales.keys()))
230
+ scale = scales[scale_name]
231
+ root = random.randint(48, 72) # Middle range
232
+ tempo = random.choice([80, 100, 120, 140, 160])
233
+ time_sig = random.choice([(4, 4), (3, 4), (6, 8)])
234
+
235
+ tpb = 480
236
+ beats_per_bar = time_sig[0] * (4.0 / time_sig[1])
237
+ ticks_per_bar = int(tpb * beats_per_bar)
238
+ n_bars = random.randint(8, 32)
239
+
240
+ notes = []
241
+ for bar in range(n_bars):
242
+ n_notes = random.randint(4, 16)
243
+ for _ in range(n_notes):
244
+ degree = random.choice(scale)
245
+ octave_offset = random.choice([-12, 0, 0, 0, 12])
246
+ pitch = root + degree + octave_offset
247
+ pitch = max(21, min(108, pitch))
248
+
249
+ position = random.randint(0, 15) * (ticks_per_bar // 16)
250
+ start = bar * ticks_per_bar + position
251
+
252
+ duration = random.choice([tpb // 4, tpb // 2, tpb, tpb * 2])
253
+ velocity = random.randint(40, 110)
254
+
255
+ notes.append({
256
+ 'pitch': pitch,
257
+ 'start': start,
258
+ 'duration': duration,
259
+ 'velocity': velocity,
260
+ })
261
+
262
+ pieces.append({
263
+ 'notes': notes,
264
+ 'tempo': tempo,
265
+ 'time_sig': time_sig,
266
+ 'metadata': {'scale': scale_name, 'root': root},
267
+ })
268
+
269
+ return pieces
270
+
271
+
272
+ # ============================================================================
273
+ # Preprocessing Pipeline
274
+ # ============================================================================
275
+
276
+ def preprocess_dataset(
277
+ pieces: List[Dict],
278
+ tokenizer: REMIPlusTokenizer,
279
+ max_phrase_len: int = 256,
280
+ bars_per_phrase: int = 1,
281
+ ) -> Tuple[List[List[int]], List[Dict]]:
282
+ """
283
+ Full preprocessing pipeline:
284
+ 1. Convert note events → REMI+ tokens
285
+ 2. Segment into phrases (bar-level)
286
+ 3. Encode to integer IDs
287
+ 4. Compute control attributes per phrase
288
+
289
+ Returns:
290
+ - phrases: List of token ID sequences
291
+ - controls: List of control dicts per phrase
292
+ """
293
+ all_phrases = []
294
+ all_controls = []
295
+
296
+ for piece in pieces:
297
+ notes = piece['notes']
298
+ tempo = piece.get('tempo', 120.0)
299
+ time_sig = piece.get('time_sig', (4, 4))
300
+
301
+ # Step 1: Notes → REMI+ tokens
302
+ tokens = tokenizer.midi_to_remi_tokens(notes, tempo, time_sig)
303
+
304
+ if len(tokens) < 5:
305
+ continue
306
+
307
+ # Step 2: Segment into phrases
308
+ phrase_groups = tokenizer.segment_into_phrases(tokens, bars_per_phrase)
309
+
310
+ for phrase_tokens in phrase_groups:
311
+ if len(phrase_tokens) < 3:
312
+ continue
313
+
314
+ # Step 3: Encode to IDs
315
+ ids = tokenizer.encode(phrase_tokens)
316
+
317
+ if len(ids) > max_phrase_len:
318
+ ids = ids[:max_phrase_len - 1] + [tokenizer.eos_id]
319
+
320
+ all_phrases.append(ids)
321
+
322
+ # Step 4: Compute controls
323
+ controls = tokenizer.compute_controls(phrase_tokens)
324
+ all_controls.append(controls)
325
+
326
+ logger.info(f"Preprocessed {len(pieces)} pieces → {len(all_phrases)} phrases")
327
+ return all_phrases, all_controls
328
+
329
+
330
+ # ============================================================================
331
+ # Complete Data Pipeline
332
+ # ============================================================================
333
+
334
+ def prepare_training_data(
335
+ dataset_name: Optional[str] = None,
336
+ max_pieces: int = None,
337
+ max_phrase_len: int = 256,
338
+ data_dir: str = './data',
339
+ ) -> Tuple[List[List[int]], List[Dict], REMIPlusTokenizer]:
340
+ """
341
+ Complete data pipeline: discover → download → preprocess → return.
342
+
343
+ Args:
344
+ dataset_name: Override auto-selection. None = auto.
345
+ max_pieces: Limit number of pieces to load.
346
+ max_phrase_len: Max tokens per phrase.
347
+ data_dir: Directory for caching.
348
+
349
+ Returns:
350
+ phrases: List of token ID sequences
351
+ controls: List of control dicts
352
+ tokenizer: Configured tokenizer
353
+ """
354
+ # Auto-select dataset
355
+ if dataset_name is None:
356
+ dataset_name = auto_select_dataset()
357
+
358
+ # Load
359
+ pieces = load_dataset_notes(dataset_name, max_pieces=max_pieces)
360
+
361
+ # Create tokenizer
362
+ tokenizer = REMIPlusTokenizer()
363
+
364
+ # Preprocess
365
+ phrases, controls = preprocess_dataset(pieces, tokenizer, max_phrase_len)
366
+
367
+ # Cache
368
+ os.makedirs(data_dir, exist_ok=True)
369
+ cache_path = os.path.join(data_dir, f'{dataset_name}_phrases.pt')
370
+ torch.save({
371
+ 'phrases': phrases,
372
+ 'controls': controls,
373
+ }, cache_path)
374
+ logger.info(f"Cached preprocessed data to {cache_path}")
375
+
376
+ # Save tokenizer
377
+ tokenizer.save(os.path.join(data_dir, 'tokenizer'))
378
+
379
+ return phrases, controls, tokenizer
380
+
381
+
382
+ if __name__ == "__main__":
383
+ phrases, controls, tokenizer = prepare_training_data(max_pieces=50)
384
+ print(f"Prepared {len(phrases)} phrases")
385
+ print(f"Sample phrase length: {len(phrases[0])}")
386
+ print(f"Tokenizer vocab size: {tokenizer.vocab_size}")