MuseMorphic / musemorphic /data_pipeline.py
asdf98's picture
Upload musemorphic/data_pipeline.py
39cadac verified
"""
MuseMorphic Data Pipeline
==========================
Automatic MIDI dataset discovery, download, and preprocessing.
Supports multiple dataset sources with automatic format detection.
Datasets (auto-selected by availability and size):
1. MAESTRO v3 (piano, ~1200 pieces, HQ performances)
2. POP909 (pop, ~800 songs, multi-track)
3. Los Angeles MIDI Dataset (diverse, large)
4. Custom MIDI file directories
"""
import os
import glob
import json
import random
import logging
from typing import List, Dict, Tuple, Optional
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import Dataset
from tokenizer import REMIPlusTokenizer, TokenizerConfig
logger = logging.getLogger(__name__)
# ============================================================================
# Dataset Discovery & Download
# ============================================================================
DATASET_REGISTRY = {
'maestro_v1_sustain': {
'hf_id': 'roszcz/maestro-v1-sustain',
'description': 'MAESTRO piano performances with sustain',
'format': 'note_events', # Has 'notes' column with {pitch, start, duration, velocity}
'priority': 1,
'genre': 'classical',
},
'maestro_v3': {
'hf_id': 'roszcz/maestro-v3-public',
'description': 'MAESTRO v3 piano performances',
'format': 'note_events',
'priority': 2,
'genre': 'classical',
},
'midi_dataset_1': {
'hf_id': 'B-K/midi-dataset',
'description': 'Aria MIDI dataset with MIDI files',
'format': 'midi_bytes',
'priority': 3,
'genre': 'mixed',
},
'midi_dataset_2': {
'hf_id': 'B-K/midi-dataset-2',
'description': 'MidiCaps dataset with MIDI files',
'format': 'midi_bytes',
'priority': 4,
'genre': 'mixed',
},
}
def auto_select_dataset(preferred_genre: str = 'any', max_size_gb: float = 2.0) -> str:
"""
Automatically select the best available dataset.
Priority:
1. MAESTRO (high quality, well-structured)
2. B-K MIDI datasets (pre-processed, easy to load)
3. Large collections (for diversity)
"""
for name, info in sorted(DATASET_REGISTRY.items(), key=lambda x: x[1]['priority']):
if preferred_genre != 'any' and info['genre'] != preferred_genre and info['genre'] != 'mixed':
continue
logger.info(f"Selected dataset: {name} ({info['description']})")
return name
return list(DATASET_REGISTRY.keys())[0]
def load_dataset_notes(dataset_name: str, split: str = 'train',
max_pieces: int = None) -> List[Dict]:
"""
Load a dataset and return as list of note event dicts.
Each piece is a dict with:
- notes: List[Dict] with pitch, start, duration, velocity
- tempo: float
- time_sig: Tuple[int, int]
- metadata: Dict (composer, title, etc.)
"""
from datasets import load_dataset
info = DATASET_REGISTRY[dataset_name]
hf_id = info['hf_id']
logger.info(f"Loading dataset: {hf_id} (split={split})")
try:
ds = load_dataset(hf_id, split=split, trust_remote_code=True)
except Exception as e:
logger.warning(f"Failed to load {hf_id}: {e}")
logger.info("Falling back to synthetic data generation")
return _generate_synthetic_dataset(max_pieces or 100)
pieces = []
n = min(len(ds), max_pieces) if max_pieces else len(ds)
for i in range(n):
item = ds[i]
if info['format'] == 'note_events':
piece = _parse_note_events_format(item)
elif info['format'] == 'midi_bytes':
piece = _parse_midi_bytes_format(item)
else:
continue
if piece and len(piece.get('notes', [])) > 0:
pieces.append(piece)
logger.info(f"Loaded {len(pieces)} pieces from {dataset_name}")
return pieces
def _parse_note_events_format(item: Dict) -> Optional[Dict]:
"""Parse note events format (MAESTRO-style)."""
try:
notes_data = item.get('notes', {})
if isinstance(notes_data, dict):
# Columnar format: {pitch: [...], start: [...], duration: [...], velocity: [...]}
pitches = notes_data.get('pitch', [])
starts = notes_data.get('start', [])
durations = notes_data.get('duration', [])
velocities = notes_data.get('velocity', [])
notes = []
for j in range(len(pitches)):
notes.append({
'pitch': int(pitches[j]),
'start': int(float(starts[j]) * 480), # Convert to ticks
'duration': max(1, int(float(durations[j]) * 480)),
'velocity': int(velocities[j]) if j < len(velocities) else 80,
})
else:
return None
return {
'notes': notes,
'tempo': 120.0, # Default, could extract from MIDI
'time_sig': (4, 4),
'metadata': {
'composer': item.get('composer', 'Unknown'),
'title': item.get('title', 'Untitled'),
}
}
except Exception as e:
logger.debug(f"Failed to parse note events: {e}")
return None
def _parse_midi_bytes_format(item: Dict) -> Optional[Dict]:
"""Parse MIDI bytes format."""
try:
import pretty_midi
import io
midi_data = item.get('midi', None)
if midi_data is None:
return None
if isinstance(midi_data, bytes):
pm = pretty_midi.PrettyMIDI(io.BytesIO(midi_data))
else:
return None
tempo = pm.estimate_tempo()
time_sig = (4, 4)
if pm.time_signature_changes:
ts = pm.time_signature_changes[0]
time_sig = (ts.numerator, ts.denominator)
notes = []
tpb = 480
for instrument in pm.instruments:
if instrument.is_drum:
continue
for note in instrument.notes:
start_ticks = int(note.start * tempo / 60.0 * tpb)
duration_ticks = int((note.end - note.start) * tempo / 60.0 * tpb)
notes.append({
'pitch': note.pitch,
'start': start_ticks,
'duration': max(1, duration_ticks),
'velocity': note.velocity,
})
return {
'notes': notes,
'tempo': tempo,
'time_sig': time_sig,
'metadata': {},
}
except Exception as e:
logger.debug(f"Failed to parse MIDI bytes: {e}")
return None
def _generate_synthetic_dataset(n_pieces: int = 100) -> List[Dict]:
"""Generate synthetic MIDI-like data for testing/fallback."""
logger.info(f"Generating {n_pieces} synthetic pieces...")
pieces = []
scales = {
'major': [0, 2, 4, 5, 7, 9, 11],
'minor': [0, 2, 3, 5, 7, 8, 10],
'pentatonic': [0, 2, 4, 7, 9],
}
for _ in range(n_pieces):
scale_name = random.choice(list(scales.keys()))
scale = scales[scale_name]
root = random.randint(48, 72) # Middle range
tempo = random.choice([80, 100, 120, 140, 160])
time_sig = random.choice([(4, 4), (3, 4), (6, 8)])
tpb = 480
beats_per_bar = time_sig[0] * (4.0 / time_sig[1])
ticks_per_bar = int(tpb * beats_per_bar)
n_bars = random.randint(8, 32)
notes = []
for bar in range(n_bars):
n_notes = random.randint(4, 16)
for _ in range(n_notes):
degree = random.choice(scale)
octave_offset = random.choice([-12, 0, 0, 0, 12])
pitch = root + degree + octave_offset
pitch = max(21, min(108, pitch))
position = random.randint(0, 15) * (ticks_per_bar // 16)
start = bar * ticks_per_bar + position
duration = random.choice([tpb // 4, tpb // 2, tpb, tpb * 2])
velocity = random.randint(40, 110)
notes.append({
'pitch': pitch,
'start': start,
'duration': duration,
'velocity': velocity,
})
pieces.append({
'notes': notes,
'tempo': tempo,
'time_sig': time_sig,
'metadata': {'scale': scale_name, 'root': root},
})
return pieces
# ============================================================================
# Preprocessing Pipeline
# ============================================================================
def preprocess_dataset(
pieces: List[Dict],
tokenizer: REMIPlusTokenizer,
max_phrase_len: int = 256,
bars_per_phrase: int = 1,
) -> Tuple[List[List[int]], List[Dict]]:
"""
Full preprocessing pipeline:
1. Convert note events → REMI+ tokens
2. Segment into phrases (bar-level)
3. Encode to integer IDs
4. Compute control attributes per phrase
Returns:
- phrases: List of token ID sequences
- controls: List of control dicts per phrase
"""
all_phrases = []
all_controls = []
for piece in pieces:
notes = piece['notes']
tempo = piece.get('tempo', 120.0)
time_sig = piece.get('time_sig', (4, 4))
# Step 1: Notes → REMI+ tokens
tokens = tokenizer.midi_to_remi_tokens(notes, tempo, time_sig)
if len(tokens) < 5:
continue
# Step 2: Segment into phrases
phrase_groups = tokenizer.segment_into_phrases(tokens, bars_per_phrase)
for phrase_tokens in phrase_groups:
if len(phrase_tokens) < 3:
continue
# Step 3: Encode to IDs
ids = tokenizer.encode(phrase_tokens)
if len(ids) > max_phrase_len:
ids = ids[:max_phrase_len - 1] + [tokenizer.eos_id]
all_phrases.append(ids)
# Step 4: Compute controls
controls = tokenizer.compute_controls(phrase_tokens)
all_controls.append(controls)
logger.info(f"Preprocessed {len(pieces)} pieces → {len(all_phrases)} phrases")
return all_phrases, all_controls
# ============================================================================
# Complete Data Pipeline
# ============================================================================
def prepare_training_data(
dataset_name: Optional[str] = None,
max_pieces: int = None,
max_phrase_len: int = 256,
data_dir: str = './data',
) -> Tuple[List[List[int]], List[Dict], REMIPlusTokenizer]:
"""
Complete data pipeline: discover → download → preprocess → return.
Args:
dataset_name: Override auto-selection. None = auto.
max_pieces: Limit number of pieces to load.
max_phrase_len: Max tokens per phrase.
data_dir: Directory for caching.
Returns:
phrases: List of token ID sequences
controls: List of control dicts
tokenizer: Configured tokenizer
"""
# Auto-select dataset
if dataset_name is None:
dataset_name = auto_select_dataset()
# Load
pieces = load_dataset_notes(dataset_name, max_pieces=max_pieces)
# Create tokenizer
tokenizer = REMIPlusTokenizer()
# Preprocess
phrases, controls = preprocess_dataset(pieces, tokenizer, max_phrase_len)
# Cache
os.makedirs(data_dir, exist_ok=True)
cache_path = os.path.join(data_dir, f'{dataset_name}_phrases.pt')
torch.save({
'phrases': phrases,
'controls': controls,
}, cache_path)
logger.info(f"Cached preprocessed data to {cache_path}")
# Save tokenizer
tokenizer.save(os.path.join(data_dir, 'tokenizer'))
return phrases, controls, tokenizer
if __name__ == "__main__":
phrases, controls, tokenizer = prepare_training_data(max_pieces=50)
print(f"Prepared {len(phrases)} phrases")
print(f"Sample phrase length: {len(phrases[0])}")
print(f"Tokenizer vocab size: {tokenizer.vocab_size}")