virtual_keyboard / engines.py
github-actions[bot]
Deploy to HF Spaces
2e0c2a7
"""
Virtual MIDI Keyboard - Engines
MIDI processing engines that transform, analyze, or manipulate MIDI events.
"""
from typing import Generator, List, Dict, Any
from midi_model import (
count_out_of_range_events,
fold_events_to_keyboard_range,
get_model,
)
# =============================================================================
# PARROT ENGINE
# =============================================================================
class ParrotEngine:
"""
Parrot Engine - plays back MIDI exactly as recorded.
This is the simplest engine - it just repeats what the user played.
"""
def __init__(self):
self.name = "Parrot"
def process(
self,
events: List[Dict[str, Any]],
options: Dict[str, Any] | None = None,
request: Any | None = None,
device: str = "auto",
) -> List[Dict[str, Any]]:
"""Return events unchanged"""
if not events:
return []
return [
{
"type": e.get("type"),
"note": e.get("note"),
"velocity": e.get("velocity"),
"time": e.get("time"),
"channel": e.get("channel", 0),
}
for e in events
]
# =============================================================================
# REVERSE PARROT ENGINE
# =============================================================================
class ReverseParrotEngine:
"""
Reverse Parrot Engine - plays back MIDI in reverse order.
Takes the recorded performance and reverses the sequence of notes,
playing them backwards while maintaining their timing relationships.
"""
def __init__(self):
self.name = "Reverse Parrot"
def process(
self,
events: List[Dict[str, Any]],
options: Dict[str, Any] | None = None,
request: Any | None = None,
device: str = "auto",
) -> List[Dict[str, Any]]:
"""Reverse the sequence of note numbers while keeping timing and event types"""
if not events:
return []
# Pair each note_on with its corresponding note_off by matching note numbers.
# This correctly handles overlapping/legato notes where note_off order
# differs from note_on order.
pairs = [] # (note, velocity, on_time, off_time, channel)
pending: Dict[int, List[Dict[str, Any]]] = {} # note -> [note_on events]
for event in events:
etype = event.get("type")
note = event.get("note")
if etype == "note_on":
pending.setdefault(note, []).append(event)
elif etype == "note_off" and note in pending and pending[note]:
on_event = pending[note].pop(0)
if not pending[note]:
del pending[note]
pairs.append(
{
"note": note,
"velocity": on_event.get("velocity"),
"on_time": on_event.get("time"),
"off_time": event.get("time"),
"channel": on_event.get("channel", 0),
}
)
# Sort pairs by on_time to get the melodic order
pairs.sort(key=lambda p: p["on_time"])
# Reverse just the note values
notes = [p["note"] for p in pairs]
reversed_notes = list(reversed(notes))
# Rebuild events with reversed note values, keeping original timing
result = []
for i, pair in enumerate(pairs):
result.append(
{
"type": "note_on",
"note": reversed_notes[i],
"velocity": pair["velocity"],
"time": pair["on_time"],
"channel": pair["channel"],
}
)
result.append(
{
"type": "note_off",
"note": reversed_notes[i],
"velocity": 0,
"time": pair["off_time"],
"channel": pair["channel"],
}
)
# Sort by time to produce a properly ordered event stream
result.sort(key=lambda e: (e["time"], 0 if e["type"] == "note_on" else 1))
return result
# =============================================================================
# GODZILLA CONTINUATION ENGINE
# =============================================================================
class GodzillaContinuationEngine:
"""
Continue a short MIDI phrase with the Godzilla Piano Transformer.
Generates a small continuation and appends it after the input events.
"""
def __init__(self, generate_tokens: int = 32):
self.name = "Godzilla"
self.generate_tokens = generate_tokens
def process(
self,
events: List[Dict[str, Any]],
options: Dict[str, Any] | None = None,
request: Any | None = None,
device: str = "auto",
) -> List[Dict[str, Any]]:
if not events:
return []
generate_tokens = self.generate_tokens
seed = None
temperature = 0.9
top_p = 0.95
num_candidates = 3
if isinstance(options, dict):
requested_tokens = options.get("generate_tokens")
if isinstance(requested_tokens, int):
generate_tokens = max(8, min(512, requested_tokens))
requested_seed = options.get("seed")
if isinstance(requested_seed, int):
seed = requested_seed
requested_temperature = options.get("temperature")
if isinstance(requested_temperature, (int, float)):
temperature = max(0.2, min(1.5, float(requested_temperature)))
requested_top_p = options.get("top_p")
if isinstance(requested_top_p, (int, float)):
top_p = max(0.5, min(0.99, float(requested_top_p)))
requested_candidates = options.get("num_candidates")
if isinstance(requested_candidates, int):
num_candidates = max(1, min(6, requested_candidates))
model = get_model("godzilla")
new_events = model.generate_continuation(
events,
tokens=generate_tokens,
seed=seed,
temperature=temperature,
top_p=top_p,
num_candidates=num_candidates,
request=request,
device=device,
)
out_of_range = count_out_of_range_events(new_events)
if out_of_range:
print(f"Godzilla: remapped {out_of_range} out-of-range events by octave folding")
return fold_events_to_keyboard_range(new_events)
def process_streaming(
self,
events: List[Dict[str, Any]],
options: Dict[str, Any] | None = None,
request: Any | None = None,
device: str = "auto",
) -> Generator[Dict[str, Any], None, None]:
"""Yield partial results as notes are generated."""
if not events:
yield {"status": "complete", "events": [], "tokens_generated": 0, "tokens_total": 0}
return
generate_tokens = self.generate_tokens
seed = None
temperature = 0.9
top_p = 0.95
if isinstance(options, dict):
requested_tokens = options.get("generate_tokens")
if isinstance(requested_tokens, int):
generate_tokens = max(8, min(512, requested_tokens))
requested_seed = options.get("seed")
if isinstance(requested_seed, int):
seed = requested_seed
requested_temperature = options.get("temperature")
if isinstance(requested_temperature, (int, float)):
temperature = max(0.2, min(1.5, float(requested_temperature)))
requested_top_p = options.get("top_p")
if isinstance(requested_top_p, (int, float)):
top_p = max(0.5, min(0.99, float(requested_top_p)))
model = get_model("godzilla")
last_result: Dict[str, Any] = {
"status": "complete", "events": [], "tokens_generated": 0, "tokens_total": generate_tokens
}
for accumulated_events, tokens_generated, tokens_total in model.generate_continuation_streaming(
events,
tokens=generate_tokens,
seed=seed,
temperature=temperature,
top_p=top_p,
request=request,
device=device,
):
folded = fold_events_to_keyboard_range(accumulated_events)
last_result = {
"status": "generating",
"events": folded,
"tokens_generated": tokens_generated,
"tokens_total": tokens_total,
}
yield last_result
last_result["status"] = "complete"
yield last_result
# =============================================================================
# ENGINE REGISTRY
# =============================================================================
class EngineRegistry:
"""Registry for managing available MIDI engines"""
_engines = {
"parrot": ParrotEngine,
"reverse_parrot": ReverseParrotEngine,
"godzilla_continue": GodzillaContinuationEngine,
}
@classmethod
def register(cls, engine_id: str, engine_class: type):
"""Register a new engine"""
cls._engines[engine_id] = engine_class
@classmethod
def get_engine(cls, engine_id: str):
"""Get an engine instance by ID"""
if engine_id not in cls._engines:
raise ValueError(f"Unknown engine: {engine_id}")
return cls._engines[engine_id]()
@classmethod
def list_engines(cls) -> List[str]:
"""List all available engines"""
return list(cls._engines.keys())
@classmethod
def get_engine_info(cls, engine_id: str) -> Dict[str, str]:
"""Get info about an engine"""
if engine_id not in cls._engines:
raise ValueError(f"Unknown engine: {engine_id}")
engine = cls._engines[engine_id]()
return {
"id": engine_id,
"name": engine.name,
}