markury commited on
Commit
b405196
·
1 Parent(s): b9cff29

Fix: lazy imports for ZeroGPU compatibility

Browse files
Files changed (3) hide show
  1. app.py +1 -4
  2. midmid/beat_tracker.py +1 -1
  3. pipeline.py +11 -9
app.py CHANGED
@@ -10,12 +10,9 @@ try:
10
  except ImportError:
11
  ON_ZEROGPU = False
12
 
13
- from pipeline import ensure_model, generate_chart
14
  from visualizer import build_visualizer_html
15
 
16
- # Pre-load model on CPU at startup
17
- ensure_model()
18
-
19
  PLACEHOLDER_HTML = """
20
  <div style="font-family: system-ui, sans-serif; background: #111; border-radius: 12px;
21
  padding: 60px 20px; text-align: center; color: #666; max-width: 900px; margin: 0 auto;">
 
10
  except ImportError:
11
  ON_ZEROGPU = False
12
 
13
+ from pipeline import generate_chart
14
  from visualizer import build_visualizer_html
15
 
 
 
 
16
  PLACEHOLDER_HTML = """
17
  <div style="font-family: system-ui, sans-serif; background: #111; border-radius: 12px;
18
  padding: 60px 20px; text-align: center; color: #666; max-width: 900px; margin: 0 auto;">
midmid/beat_tracker.py CHANGED
@@ -3,7 +3,6 @@
3
  from dataclasses import dataclass
4
 
5
  import numpy as np
6
- from beat_this.inference import File2Beats
7
 
8
 
9
  @dataclass
@@ -15,6 +14,7 @@ class BeatData:
15
 
16
  def track_beats(audio_path: str, device: str = "cuda") -> BeatData:
17
  """Run beat and downbeat tracking on an audio file."""
 
18
  processor = File2Beats(checkpoint_path="final0", device=device)
19
  beats, downbeats = processor(audio_path)
20
 
 
3
  from dataclasses import dataclass
4
 
5
  import numpy as np
 
6
 
7
 
8
  @dataclass
 
14
 
15
  def track_beats(audio_path: str, device: str = "cuda") -> BeatData:
16
  """Run beat and downbeat tracking on an audio file."""
17
+ from beat_this.inference import File2Beats
18
  processor = File2Beats(checkpoint_path="final0", device=device)
19
  beats, downbeats = processor(audio_path)
20
 
pipeline.py CHANGED
@@ -15,16 +15,7 @@ from pathlib import Path
15
  import numpy as np
16
  import torch
17
 
18
- from midmid.beat_tracker import track_beats
19
- from midmid.tempo_map import derive_tempo_map, get_median_bpm, estimate_time_signature
20
- from midmid.offset import calculate_offset
21
- from midmid.sections import detect_sections
22
- from midmid.constraints import enforce_constraints
23
  from midmid.datatypes import ChartData, NoteEvent
24
- from midmid.inference import load_model_from_hub, predict_notes, move_models_to_device
25
- from midmid.midi_writer import write_midi
26
- from midmid.audio_prep import prepare_audio
27
- from midmid.ini_writer import write_ini
28
 
29
  RESOLUTION = 192
30
  MODEL_REPO = "markury/midmid3-19m-0326"
@@ -37,6 +28,7 @@ def ensure_model():
37
  """Pre-load model on CPU (called at app startup)."""
38
  global _chart_model
39
  if _chart_model is None:
 
40
  print("Loading chart model from HF Hub...")
41
  _chart_model = load_model_from_hub(MODEL_REPO, device="cpu")
42
  print("Chart model loaded.")
@@ -70,6 +62,16 @@ def generate_chart(
70
  Returns:
71
  (zip_path, chart_json) where chart_json has the data for the visualizer.
72
  """
 
 
 
 
 
 
 
 
 
 
73
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
  model = ensure_model()
75
  model.to(device)
 
15
  import numpy as np
16
  import torch
17
 
 
 
 
 
 
18
  from midmid.datatypes import ChartData, NoteEvent
 
 
 
 
19
 
20
  RESOLUTION = 192
21
  MODEL_REPO = "markury/midmid3-19m-0326"
 
28
  """Pre-load model on CPU (called at app startup)."""
29
  global _chart_model
30
  if _chart_model is None:
31
+ from midmid.inference import load_model_from_hub
32
  print("Loading chart model from HF Hub...")
33
  _chart_model = load_model_from_hub(MODEL_REPO, device="cpu")
34
  print("Chart model loaded.")
 
62
  Returns:
63
  (zip_path, chart_json) where chart_json has the data for the visualizer.
64
  """
65
+ from midmid.beat_tracker import track_beats
66
+ from midmid.tempo_map import derive_tempo_map, get_median_bpm, estimate_time_signature
67
+ from midmid.offset import calculate_offset
68
+ from midmid.sections import detect_sections
69
+ from midmid.constraints import enforce_constraints
70
+ from midmid.inference import predict_notes, move_models_to_device
71
+ from midmid.midi_writer import write_midi
72
+ from midmid.audio_prep import prepare_audio
73
+ from midmid.ini_writer import write_ini
74
+
75
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
  model = ensure_model()
77
  model.to(device)