midmid3 / midmid /beat_tracker.py
markury's picture
Fix: lazy imports for ZeroGPU compatibility
b405196
"""Beat and downbeat tracking via beat_this (CPJKU)."""
from dataclasses import dataclass
import numpy as np
@dataclass
class BeatData:
beats: np.ndarray
downbeats: np.ndarray
beat_numbers: np.ndarray
def track_beats(audio_path: str, device: str = "cuda") -> BeatData:
"""Run beat and downbeat tracking on an audio file."""
from beat_this.inference import File2Beats
processor = File2Beats(checkpoint_path="final0", device=device)
beats, downbeats = processor(audio_path)
beat_numbers = _assign_beat_numbers(beats, downbeats)
return BeatData(
beats=np.asarray(beats),
downbeats=np.asarray(downbeats),
beat_numbers=beat_numbers,
)
def _assign_beat_numbers(beats: np.ndarray, downbeats: np.ndarray) -> np.ndarray:
beats = np.asarray(beats)
downbeats_set = set(np.round(downbeats, 6))
numbers = np.zeros(len(beats), dtype=int)
beat_num = 1
for i, t in enumerate(beats):
if round(float(t), 6) in downbeats_set:
beat_num = 1
numbers[i] = beat_num
beat_num += 1
return numbers