""" MuseMorphic Data Pipeline ========================== Automatic MIDI dataset discovery, download, and preprocessing. Supports multiple dataset sources with automatic format detection. Datasets (auto-selected by availability and size): 1. MAESTRO v3 (piano, ~1200 pieces, HQ performances) 2. POP909 (pop, ~800 songs, multi-track) 3. Los Angeles MIDI Dataset (diverse, large) 4. Custom MIDI file directories """ import os import glob import json import random import logging from typing import List, Dict, Tuple, Optional from pathlib import Path import numpy as np import torch from torch.utils.data import Dataset from tokenizer import REMIPlusTokenizer, TokenizerConfig logger = logging.getLogger(__name__) # ============================================================================ # Dataset Discovery & Download # ============================================================================ DATASET_REGISTRY = { 'maestro_v1_sustain': { 'hf_id': 'roszcz/maestro-v1-sustain', 'description': 'MAESTRO piano performances with sustain', 'format': 'note_events', # Has 'notes' column with {pitch, start, duration, velocity} 'priority': 1, 'genre': 'classical', }, 'maestro_v3': { 'hf_id': 'roszcz/maestro-v3-public', 'description': 'MAESTRO v3 piano performances', 'format': 'note_events', 'priority': 2, 'genre': 'classical', }, 'midi_dataset_1': { 'hf_id': 'B-K/midi-dataset', 'description': 'Aria MIDI dataset with MIDI files', 'format': 'midi_bytes', 'priority': 3, 'genre': 'mixed', }, 'midi_dataset_2': { 'hf_id': 'B-K/midi-dataset-2', 'description': 'MidiCaps dataset with MIDI files', 'format': 'midi_bytes', 'priority': 4, 'genre': 'mixed', }, } def auto_select_dataset(preferred_genre: str = 'any', max_size_gb: float = 2.0) -> str: """ Automatically select the best available dataset. Priority: 1. MAESTRO (high quality, well-structured) 2. B-K MIDI datasets (pre-processed, easy to load) 3. Large collections (for diversity) """ for name, info in sorted(DATASET_REGISTRY.items(), key=lambda x: x[1]['priority']): if preferred_genre != 'any' and info['genre'] != preferred_genre and info['genre'] != 'mixed': continue logger.info(f"Selected dataset: {name} ({info['description']})") return name return list(DATASET_REGISTRY.keys())[0] def load_dataset_notes(dataset_name: str, split: str = 'train', max_pieces: int = None) -> List[Dict]: """ Load a dataset and return as list of note event dicts. Each piece is a dict with: - notes: List[Dict] with pitch, start, duration, velocity - tempo: float - time_sig: Tuple[int, int] - metadata: Dict (composer, title, etc.) """ from datasets import load_dataset info = DATASET_REGISTRY[dataset_name] hf_id = info['hf_id'] logger.info(f"Loading dataset: {hf_id} (split={split})") try: ds = load_dataset(hf_id, split=split, trust_remote_code=True) except Exception as e: logger.warning(f"Failed to load {hf_id}: {e}") logger.info("Falling back to synthetic data generation") return _generate_synthetic_dataset(max_pieces or 100) pieces = [] n = min(len(ds), max_pieces) if max_pieces else len(ds) for i in range(n): item = ds[i] if info['format'] == 'note_events': piece = _parse_note_events_format(item) elif info['format'] == 'midi_bytes': piece = _parse_midi_bytes_format(item) else: continue if piece and len(piece.get('notes', [])) > 0: pieces.append(piece) logger.info(f"Loaded {len(pieces)} pieces from {dataset_name}") return pieces def _parse_note_events_format(item: Dict) -> Optional[Dict]: """Parse note events format (MAESTRO-style).""" try: notes_data = item.get('notes', {}) if isinstance(notes_data, dict): # Columnar format: {pitch: [...], start: [...], duration: [...], velocity: [...]} pitches = notes_data.get('pitch', []) starts = notes_data.get('start', []) durations = notes_data.get('duration', []) velocities = notes_data.get('velocity', []) notes = [] for j in range(len(pitches)): notes.append({ 'pitch': int(pitches[j]), 'start': int(float(starts[j]) * 480), # Convert to ticks 'duration': max(1, int(float(durations[j]) * 480)), 'velocity': int(velocities[j]) if j < len(velocities) else 80, }) else: return None return { 'notes': notes, 'tempo': 120.0, # Default, could extract from MIDI 'time_sig': (4, 4), 'metadata': { 'composer': item.get('composer', 'Unknown'), 'title': item.get('title', 'Untitled'), } } except Exception as e: logger.debug(f"Failed to parse note events: {e}") return None def _parse_midi_bytes_format(item: Dict) -> Optional[Dict]: """Parse MIDI bytes format.""" try: import pretty_midi import io midi_data = item.get('midi', None) if midi_data is None: return None if isinstance(midi_data, bytes): pm = pretty_midi.PrettyMIDI(io.BytesIO(midi_data)) else: return None tempo = pm.estimate_tempo() time_sig = (4, 4) if pm.time_signature_changes: ts = pm.time_signature_changes[0] time_sig = (ts.numerator, ts.denominator) notes = [] tpb = 480 for instrument in pm.instruments: if instrument.is_drum: continue for note in instrument.notes: start_ticks = int(note.start * tempo / 60.0 * tpb) duration_ticks = int((note.end - note.start) * tempo / 60.0 * tpb) notes.append({ 'pitch': note.pitch, 'start': start_ticks, 'duration': max(1, duration_ticks), 'velocity': note.velocity, }) return { 'notes': notes, 'tempo': tempo, 'time_sig': time_sig, 'metadata': {}, } except Exception as e: logger.debug(f"Failed to parse MIDI bytes: {e}") return None def _generate_synthetic_dataset(n_pieces: int = 100) -> List[Dict]: """Generate synthetic MIDI-like data for testing/fallback.""" logger.info(f"Generating {n_pieces} synthetic pieces...") pieces = [] scales = { 'major': [0, 2, 4, 5, 7, 9, 11], 'minor': [0, 2, 3, 5, 7, 8, 10], 'pentatonic': [0, 2, 4, 7, 9], } for _ in range(n_pieces): scale_name = random.choice(list(scales.keys())) scale = scales[scale_name] root = random.randint(48, 72) # Middle range tempo = random.choice([80, 100, 120, 140, 160]) time_sig = random.choice([(4, 4), (3, 4), (6, 8)]) tpb = 480 beats_per_bar = time_sig[0] * (4.0 / time_sig[1]) ticks_per_bar = int(tpb * beats_per_bar) n_bars = random.randint(8, 32) notes = [] for bar in range(n_bars): n_notes = random.randint(4, 16) for _ in range(n_notes): degree = random.choice(scale) octave_offset = random.choice([-12, 0, 0, 0, 12]) pitch = root + degree + octave_offset pitch = max(21, min(108, pitch)) position = random.randint(0, 15) * (ticks_per_bar // 16) start = bar * ticks_per_bar + position duration = random.choice([tpb // 4, tpb // 2, tpb, tpb * 2]) velocity = random.randint(40, 110) notes.append({ 'pitch': pitch, 'start': start, 'duration': duration, 'velocity': velocity, }) pieces.append({ 'notes': notes, 'tempo': tempo, 'time_sig': time_sig, 'metadata': {'scale': scale_name, 'root': root}, }) return pieces # ============================================================================ # Preprocessing Pipeline # ============================================================================ def preprocess_dataset( pieces: List[Dict], tokenizer: REMIPlusTokenizer, max_phrase_len: int = 256, bars_per_phrase: int = 1, ) -> Tuple[List[List[int]], List[Dict]]: """ Full preprocessing pipeline: 1. Convert note events → REMI+ tokens 2. Segment into phrases (bar-level) 3. Encode to integer IDs 4. Compute control attributes per phrase Returns: - phrases: List of token ID sequences - controls: List of control dicts per phrase """ all_phrases = [] all_controls = [] for piece in pieces: notes = piece['notes'] tempo = piece.get('tempo', 120.0) time_sig = piece.get('time_sig', (4, 4)) # Step 1: Notes → REMI+ tokens tokens = tokenizer.midi_to_remi_tokens(notes, tempo, time_sig) if len(tokens) < 5: continue # Step 2: Segment into phrases phrase_groups = tokenizer.segment_into_phrases(tokens, bars_per_phrase) for phrase_tokens in phrase_groups: if len(phrase_tokens) < 3: continue # Step 3: Encode to IDs ids = tokenizer.encode(phrase_tokens) if len(ids) > max_phrase_len: ids = ids[:max_phrase_len - 1] + [tokenizer.eos_id] all_phrases.append(ids) # Step 4: Compute controls controls = tokenizer.compute_controls(phrase_tokens) all_controls.append(controls) logger.info(f"Preprocessed {len(pieces)} pieces → {len(all_phrases)} phrases") return all_phrases, all_controls # ============================================================================ # Complete Data Pipeline # ============================================================================ def prepare_training_data( dataset_name: Optional[str] = None, max_pieces: int = None, max_phrase_len: int = 256, data_dir: str = './data', ) -> Tuple[List[List[int]], List[Dict], REMIPlusTokenizer]: """ Complete data pipeline: discover → download → preprocess → return. Args: dataset_name: Override auto-selection. None = auto. max_pieces: Limit number of pieces to load. max_phrase_len: Max tokens per phrase. data_dir: Directory for caching. Returns: phrases: List of token ID sequences controls: List of control dicts tokenizer: Configured tokenizer """ # Auto-select dataset if dataset_name is None: dataset_name = auto_select_dataset() # Load pieces = load_dataset_notes(dataset_name, max_pieces=max_pieces) # Create tokenizer tokenizer = REMIPlusTokenizer() # Preprocess phrases, controls = preprocess_dataset(pieces, tokenizer, max_phrase_len) # Cache os.makedirs(data_dir, exist_ok=True) cache_path = os.path.join(data_dir, f'{dataset_name}_phrases.pt') torch.save({ 'phrases': phrases, 'controls': controls, }, cache_path) logger.info(f"Cached preprocessed data to {cache_path}") # Save tokenizer tokenizer.save(os.path.join(data_dir, 'tokenizer')) return phrases, controls, tokenizer if __name__ == "__main__": phrases, controls, tokenizer = prepare_training_data(max_pieces=50) print(f"Prepared {len(phrases)} phrases") print(f"Sample phrase length: {len(phrases[0])}") print(f"Tokenizer vocab size: {tokenizer.vocab_size}")