Spaces:
Running
Running
| """ | |
| Virtual MIDI Keyboard - Engines | |
| MIDI processing engines that transform, analyze, or manipulate MIDI events. | |
| """ | |
| from typing import Generator, List, Dict, Any | |
| from midi_model import ( | |
| count_out_of_range_events, | |
| fold_events_to_keyboard_range, | |
| get_model, | |
| ) | |
| # ============================================================================= | |
| # PARROT ENGINE | |
| # ============================================================================= | |
| class ParrotEngine: | |
| """ | |
| Parrot Engine - plays back MIDI exactly as recorded. | |
| This is the simplest engine - it just repeats what the user played. | |
| """ | |
| def __init__(self): | |
| self.name = "Parrot" | |
| def process( | |
| self, | |
| events: List[Dict[str, Any]], | |
| options: Dict[str, Any] | None = None, | |
| request: Any | None = None, | |
| device: str = "auto", | |
| ) -> List[Dict[str, Any]]: | |
| """Return events unchanged""" | |
| if not events: | |
| return [] | |
| return [ | |
| { | |
| "type": e.get("type"), | |
| "note": e.get("note"), | |
| "velocity": e.get("velocity"), | |
| "time": e.get("time"), | |
| "channel": e.get("channel", 0), | |
| } | |
| for e in events | |
| ] | |
| # ============================================================================= | |
| # REVERSE PARROT ENGINE | |
| # ============================================================================= | |
| class ReverseParrotEngine: | |
| """ | |
| Reverse Parrot Engine - plays back MIDI in reverse order. | |
| Takes the recorded performance and reverses the sequence of notes, | |
| playing them backwards while maintaining their timing relationships. | |
| """ | |
| def __init__(self): | |
| self.name = "Reverse Parrot" | |
| def process( | |
| self, | |
| events: List[Dict[str, Any]], | |
| options: Dict[str, Any] | None = None, | |
| request: Any | None = None, | |
| device: str = "auto", | |
| ) -> List[Dict[str, Any]]: | |
| """Reverse the sequence of note numbers while keeping timing and event types""" | |
| if not events: | |
| return [] | |
| # Pair each note_on with its corresponding note_off by matching note numbers. | |
| # This correctly handles overlapping/legato notes where note_off order | |
| # differs from note_on order. | |
| pairs = [] # (note, velocity, on_time, off_time, channel) | |
| pending: Dict[int, List[Dict[str, Any]]] = {} # note -> [note_on events] | |
| for event in events: | |
| etype = event.get("type") | |
| note = event.get("note") | |
| if etype == "note_on": | |
| pending.setdefault(note, []).append(event) | |
| elif etype == "note_off" and note in pending and pending[note]: | |
| on_event = pending[note].pop(0) | |
| if not pending[note]: | |
| del pending[note] | |
| pairs.append( | |
| { | |
| "note": note, | |
| "velocity": on_event.get("velocity"), | |
| "on_time": on_event.get("time"), | |
| "off_time": event.get("time"), | |
| "channel": on_event.get("channel", 0), | |
| } | |
| ) | |
| # Sort pairs by on_time to get the melodic order | |
| pairs.sort(key=lambda p: p["on_time"]) | |
| # Reverse just the note values | |
| notes = [p["note"] for p in pairs] | |
| reversed_notes = list(reversed(notes)) | |
| # Rebuild events with reversed note values, keeping original timing | |
| result = [] | |
| for i, pair in enumerate(pairs): | |
| result.append( | |
| { | |
| "type": "note_on", | |
| "note": reversed_notes[i], | |
| "velocity": pair["velocity"], | |
| "time": pair["on_time"], | |
| "channel": pair["channel"], | |
| } | |
| ) | |
| result.append( | |
| { | |
| "type": "note_off", | |
| "note": reversed_notes[i], | |
| "velocity": 0, | |
| "time": pair["off_time"], | |
| "channel": pair["channel"], | |
| } | |
| ) | |
| # Sort by time to produce a properly ordered event stream | |
| result.sort(key=lambda e: (e["time"], 0 if e["type"] == "note_on" else 1)) | |
| return result | |
| # ============================================================================= | |
| # GODZILLA CONTINUATION ENGINE | |
| # ============================================================================= | |
| class GodzillaContinuationEngine: | |
| """ | |
| Continue a short MIDI phrase with the Godzilla Piano Transformer. | |
| Generates a small continuation and appends it after the input events. | |
| """ | |
| def __init__(self, generate_tokens: int = 32): | |
| self.name = "Godzilla" | |
| self.generate_tokens = generate_tokens | |
| def process( | |
| self, | |
| events: List[Dict[str, Any]], | |
| options: Dict[str, Any] | None = None, | |
| request: Any | None = None, | |
| device: str = "auto", | |
| ) -> List[Dict[str, Any]]: | |
| if not events: | |
| return [] | |
| generate_tokens = self.generate_tokens | |
| seed = None | |
| temperature = 0.9 | |
| top_p = 0.95 | |
| num_candidates = 3 | |
| if isinstance(options, dict): | |
| requested_tokens = options.get("generate_tokens") | |
| if isinstance(requested_tokens, int): | |
| generate_tokens = max(8, min(512, requested_tokens)) | |
| requested_seed = options.get("seed") | |
| if isinstance(requested_seed, int): | |
| seed = requested_seed | |
| requested_temperature = options.get("temperature") | |
| if isinstance(requested_temperature, (int, float)): | |
| temperature = max(0.2, min(1.5, float(requested_temperature))) | |
| requested_top_p = options.get("top_p") | |
| if isinstance(requested_top_p, (int, float)): | |
| top_p = max(0.5, min(0.99, float(requested_top_p))) | |
| requested_candidates = options.get("num_candidates") | |
| if isinstance(requested_candidates, int): | |
| num_candidates = max(1, min(6, requested_candidates)) | |
| model = get_model("godzilla") | |
| new_events = model.generate_continuation( | |
| events, | |
| tokens=generate_tokens, | |
| seed=seed, | |
| temperature=temperature, | |
| top_p=top_p, | |
| num_candidates=num_candidates, | |
| request=request, | |
| device=device, | |
| ) | |
| out_of_range = count_out_of_range_events(new_events) | |
| if out_of_range: | |
| print(f"Godzilla: remapped {out_of_range} out-of-range events by octave folding") | |
| return fold_events_to_keyboard_range(new_events) | |
| def process_streaming( | |
| self, | |
| events: List[Dict[str, Any]], | |
| options: Dict[str, Any] | None = None, | |
| request: Any | None = None, | |
| device: str = "auto", | |
| ) -> Generator[Dict[str, Any], None, None]: | |
| """Yield partial results as notes are generated.""" | |
| if not events: | |
| yield {"status": "complete", "events": [], "tokens_generated": 0, "tokens_total": 0} | |
| return | |
| generate_tokens = self.generate_tokens | |
| seed = None | |
| temperature = 0.9 | |
| top_p = 0.95 | |
| if isinstance(options, dict): | |
| requested_tokens = options.get("generate_tokens") | |
| if isinstance(requested_tokens, int): | |
| generate_tokens = max(8, min(512, requested_tokens)) | |
| requested_seed = options.get("seed") | |
| if isinstance(requested_seed, int): | |
| seed = requested_seed | |
| requested_temperature = options.get("temperature") | |
| if isinstance(requested_temperature, (int, float)): | |
| temperature = max(0.2, min(1.5, float(requested_temperature))) | |
| requested_top_p = options.get("top_p") | |
| if isinstance(requested_top_p, (int, float)): | |
| top_p = max(0.5, min(0.99, float(requested_top_p))) | |
| model = get_model("godzilla") | |
| last_result: Dict[str, Any] = { | |
| "status": "complete", "events": [], "tokens_generated": 0, "tokens_total": generate_tokens | |
| } | |
| for accumulated_events, tokens_generated, tokens_total in model.generate_continuation_streaming( | |
| events, | |
| tokens=generate_tokens, | |
| seed=seed, | |
| temperature=temperature, | |
| top_p=top_p, | |
| request=request, | |
| device=device, | |
| ): | |
| folded = fold_events_to_keyboard_range(accumulated_events) | |
| last_result = { | |
| "status": "generating", | |
| "events": folded, | |
| "tokens_generated": tokens_generated, | |
| "tokens_total": tokens_total, | |
| } | |
| yield last_result | |
| last_result["status"] = "complete" | |
| yield last_result | |
| # ============================================================================= | |
| # ENGINE REGISTRY | |
| # ============================================================================= | |
| class EngineRegistry: | |
| """Registry for managing available MIDI engines""" | |
| _engines = { | |
| "parrot": ParrotEngine, | |
| "reverse_parrot": ReverseParrotEngine, | |
| "godzilla_continue": GodzillaContinuationEngine, | |
| } | |
| def register(cls, engine_id: str, engine_class: type): | |
| """Register a new engine""" | |
| cls._engines[engine_id] = engine_class | |
| def get_engine(cls, engine_id: str): | |
| """Get an engine instance by ID""" | |
| if engine_id not in cls._engines: | |
| raise ValueError(f"Unknown engine: {engine_id}") | |
| return cls._engines[engine_id]() | |
| def list_engines(cls) -> List[str]: | |
| """List all available engines""" | |
| return list(cls._engines.keys()) | |
| def get_engine_info(cls, engine_id: str) -> Dict[str, str]: | |
| """Get info about an engine""" | |
| if engine_id not in cls._engines: | |
| raise ValueError(f"Unknown engine: {engine_id}") | |
| engine = cls._engines[engine_id]() | |
| return { | |
| "id": engine_id, | |
| "name": engine.name, | |
| } | |