| """Generation pipeline — callable from Gradio, ZeroGPU-compatible. |
| |
| Wraps the full audio→chart pipeline into a single function that returns |
| a zip file path and chart JSON for the visualizer. |
| """ |
|
|
| import base64 |
| import json |
| import os |
| import shutil |
| import tempfile |
| from datetime import datetime |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
|
|
| from midmid.datatypes import ChartData, NoteEvent |
|
|
| RESOLUTION = 192 |
| MODEL_REPO = "markury/midmid3-19m-0326" |
|
|
| |
| _chart_model = None |
|
|
|
|
| def ensure_model(): |
| """Pre-load model on CPU (called at app startup).""" |
| global _chart_model |
| if _chart_model is None: |
| from midmid.inference import load_model_from_hub |
| print("Loading chart model from HF Hub...") |
| _chart_model = load_model_from_hub(MODEL_REPO, device="cpu") |
| print("Chart model loaded.") |
| return _chart_model |
|
|
|
|
| def generate_chart( |
| audio_path: str, |
| title: str, |
| artist: str, |
| album: str = "", |
| year: str = "", |
| genre: str = "rock", |
| temperature: float = 0.8, |
| num_steps: int = 12, |
| progress_cb=None, |
| ) -> tuple[str, dict]: |
| """Run the full generation pipeline. |
| |
| Args: |
| audio_path: Path to uploaded audio file. |
| title: Song title. |
| artist: Artist name. |
| album: Album name (optional). |
| year: Release year (optional). |
| genre: Genre string (optional). |
| temperature: Sampling temperature. |
| num_steps: Unmasking steps. |
| progress_cb: Optional callable(step, total, message) for progress. |
| |
| Returns: |
| (zip_path, chart_json) where chart_json has the data for the visualizer. |
| """ |
| from midmid.beat_tracker import track_beats |
| from midmid.tempo_map import derive_tempo_map, get_median_bpm, estimate_time_signature |
| from midmid.offset import calculate_offset |
| from midmid.sections import detect_sections |
| from midmid.constraints import enforce_constraints |
| from midmid.inference import predict_notes, move_models_to_device |
| from midmid.midi_writer import write_midi |
| from midmid.audio_prep import prepare_audio |
| from midmid.ini_writer import write_ini |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = ensure_model() |
| model.to(device) |
| move_models_to_device(device) |
|
|
| if not year: |
| year = str(datetime.now().year) |
|
|
| |
| tmp_dir = tempfile.mkdtemp(prefix="midmid_") |
| song_dir = Path(tmp_dir) / f"{title} - {artist}" |
| song_dir.mkdir(parents=True, exist_ok=True) |
|
|
| def _progress(step, total, msg): |
| if progress_cb: |
| progress_cb(step / total, desc=msg) |
|
|
| |
| _progress(0, 8, "Tracking beats...") |
| beat_data = track_beats(audio_path, device=str(device)) |
|
|
| _progress(1, 8, "Analyzing tempo...") |
| tempo_map = derive_tempo_map(beat_data) |
| bpm = get_median_bpm(beat_data) |
| time_sig = estimate_time_signature(beat_data) |
| offset_sec = calculate_offset(beat_data, bpm, beats_per_measure=time_sig) |
|
|
| _progress(2, 8, "Detecting sections...") |
| raw_sections = detect_sections(audio_path) |
|
|
| |
| beat_times = list(beat_data.beats) |
| difficulties = ["expert", "hard", "medium", "easy"] |
| all_notes = {} |
|
|
| for i, diff_name in enumerate(difficulties): |
| _progress(3 + i * 0.2, 8, f"Generating {diff_name} chart...") |
| raw_notes = predict_notes( |
| audio_path=audio_path, |
| model=model, |
| beat_times=beat_times, |
| difficulty=diff_name, |
| device=device, |
| temperature=temperature, |
| num_steps=num_steps, |
| ) |
|
|
| notes = _grid_to_musical_ticks(raw_notes, beat_times, offset_sec, bpm, RESOLUTION) |
| notes = enforce_constraints(notes, diff_name, RESOLUTION) |
|
|
| last_beat_sec = float(beat_data.beats[-1]) if len(beat_data.beats) > 0 else 0 |
| last_beat_tick = int(round((last_beat_sec + offset_sec) * bpm / 60.0 * RESOLUTION)) |
| notes = [n for n in notes if n.tick <= last_beat_tick] |
|
|
| all_notes[diff_name] = notes |
|
|
| |
| required = ["expert", "hard", "medium", "easy"] |
| for diff in required: |
| if diff not in all_notes: |
| for fallback in required: |
| if fallback in all_notes: |
| all_notes[diff] = all_notes[fallback] |
| break |
|
|
| |
| _progress(5, 8, "Building chart...") |
| tempo_events = _tempo_map_to_ticks(tempo_map, offset_sec, bpm, RESOLUTION) |
| section_events = _sections_to_ticks(raw_sections, tempo_map, offset_sec, RESOLUTION) |
|
|
| all_ticks = [n.tick for ns in all_notes.values() for n in ns] |
| last_tick = max(all_ticks) + RESOLUTION * time_sig if all_ticks else RESOLUTION * time_sig * 4 |
| beat_markers = _build_beat_markers(last_tick, RESOLUTION, time_sig) |
|
|
| chart = ChartData( |
| resolution=RESOLUTION, |
| tempo_events=tempo_events, |
| time_signatures=[(0, time_sig, 4)], |
| sections=section_events, |
| notes=all_notes, |
| beats=beat_markers, |
| ) |
|
|
| |
| _progress(6, 8, "Writing MIDI...") |
| write_midi(chart, str(song_dir / "notes.mid")) |
|
|
| _progress(7, 8, "Preparing audio...") |
| prepare_audio( |
| audio_path=audio_path, |
| output_path=str(song_dir / "song.ogg"), |
| silence_duration_sec=offset_sec, |
| ) |
|
|
| write_ini( |
| output_path=str(song_dir / "song.ini"), |
| title=title, |
| artist=artist, |
| album=album, |
| genre=genre, |
| year=year, |
| ) |
|
|
| |
| zip_base = Path(tmp_dir) / f"{title} - {artist}" |
| zip_path = shutil.make_archive(str(zip_base), "zip", tmp_dir, song_dir.name) |
|
|
| |
| chart_json = _build_chart_json( |
| chart, bpm, offset_sec, audio_path, str(song_dir / "song.ogg"), |
| ) |
|
|
| _progress(8, 8, "Done!") |
| return zip_path, chart_json |
|
|
|
|
| def _build_chart_json(chart, bpm, offset_sec, original_audio_path, prepared_audio_path): |
| """Build JSON payload for the client-side visualizer.""" |
| |
| with open(prepared_audio_path, "rb") as f: |
| audio_b64 = base64.b64encode(f.read()).decode("ascii") |
|
|
| notes_json = {} |
| for diff, note_list in chart.notes.items(): |
| notes_json[diff] = [ |
| { |
| "tick": n.tick, |
| "frets": sorted(n.fret_set), |
| "sustain": n.sustain_ticks, |
| "hopo": n.is_hopo, |
| } |
| for n in note_list |
| ] |
|
|
| return { |
| "resolution": chart.resolution, |
| "bpm": bpm, |
| "offset_sec": offset_sec, |
| "tempo_events": [{"tick": t, "bpm": b} for t, b in chart.tempo_events], |
| "time_signatures": [{"tick": t, "num": n, "den": d} for t, n, d in chart.time_signatures], |
| "sections": [{"tick": t, "label": l} for t, l in chart.sections], |
| "beats": [{"tick": t, "downbeat": d} for t, d in chart.beats], |
| "notes": notes_json, |
| "audio_b64": audio_b64, |
| "audio_format": "ogg", |
| } |
|
|
|
|
| |
| |
| |
|
|
| def _grid_to_musical_ticks(notes, beat_times, offset_sec, bpm, resolution): |
| if len(beat_times) < 2: |
| return notes |
|
|
| sixteenth = resolution // 4 |
|
|
| fretbars_ms = [t * 1000.0 for t in beat_times] |
| grid_times_ms = [] |
| for i in range(len(fretbars_ms) - 1): |
| start = fretbars_ms[i] |
| interval = fretbars_ms[i + 1] - start |
| for sub in range(4): |
| grid_times_ms.append(start + sub * interval / 4.0) |
| grid_times_ms.append(fretbars_ms[-1]) |
|
|
| result = [] |
| for note in notes: |
| grid_idx = note.tick |
| if grid_idx < 0 or grid_idx >= len(grid_times_ms): |
| continue |
|
|
| time_sec = grid_times_ms[grid_idx] / 1000.0 + offset_sec |
| tick = round(time_sec * bpm / 60.0 * resolution) |
| tick = round(tick / sixteenth) * sixteenth |
| tick = max(0, tick) |
|
|
| sustain_ticks = 0 |
| if note.sustain_ticks > 0: |
| sustain_sec = note.sustain_ticks / 1000.0 |
| raw = sustain_sec * bpm / 60.0 * resolution |
| sustain_ticks = max(sixteenth, round(raw / sixteenth) * sixteenth) |
|
|
| result.append(NoteEvent( |
| tick=tick, |
| fret_set=note.fret_set, |
| sustain_ticks=sustain_ticks, |
| is_hopo=note.is_hopo, |
| )) |
|
|
| return result |
|
|
|
|
| def _tempo_map_to_ticks(tempo_map, offset_sec, bpm, resolution): |
| events = [] |
| for i, (time_sec, bpm_val) in enumerate(tempo_map): |
| if i == 0: |
| events.append((0, bpm_val)) |
| else: |
| adjusted_time = time_sec + offset_sec |
| prev_time = tempo_map[i - 1][0] + offset_sec if i > 0 else 0 |
| dt_sec = adjusted_time - prev_time |
| prev_tick = events[-1][0] |
| prev_bpm = events[-1][1] |
| tick = prev_tick + int(round(dt_sec * prev_bpm / 60.0 * resolution)) |
| events.append((tick, bpm_val)) |
| return events |
|
|
|
|
| def _sections_to_ticks(sections, tempo_map, offset_sec, resolution): |
| if not tempo_map: |
| return [] |
| result = [] |
| bpm = tempo_map[0][1] |
| for time_sec, label in sections: |
| adjusted = time_sec + offset_sec |
| tick = int(round(adjusted * bpm / 60.0 * resolution)) |
| tick = max(0, tick) |
| result.append((tick, label)) |
| return result |
|
|
|
|
| def _build_beat_markers(last_tick, resolution, beats_per_measure): |
| beats = [] |
| tick = 0 |
| beat_in_measure = 0 |
| while tick <= last_tick: |
| beats.append((tick, beat_in_measure == 0)) |
| beat_in_measure = (beat_in_measure + 1) % beats_per_measure |
| tick += resolution |
| return beats |
|
|