File size: 12,499 Bytes
39cadac | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 | """
MuseMorphic Data Pipeline
==========================
Automatic MIDI dataset discovery, download, and preprocessing.
Supports multiple dataset sources with automatic format detection.
Datasets (auto-selected by availability and size):
1. MAESTRO v3 (piano, ~1200 pieces, HQ performances)
2. POP909 (pop, ~800 songs, multi-track)
3. Los Angeles MIDI Dataset (diverse, large)
4. Custom MIDI file directories
"""
import os
import glob
import json
import random
import logging
from typing import List, Dict, Tuple, Optional
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import Dataset
from tokenizer import REMIPlusTokenizer, TokenizerConfig
logger = logging.getLogger(__name__)
# ============================================================================
# Dataset Discovery & Download
# ============================================================================
DATASET_REGISTRY = {
'maestro_v1_sustain': {
'hf_id': 'roszcz/maestro-v1-sustain',
'description': 'MAESTRO piano performances with sustain',
'format': 'note_events', # Has 'notes' column with {pitch, start, duration, velocity}
'priority': 1,
'genre': 'classical',
},
'maestro_v3': {
'hf_id': 'roszcz/maestro-v3-public',
'description': 'MAESTRO v3 piano performances',
'format': 'note_events',
'priority': 2,
'genre': 'classical',
},
'midi_dataset_1': {
'hf_id': 'B-K/midi-dataset',
'description': 'Aria MIDI dataset with MIDI files',
'format': 'midi_bytes',
'priority': 3,
'genre': 'mixed',
},
'midi_dataset_2': {
'hf_id': 'B-K/midi-dataset-2',
'description': 'MidiCaps dataset with MIDI files',
'format': 'midi_bytes',
'priority': 4,
'genre': 'mixed',
},
}
def auto_select_dataset(preferred_genre: str = 'any', max_size_gb: float = 2.0) -> str:
"""
Automatically select the best available dataset.
Priority:
1. MAESTRO (high quality, well-structured)
2. B-K MIDI datasets (pre-processed, easy to load)
3. Large collections (for diversity)
"""
for name, info in sorted(DATASET_REGISTRY.items(), key=lambda x: x[1]['priority']):
if preferred_genre != 'any' and info['genre'] != preferred_genre and info['genre'] != 'mixed':
continue
logger.info(f"Selected dataset: {name} ({info['description']})")
return name
return list(DATASET_REGISTRY.keys())[0]
def load_dataset_notes(dataset_name: str, split: str = 'train',
max_pieces: int = None) -> List[Dict]:
"""
Load a dataset and return as list of note event dicts.
Each piece is a dict with:
- notes: List[Dict] with pitch, start, duration, velocity
- tempo: float
- time_sig: Tuple[int, int]
- metadata: Dict (composer, title, etc.)
"""
from datasets import load_dataset
info = DATASET_REGISTRY[dataset_name]
hf_id = info['hf_id']
logger.info(f"Loading dataset: {hf_id} (split={split})")
try:
ds = load_dataset(hf_id, split=split, trust_remote_code=True)
except Exception as e:
logger.warning(f"Failed to load {hf_id}: {e}")
logger.info("Falling back to synthetic data generation")
return _generate_synthetic_dataset(max_pieces or 100)
pieces = []
n = min(len(ds), max_pieces) if max_pieces else len(ds)
for i in range(n):
item = ds[i]
if info['format'] == 'note_events':
piece = _parse_note_events_format(item)
elif info['format'] == 'midi_bytes':
piece = _parse_midi_bytes_format(item)
else:
continue
if piece and len(piece.get('notes', [])) > 0:
pieces.append(piece)
logger.info(f"Loaded {len(pieces)} pieces from {dataset_name}")
return pieces
def _parse_note_events_format(item: Dict) -> Optional[Dict]:
"""Parse note events format (MAESTRO-style)."""
try:
notes_data = item.get('notes', {})
if isinstance(notes_data, dict):
# Columnar format: {pitch: [...], start: [...], duration: [...], velocity: [...]}
pitches = notes_data.get('pitch', [])
starts = notes_data.get('start', [])
durations = notes_data.get('duration', [])
velocities = notes_data.get('velocity', [])
notes = []
for j in range(len(pitches)):
notes.append({
'pitch': int(pitches[j]),
'start': int(float(starts[j]) * 480), # Convert to ticks
'duration': max(1, int(float(durations[j]) * 480)),
'velocity': int(velocities[j]) if j < len(velocities) else 80,
})
else:
return None
return {
'notes': notes,
'tempo': 120.0, # Default, could extract from MIDI
'time_sig': (4, 4),
'metadata': {
'composer': item.get('composer', 'Unknown'),
'title': item.get('title', 'Untitled'),
}
}
except Exception as e:
logger.debug(f"Failed to parse note events: {e}")
return None
def _parse_midi_bytes_format(item: Dict) -> Optional[Dict]:
"""Parse MIDI bytes format."""
try:
import pretty_midi
import io
midi_data = item.get('midi', None)
if midi_data is None:
return None
if isinstance(midi_data, bytes):
pm = pretty_midi.PrettyMIDI(io.BytesIO(midi_data))
else:
return None
tempo = pm.estimate_tempo()
time_sig = (4, 4)
if pm.time_signature_changes:
ts = pm.time_signature_changes[0]
time_sig = (ts.numerator, ts.denominator)
notes = []
tpb = 480
for instrument in pm.instruments:
if instrument.is_drum:
continue
for note in instrument.notes:
start_ticks = int(note.start * tempo / 60.0 * tpb)
duration_ticks = int((note.end - note.start) * tempo / 60.0 * tpb)
notes.append({
'pitch': note.pitch,
'start': start_ticks,
'duration': max(1, duration_ticks),
'velocity': note.velocity,
})
return {
'notes': notes,
'tempo': tempo,
'time_sig': time_sig,
'metadata': {},
}
except Exception as e:
logger.debug(f"Failed to parse MIDI bytes: {e}")
return None
def _generate_synthetic_dataset(n_pieces: int = 100) -> List[Dict]:
"""Generate synthetic MIDI-like data for testing/fallback."""
logger.info(f"Generating {n_pieces} synthetic pieces...")
pieces = []
scales = {
'major': [0, 2, 4, 5, 7, 9, 11],
'minor': [0, 2, 3, 5, 7, 8, 10],
'pentatonic': [0, 2, 4, 7, 9],
}
for _ in range(n_pieces):
scale_name = random.choice(list(scales.keys()))
scale = scales[scale_name]
root = random.randint(48, 72) # Middle range
tempo = random.choice([80, 100, 120, 140, 160])
time_sig = random.choice([(4, 4), (3, 4), (6, 8)])
tpb = 480
beats_per_bar = time_sig[0] * (4.0 / time_sig[1])
ticks_per_bar = int(tpb * beats_per_bar)
n_bars = random.randint(8, 32)
notes = []
for bar in range(n_bars):
n_notes = random.randint(4, 16)
for _ in range(n_notes):
degree = random.choice(scale)
octave_offset = random.choice([-12, 0, 0, 0, 12])
pitch = root + degree + octave_offset
pitch = max(21, min(108, pitch))
position = random.randint(0, 15) * (ticks_per_bar // 16)
start = bar * ticks_per_bar + position
duration = random.choice([tpb // 4, tpb // 2, tpb, tpb * 2])
velocity = random.randint(40, 110)
notes.append({
'pitch': pitch,
'start': start,
'duration': duration,
'velocity': velocity,
})
pieces.append({
'notes': notes,
'tempo': tempo,
'time_sig': time_sig,
'metadata': {'scale': scale_name, 'root': root},
})
return pieces
# ============================================================================
# Preprocessing Pipeline
# ============================================================================
def preprocess_dataset(
pieces: List[Dict],
tokenizer: REMIPlusTokenizer,
max_phrase_len: int = 256,
bars_per_phrase: int = 1,
) -> Tuple[List[List[int]], List[Dict]]:
"""
Full preprocessing pipeline:
1. Convert note events → REMI+ tokens
2. Segment into phrases (bar-level)
3. Encode to integer IDs
4. Compute control attributes per phrase
Returns:
- phrases: List of token ID sequences
- controls: List of control dicts per phrase
"""
all_phrases = []
all_controls = []
for piece in pieces:
notes = piece['notes']
tempo = piece.get('tempo', 120.0)
time_sig = piece.get('time_sig', (4, 4))
# Step 1: Notes → REMI+ tokens
tokens = tokenizer.midi_to_remi_tokens(notes, tempo, time_sig)
if len(tokens) < 5:
continue
# Step 2: Segment into phrases
phrase_groups = tokenizer.segment_into_phrases(tokens, bars_per_phrase)
for phrase_tokens in phrase_groups:
if len(phrase_tokens) < 3:
continue
# Step 3: Encode to IDs
ids = tokenizer.encode(phrase_tokens)
if len(ids) > max_phrase_len:
ids = ids[:max_phrase_len - 1] + [tokenizer.eos_id]
all_phrases.append(ids)
# Step 4: Compute controls
controls = tokenizer.compute_controls(phrase_tokens)
all_controls.append(controls)
logger.info(f"Preprocessed {len(pieces)} pieces → {len(all_phrases)} phrases")
return all_phrases, all_controls
# ============================================================================
# Complete Data Pipeline
# ============================================================================
def prepare_training_data(
dataset_name: Optional[str] = None,
max_pieces: int = None,
max_phrase_len: int = 256,
data_dir: str = './data',
) -> Tuple[List[List[int]], List[Dict], REMIPlusTokenizer]:
"""
Complete data pipeline: discover → download → preprocess → return.
Args:
dataset_name: Override auto-selection. None = auto.
max_pieces: Limit number of pieces to load.
max_phrase_len: Max tokens per phrase.
data_dir: Directory for caching.
Returns:
phrases: List of token ID sequences
controls: List of control dicts
tokenizer: Configured tokenizer
"""
# Auto-select dataset
if dataset_name is None:
dataset_name = auto_select_dataset()
# Load
pieces = load_dataset_notes(dataset_name, max_pieces=max_pieces)
# Create tokenizer
tokenizer = REMIPlusTokenizer()
# Preprocess
phrases, controls = preprocess_dataset(pieces, tokenizer, max_phrase_len)
# Cache
os.makedirs(data_dir, exist_ok=True)
cache_path = os.path.join(data_dir, f'{dataset_name}_phrases.pt')
torch.save({
'phrases': phrases,
'controls': controls,
}, cache_path)
logger.info(f"Cached preprocessed data to {cache_path}")
# Save tokenizer
tokenizer.save(os.path.join(data_dir, 'tokenizer'))
return phrases, controls, tokenizer
if __name__ == "__main__":
phrases, controls, tokenizer = prepare_training_data(max_pieces=50)
print(f"Prepared {len(phrases)} phrases")
print(f"Sample phrase length: {len(phrases[0])}")
print(f"Tokenizer vocab size: {tokenizer.vocab_size}")
|