| """Beat/kick detection using madmom's RNN beat tracker.""" |
|
|
| import json |
| import subprocess |
| import tempfile |
| from pathlib import Path |
| from typing import Optional |
|
|
| import numpy as np |
| from madmom.features.beats import DBNBeatTrackingProcessor, RNNBeatProcessor |
|
|
| |
| HIGHPASS_CUTOFF = 50 |
| LOWPASS_CUTOFF = 500 |
|
|
|
|
| def _bandpass_filter(input_path: Path) -> Path: |
| """Apply a 50-200 Hz bandpass filter to isolate kick drum transients. |
| |
| Returns path to a temporary filtered WAV file. |
| """ |
| filtered = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) |
| filtered.close() |
| subprocess.run([ |
| "ffmpeg", "-y", |
| "-i", str(input_path), |
| "-af", f"highpass=f={HIGHPASS_CUTOFF},lowpass=f={LOWPASS_CUTOFF}", |
| str(filtered.name), |
| ], check=True, capture_output=True) |
| return Path(filtered.name) |
|
|
|
|
| def detect_beats( |
| drum_stem_path: str | Path, |
| min_bpm: float = 55.0, |
| max_bpm: float = 215.0, |
| transition_lambda: float = 100, |
| fps: int = 1000, |
| ) -> np.ndarray: |
| """Detect beat timestamps from a drum stem using madmom. |
| |
| Uses an ensemble of bidirectional LSTMs to produce a beat activation |
| function, then a Dynamic Bayesian Network to decode beat positions. |
| |
| Args: |
| drum_stem_path: Path to the isolated drum stem WAV file. |
| min_bpm: Minimum expected tempo. Narrow this if you know the song's |
| approximate BPM for better accuracy. |
| max_bpm: Maximum expected tempo. |
| transition_lambda: Tempo smoothness — higher values penalise tempo |
| changes more (100 = very steady, good for most pop/rock). |
| fps: Frames per second for the DBN decoder. The RNN outputs at 100fps; |
| higher values interpolate for finer timestamp resolution (1ms at 1000fps). |
| |
| Returns: |
| 1D numpy array of beat timestamps in seconds, sorted chronologically. |
| """ |
| drum_stem_path = Path(drum_stem_path) |
|
|
| |
| filtered_path = _bandpass_filter(drum_stem_path) |
|
|
| |
| act_proc = RNNBeatProcessor() |
| activations = act_proc(str(filtered_path)) |
|
|
| |
| filtered_path.unlink(missing_ok=True) |
|
|
| |
| if fps != 100: |
| from scipy.interpolate import interp1d |
| n_frames = len(activations) |
| t_orig = np.linspace(0, n_frames / 100, n_frames, endpoint=False) |
| n_new = int(n_frames * fps / 100) |
| t_new = np.linspace(0, n_frames / 100, n_new, endpoint=False) |
| activations = interp1d(t_orig, activations, kind="cubic", fill_value="extrapolate")(t_new) |
| activations = np.clip(activations, 0, None) |
|
|
| |
| |
| |
| beat_proc = DBNBeatTrackingProcessor( |
| min_bpm=min_bpm, |
| max_bpm=max_bpm, |
| transition_lambda=transition_lambda, |
| fps=fps, |
| correct=False, |
| ) |
| beats = beat_proc(activations) |
|
|
| return beats |
|
|
|
|
| def detect_drop( |
| audio_path: str | Path, |
| beat_times: np.ndarray, |
| window_sec: float = 0.5, |
| ) -> float: |
| """Find the beat where the biggest energy jump occurs (the drop). |
| |
| Computes RMS energy in a window around each beat and returns the beat |
| with the largest increase compared to the previous beat. |
| |
| Args: |
| audio_path: Path to the full mix audio file. |
| beat_times: Array of beat timestamps in seconds. |
| window_sec: Duration of the analysis window around each beat. |
| |
| Returns: |
| Timestamp (seconds) of the detected drop beat. |
| """ |
| import librosa |
|
|
| y, sr = librosa.load(str(audio_path), sr=None, mono=True) |
| half_win = int(window_sec / 2 * sr) |
|
|
| rms_values = [] |
| for t in beat_times: |
| center = int(t * sr) |
| start = max(0, center - half_win) |
| end = min(len(y), center + half_win) |
| segment = y[start:end] |
| rms = np.sqrt(np.mean(segment ** 2)) if len(segment) > 0 else 0.0 |
| rms_values.append(rms) |
|
|
| rms_values = np.array(rms_values) |
|
|
| |
| diffs = np.diff(rms_values) |
| drop_idx = int(np.argmax(diffs)) + 1 |
| drop_time = float(beat_times[drop_idx]) |
|
|
| print(f" Drop detected at beat {drop_idx + 1}: {drop_time:.3f}s " |
| f"(energy jump: {diffs[drop_idx - 1]:.4f})") |
| return drop_time |
|
|
|
|
| def select_beats( |
| beats: np.ndarray, |
| max_duration: float = 15.0, |
| min_interval: float = 0.3, |
| ) -> np.ndarray: |
| """Select a subset of beats for video generation. |
| |
| Filters beats to fit within a duration limit and enforces a minimum |
| interval between consecutive beats (to avoid generating too many frames). |
| |
| Args: |
| beats: Array of beat timestamps in seconds. |
| max_duration: Maximum video duration in seconds. |
| min_interval: Minimum time between selected beats in seconds. |
| Beats closer together than this are skipped. |
| |
| Returns: |
| Filtered array of beat timestamps. |
| """ |
| if len(beats) == 0: |
| return beats |
|
|
| |
| beats = beats[beats <= max_duration] |
|
|
| if len(beats) == 0: |
| return beats |
|
|
| |
| selected = [beats[0]] |
| for beat in beats[1:]: |
| if beat - selected[-1] >= min_interval: |
| selected.append(beat) |
|
|
| return np.array(selected) |
|
|
|
|
| def save_beats( |
| beats: np.ndarray, |
| output_path: str | Path, |
| ) -> Path: |
| """Save beat timestamps to a JSON file. |
| |
| Format matches the project convention (same style as lyrics.json): |
| a list of objects with beat index and timestamp. |
| |
| Args: |
| beats: Array of beat timestamps in seconds. |
| output_path: Path to save the JSON file. |
| |
| Returns: |
| Path to the saved JSON file. |
| """ |
| output_path = Path(output_path) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| data = [ |
| {"beat": i + 1, "time": round(float(t), 3)} |
| for i, t in enumerate(beats) |
| ] |
|
|
| with open(output_path, "w") as f: |
| json.dump(data, f, indent=2) |
|
|
| return output_path |
|
|
|
|
| def run( |
| drum_stem_path: str | Path, |
| output_dir: Optional[str | Path] = None, |
| min_bpm: float = 55.0, |
| max_bpm: float = 215.0, |
| ) -> dict: |
| """Full beat detection pipeline: detect, select, and save. |
| |
| Args: |
| drum_stem_path: Path to the isolated drum stem WAV file. |
| output_dir: Directory to save beats.json. Defaults to the |
| parent of the drum stem's parent (e.g. data/Gone/ if |
| stem is at data/Gone/stems/drums.wav). |
| min_bpm: Minimum expected tempo. |
| max_bpm: Maximum expected tempo. |
| |
| Returns: |
| Dict with 'all_beats', 'selected_beats', and 'beats_path'. |
| """ |
| drum_stem_path = Path(drum_stem_path) |
|
|
| if output_dir is None: |
| |
| output_dir = drum_stem_path.parent.parent |
| output_dir = Path(output_dir) |
|
|
| all_beats = detect_beats(drum_stem_path, min_bpm=min_bpm, max_bpm=max_bpm) |
| selected = select_beats(all_beats) |
|
|
| |
| song_dir = output_dir.parent if output_dir.name.startswith("run_") else output_dir |
| audio_path = None |
| for ext in [".wav", ".mp3", ".flac", ".m4a"]: |
| candidates = list(song_dir.glob(f"*{ext}")) |
| if candidates: |
| audio_path = candidates[0] |
| break |
|
|
| drop_time = None |
| if audio_path and len(all_beats) > 2: |
| drop_time = detect_drop(audio_path, all_beats) |
|
|
| beats_path = save_beats(all_beats, output_dir / "beats.json") |
|
|
| |
| if drop_time is not None: |
| drop_path = output_dir / "drop.json" |
| with open(drop_path, "w") as f: |
| json.dump({"drop_time": round(drop_time, 3)}, f, indent=2) |
|
|
| return { |
| "all_beats": all_beats, |
| "selected_beats": selected, |
| "beats_path": beats_path, |
| "drop_time": drop_time, |
| } |
|
|
|
|
| if __name__ == "__main__": |
| import sys |
|
|
| if len(sys.argv) < 2: |
| print("Usage: python -m src.beat_detector <drum_stem.wav>") |
| sys.exit(1) |
|
|
| result = run(sys.argv[1]) |
| all_beats = result["all_beats"] |
| selected = result["selected_beats"] |
|
|
| print(f"Detected {len(all_beats)} beats (saved to {result['beats_path']})") |
| print(f"Selected {len(selected)} beats (max 15s, min 0.3s apart):") |
| for i, t in enumerate(selected): |
| print(f" Beat {i + 1}: {t:.3f}s") |
|
|