| """ |
| MuseMorphic Data Pipeline |
| ========================== |
| |
| Automatic MIDI dataset discovery, download, and preprocessing. |
| Supports multiple dataset sources with automatic format detection. |
| |
| Datasets (auto-selected by availability and size): |
| 1. MAESTRO v3 (piano, ~1200 pieces, HQ performances) |
| 2. POP909 (pop, ~800 songs, multi-track) |
| 3. Los Angeles MIDI Dataset (diverse, large) |
| 4. Custom MIDI file directories |
| """ |
|
|
| import os |
| import glob |
| import json |
| import random |
| import logging |
| from typing import List, Dict, Tuple, Optional |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset |
|
|
| from tokenizer import REMIPlusTokenizer, TokenizerConfig |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| DATASET_REGISTRY = { |
| 'maestro_v1_sustain': { |
| 'hf_id': 'roszcz/maestro-v1-sustain', |
| 'description': 'MAESTRO piano performances with sustain', |
| 'format': 'note_events', |
| 'priority': 1, |
| 'genre': 'classical', |
| }, |
| 'maestro_v3': { |
| 'hf_id': 'roszcz/maestro-v3-public', |
| 'description': 'MAESTRO v3 piano performances', |
| 'format': 'note_events', |
| 'priority': 2, |
| 'genre': 'classical', |
| }, |
| 'midi_dataset_1': { |
| 'hf_id': 'B-K/midi-dataset', |
| 'description': 'Aria MIDI dataset with MIDI files', |
| 'format': 'midi_bytes', |
| 'priority': 3, |
| 'genre': 'mixed', |
| }, |
| 'midi_dataset_2': { |
| 'hf_id': 'B-K/midi-dataset-2', |
| 'description': 'MidiCaps dataset with MIDI files', |
| 'format': 'midi_bytes', |
| 'priority': 4, |
| 'genre': 'mixed', |
| }, |
| } |
|
|
|
|
| def auto_select_dataset(preferred_genre: str = 'any', max_size_gb: float = 2.0) -> str: |
| """ |
| Automatically select the best available dataset. |
| |
| Priority: |
| 1. MAESTRO (high quality, well-structured) |
| 2. B-K MIDI datasets (pre-processed, easy to load) |
| 3. Large collections (for diversity) |
| """ |
| for name, info in sorted(DATASET_REGISTRY.items(), key=lambda x: x[1]['priority']): |
| if preferred_genre != 'any' and info['genre'] != preferred_genre and info['genre'] != 'mixed': |
| continue |
| |
| logger.info(f"Selected dataset: {name} ({info['description']})") |
| return name |
| |
| return list(DATASET_REGISTRY.keys())[0] |
|
|
|
|
| def load_dataset_notes(dataset_name: str, split: str = 'train', |
| max_pieces: int = None) -> List[Dict]: |
| """ |
| Load a dataset and return as list of note event dicts. |
| |
| Each piece is a dict with: |
| - notes: List[Dict] with pitch, start, duration, velocity |
| - tempo: float |
| - time_sig: Tuple[int, int] |
| - metadata: Dict (composer, title, etc.) |
| """ |
| from datasets import load_dataset |
| |
| info = DATASET_REGISTRY[dataset_name] |
| hf_id = info['hf_id'] |
| |
| logger.info(f"Loading dataset: {hf_id} (split={split})") |
| |
| try: |
| ds = load_dataset(hf_id, split=split, trust_remote_code=True) |
| except Exception as e: |
| logger.warning(f"Failed to load {hf_id}: {e}") |
| logger.info("Falling back to synthetic data generation") |
| return _generate_synthetic_dataset(max_pieces or 100) |
| |
| pieces = [] |
| n = min(len(ds), max_pieces) if max_pieces else len(ds) |
| |
| for i in range(n): |
| item = ds[i] |
| |
| if info['format'] == 'note_events': |
| piece = _parse_note_events_format(item) |
| elif info['format'] == 'midi_bytes': |
| piece = _parse_midi_bytes_format(item) |
| else: |
| continue |
| |
| if piece and len(piece.get('notes', [])) > 0: |
| pieces.append(piece) |
| |
| logger.info(f"Loaded {len(pieces)} pieces from {dataset_name}") |
| return pieces |
|
|
|
|
| def _parse_note_events_format(item: Dict) -> Optional[Dict]: |
| """Parse note events format (MAESTRO-style).""" |
| try: |
| notes_data = item.get('notes', {}) |
| |
| if isinstance(notes_data, dict): |
| |
| pitches = notes_data.get('pitch', []) |
| starts = notes_data.get('start', []) |
| durations = notes_data.get('duration', []) |
| velocities = notes_data.get('velocity', []) |
| |
| notes = [] |
| for j in range(len(pitches)): |
| notes.append({ |
| 'pitch': int(pitches[j]), |
| 'start': int(float(starts[j]) * 480), |
| 'duration': max(1, int(float(durations[j]) * 480)), |
| 'velocity': int(velocities[j]) if j < len(velocities) else 80, |
| }) |
| else: |
| return None |
| |
| return { |
| 'notes': notes, |
| 'tempo': 120.0, |
| 'time_sig': (4, 4), |
| 'metadata': { |
| 'composer': item.get('composer', 'Unknown'), |
| 'title': item.get('title', 'Untitled'), |
| } |
| } |
| except Exception as e: |
| logger.debug(f"Failed to parse note events: {e}") |
| return None |
|
|
|
|
| def _parse_midi_bytes_format(item: Dict) -> Optional[Dict]: |
| """Parse MIDI bytes format.""" |
| try: |
| import pretty_midi |
| import io |
| |
| midi_data = item.get('midi', None) |
| if midi_data is None: |
| return None |
| |
| if isinstance(midi_data, bytes): |
| pm = pretty_midi.PrettyMIDI(io.BytesIO(midi_data)) |
| else: |
| return None |
| |
| tempo = pm.estimate_tempo() |
| time_sig = (4, 4) |
| if pm.time_signature_changes: |
| ts = pm.time_signature_changes[0] |
| time_sig = (ts.numerator, ts.denominator) |
| |
| notes = [] |
| tpb = 480 |
| |
| for instrument in pm.instruments: |
| if instrument.is_drum: |
| continue |
| for note in instrument.notes: |
| start_ticks = int(note.start * tempo / 60.0 * tpb) |
| duration_ticks = int((note.end - note.start) * tempo / 60.0 * tpb) |
| notes.append({ |
| 'pitch': note.pitch, |
| 'start': start_ticks, |
| 'duration': max(1, duration_ticks), |
| 'velocity': note.velocity, |
| }) |
| |
| return { |
| 'notes': notes, |
| 'tempo': tempo, |
| 'time_sig': time_sig, |
| 'metadata': {}, |
| } |
| except Exception as e: |
| logger.debug(f"Failed to parse MIDI bytes: {e}") |
| return None |
|
|
|
|
| def _generate_synthetic_dataset(n_pieces: int = 100) -> List[Dict]: |
| """Generate synthetic MIDI-like data for testing/fallback.""" |
| logger.info(f"Generating {n_pieces} synthetic pieces...") |
| |
| pieces = [] |
| scales = { |
| 'major': [0, 2, 4, 5, 7, 9, 11], |
| 'minor': [0, 2, 3, 5, 7, 8, 10], |
| 'pentatonic': [0, 2, 4, 7, 9], |
| } |
| |
| for _ in range(n_pieces): |
| scale_name = random.choice(list(scales.keys())) |
| scale = scales[scale_name] |
| root = random.randint(48, 72) |
| tempo = random.choice([80, 100, 120, 140, 160]) |
| time_sig = random.choice([(4, 4), (3, 4), (6, 8)]) |
| |
| tpb = 480 |
| beats_per_bar = time_sig[0] * (4.0 / time_sig[1]) |
| ticks_per_bar = int(tpb * beats_per_bar) |
| n_bars = random.randint(8, 32) |
| |
| notes = [] |
| for bar in range(n_bars): |
| n_notes = random.randint(4, 16) |
| for _ in range(n_notes): |
| degree = random.choice(scale) |
| octave_offset = random.choice([-12, 0, 0, 0, 12]) |
| pitch = root + degree + octave_offset |
| pitch = max(21, min(108, pitch)) |
| |
| position = random.randint(0, 15) * (ticks_per_bar // 16) |
| start = bar * ticks_per_bar + position |
| |
| duration = random.choice([tpb // 4, tpb // 2, tpb, tpb * 2]) |
| velocity = random.randint(40, 110) |
| |
| notes.append({ |
| 'pitch': pitch, |
| 'start': start, |
| 'duration': duration, |
| 'velocity': velocity, |
| }) |
| |
| pieces.append({ |
| 'notes': notes, |
| 'tempo': tempo, |
| 'time_sig': time_sig, |
| 'metadata': {'scale': scale_name, 'root': root}, |
| }) |
| |
| return pieces |
|
|
|
|
| |
| |
| |
|
|
| def preprocess_dataset( |
| pieces: List[Dict], |
| tokenizer: REMIPlusTokenizer, |
| max_phrase_len: int = 256, |
| bars_per_phrase: int = 1, |
| ) -> Tuple[List[List[int]], List[Dict]]: |
| """ |
| Full preprocessing pipeline: |
| 1. Convert note events → REMI+ tokens |
| 2. Segment into phrases (bar-level) |
| 3. Encode to integer IDs |
| 4. Compute control attributes per phrase |
| |
| Returns: |
| - phrases: List of token ID sequences |
| - controls: List of control dicts per phrase |
| """ |
| all_phrases = [] |
| all_controls = [] |
| |
| for piece in pieces: |
| notes = piece['notes'] |
| tempo = piece.get('tempo', 120.0) |
| time_sig = piece.get('time_sig', (4, 4)) |
| |
| |
| tokens = tokenizer.midi_to_remi_tokens(notes, tempo, time_sig) |
| |
| if len(tokens) < 5: |
| continue |
| |
| |
| phrase_groups = tokenizer.segment_into_phrases(tokens, bars_per_phrase) |
| |
| for phrase_tokens in phrase_groups: |
| if len(phrase_tokens) < 3: |
| continue |
| |
| |
| ids = tokenizer.encode(phrase_tokens) |
| |
| if len(ids) > max_phrase_len: |
| ids = ids[:max_phrase_len - 1] + [tokenizer.eos_id] |
| |
| all_phrases.append(ids) |
| |
| |
| controls = tokenizer.compute_controls(phrase_tokens) |
| all_controls.append(controls) |
| |
| logger.info(f"Preprocessed {len(pieces)} pieces → {len(all_phrases)} phrases") |
| return all_phrases, all_controls |
|
|
|
|
| |
| |
| |
|
|
| def prepare_training_data( |
| dataset_name: Optional[str] = None, |
| max_pieces: int = None, |
| max_phrase_len: int = 256, |
| data_dir: str = './data', |
| ) -> Tuple[List[List[int]], List[Dict], REMIPlusTokenizer]: |
| """ |
| Complete data pipeline: discover → download → preprocess → return. |
| |
| Args: |
| dataset_name: Override auto-selection. None = auto. |
| max_pieces: Limit number of pieces to load. |
| max_phrase_len: Max tokens per phrase. |
| data_dir: Directory for caching. |
| |
| Returns: |
| phrases: List of token ID sequences |
| controls: List of control dicts |
| tokenizer: Configured tokenizer |
| """ |
| |
| if dataset_name is None: |
| dataset_name = auto_select_dataset() |
| |
| |
| pieces = load_dataset_notes(dataset_name, max_pieces=max_pieces) |
| |
| |
| tokenizer = REMIPlusTokenizer() |
| |
| |
| phrases, controls = preprocess_dataset(pieces, tokenizer, max_phrase_len) |
| |
| |
| os.makedirs(data_dir, exist_ok=True) |
| cache_path = os.path.join(data_dir, f'{dataset_name}_phrases.pt') |
| torch.save({ |
| 'phrases': phrases, |
| 'controls': controls, |
| }, cache_path) |
| logger.info(f"Cached preprocessed data to {cache_path}") |
| |
| |
| tokenizer.save(os.path.join(data_dir, 'tokenizer')) |
| |
| return phrases, controls, tokenizer |
|
|
|
|
| if __name__ == "__main__": |
| phrases, controls, tokenizer = prepare_training_data(max_pieces=50) |
| print(f"Prepared {len(phrases)} phrases") |
| print(f"Sample phrase length: {len(phrases[0])}") |
| print(f"Tokenizer vocab size: {tokenizer.vocab_size}") |
|
|