| """Beat and downbeat tracking via beat_this (CPJKU).""" |
|
|
| from dataclasses import dataclass |
|
|
| import numpy as np |
|
|
|
|
| @dataclass |
| class BeatData: |
| beats: np.ndarray |
| downbeats: np.ndarray |
| beat_numbers: np.ndarray |
|
|
|
|
| def track_beats(audio_path: str, device: str = "cuda") -> BeatData: |
| """Run beat and downbeat tracking on an audio file.""" |
| from beat_this.inference import File2Beats |
| processor = File2Beats(checkpoint_path="final0", device=device) |
| beats, downbeats = processor(audio_path) |
|
|
| beat_numbers = _assign_beat_numbers(beats, downbeats) |
|
|
| return BeatData( |
| beats=np.asarray(beats), |
| downbeats=np.asarray(downbeats), |
| beat_numbers=beat_numbers, |
| ) |
|
|
|
|
| def _assign_beat_numbers(beats: np.ndarray, downbeats: np.ndarray) -> np.ndarray: |
| beats = np.asarray(beats) |
| downbeats_set = set(np.round(downbeats, 6)) |
| numbers = np.zeros(len(beats), dtype=int) |
| beat_num = 1 |
|
|
| for i, t in enumerate(beats): |
| if round(float(t), 6) in downbeats_set: |
| beat_num = 1 |
| numbers[i] = beat_num |
| beat_num += 1 |
|
|
| return numbers |
|
|