Fix: lazy imports for ZeroGPU compatibility
Browse files- app.py +1 -4
- midmid/beat_tracker.py +1 -1
- pipeline.py +11 -9
app.py
CHANGED
|
@@ -10,12 +10,9 @@ try:
|
|
| 10 |
except ImportError:
|
| 11 |
ON_ZEROGPU = False
|
| 12 |
|
| 13 |
-
from pipeline import
|
| 14 |
from visualizer import build_visualizer_html
|
| 15 |
|
| 16 |
-
# Pre-load model on CPU at startup
|
| 17 |
-
ensure_model()
|
| 18 |
-
|
| 19 |
PLACEHOLDER_HTML = """
|
| 20 |
<div style="font-family: system-ui, sans-serif; background: #111; border-radius: 12px;
|
| 21 |
padding: 60px 20px; text-align: center; color: #666; max-width: 900px; margin: 0 auto;">
|
|
|
|
| 10 |
except ImportError:
|
| 11 |
ON_ZEROGPU = False
|
| 12 |
|
| 13 |
+
from pipeline import generate_chart
|
| 14 |
from visualizer import build_visualizer_html
|
| 15 |
|
|
|
|
|
|
|
|
|
|
| 16 |
PLACEHOLDER_HTML = """
|
| 17 |
<div style="font-family: system-ui, sans-serif; background: #111; border-radius: 12px;
|
| 18 |
padding: 60px 20px; text-align: center; color: #666; max-width: 900px; margin: 0 auto;">
|
midmid/beat_tracker.py
CHANGED
|
@@ -3,7 +3,6 @@
|
|
| 3 |
from dataclasses import dataclass
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
-
from beat_this.inference import File2Beats
|
| 7 |
|
| 8 |
|
| 9 |
@dataclass
|
|
@@ -15,6 +14,7 @@ class BeatData:
|
|
| 15 |
|
| 16 |
def track_beats(audio_path: str, device: str = "cuda") -> BeatData:
|
| 17 |
"""Run beat and downbeat tracking on an audio file."""
|
|
|
|
| 18 |
processor = File2Beats(checkpoint_path="final0", device=device)
|
| 19 |
beats, downbeats = processor(audio_path)
|
| 20 |
|
|
|
|
| 3 |
from dataclasses import dataclass
|
| 4 |
|
| 5 |
import numpy as np
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
@dataclass
|
|
|
|
| 14 |
|
| 15 |
def track_beats(audio_path: str, device: str = "cuda") -> BeatData:
|
| 16 |
"""Run beat and downbeat tracking on an audio file."""
|
| 17 |
+
from beat_this.inference import File2Beats
|
| 18 |
processor = File2Beats(checkpoint_path="final0", device=device)
|
| 19 |
beats, downbeats = processor(audio_path)
|
| 20 |
|
pipeline.py
CHANGED
|
@@ -15,16 +15,7 @@ from pathlib import Path
|
|
| 15 |
import numpy as np
|
| 16 |
import torch
|
| 17 |
|
| 18 |
-
from midmid.beat_tracker import track_beats
|
| 19 |
-
from midmid.tempo_map import derive_tempo_map, get_median_bpm, estimate_time_signature
|
| 20 |
-
from midmid.offset import calculate_offset
|
| 21 |
-
from midmid.sections import detect_sections
|
| 22 |
-
from midmid.constraints import enforce_constraints
|
| 23 |
from midmid.datatypes import ChartData, NoteEvent
|
| 24 |
-
from midmid.inference import load_model_from_hub, predict_notes, move_models_to_device
|
| 25 |
-
from midmid.midi_writer import write_midi
|
| 26 |
-
from midmid.audio_prep import prepare_audio
|
| 27 |
-
from midmid.ini_writer import write_ini
|
| 28 |
|
| 29 |
RESOLUTION = 192
|
| 30 |
MODEL_REPO = "markury/midmid3-19m-0326"
|
|
@@ -37,6 +28,7 @@ def ensure_model():
|
|
| 37 |
"""Pre-load model on CPU (called at app startup)."""
|
| 38 |
global _chart_model
|
| 39 |
if _chart_model is None:
|
|
|
|
| 40 |
print("Loading chart model from HF Hub...")
|
| 41 |
_chart_model = load_model_from_hub(MODEL_REPO, device="cpu")
|
| 42 |
print("Chart model loaded.")
|
|
@@ -70,6 +62,16 @@ def generate_chart(
|
|
| 70 |
Returns:
|
| 71 |
(zip_path, chart_json) where chart_json has the data for the visualizer.
|
| 72 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 74 |
model = ensure_model()
|
| 75 |
model.to(device)
|
|
|
|
| 15 |
import numpy as np
|
| 16 |
import torch
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
from midmid.datatypes import ChartData, NoteEvent
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
RESOLUTION = 192
|
| 21 |
MODEL_REPO = "markury/midmid3-19m-0326"
|
|
|
|
| 28 |
"""Pre-load model on CPU (called at app startup)."""
|
| 29 |
global _chart_model
|
| 30 |
if _chart_model is None:
|
| 31 |
+
from midmid.inference import load_model_from_hub
|
| 32 |
print("Loading chart model from HF Hub...")
|
| 33 |
_chart_model = load_model_from_hub(MODEL_REPO, device="cpu")
|
| 34 |
print("Chart model loaded.")
|
|
|
|
| 62 |
Returns:
|
| 63 |
(zip_path, chart_json) where chart_json has the data for the visualizer.
|
| 64 |
"""
|
| 65 |
+
from midmid.beat_tracker import track_beats
|
| 66 |
+
from midmid.tempo_map import derive_tempo_map, get_median_bpm, estimate_time_signature
|
| 67 |
+
from midmid.offset import calculate_offset
|
| 68 |
+
from midmid.sections import detect_sections
|
| 69 |
+
from midmid.constraints import enforce_constraints
|
| 70 |
+
from midmid.inference import predict_notes, move_models_to_device
|
| 71 |
+
from midmid.midi_writer import write_midi
|
| 72 |
+
from midmid.audio_prep import prepare_audio
|
| 73 |
+
from midmid.ini_writer import write_ini
|
| 74 |
+
|
| 75 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 76 |
model = ensure_model()
|
| 77 |
model.to(device)
|