mic-id / predict.py
connork
Align Space with latest Mic-ID release
b6c1b75
#!/usr/bin/env python3
"""Quick CLI to score audio clips with the trained Mic-ID model."""
from __future__ import annotations
import argparse
import io
import os
from pathlib import Path
from typing import Iterable, List
import joblib
import librosa
import numpy as np
BASE_DIR = Path(__file__).resolve().parent
CACHE_ROOT = BASE_DIR / ".cache"
NUMBA_CACHE_DIR = CACHE_ROOT / "numba"
MPL_CACHE_DIR = CACHE_ROOT / "matplotlib"
for path in (NUMBA_CACHE_DIR, MPL_CACHE_DIR):
path.mkdir(parents=True, exist_ok=True)
os.environ.setdefault("NUMBA_CACHE_DIR", str(NUMBA_CACHE_DIR))
os.environ.setdefault("MPLCONFIGDIR", str(MPL_CACHE_DIR))
from features import extract_features
from devices import describe_label
MODEL_PATH = Path("models/model.pkl")
ENCODER_PATH = Path("models/label_encoder.pkl")
AUDIO_EXTENSIONS = {".wav", ".mp3", ".m4a", ".flac", ".ogg"}
def load_model():
if not MODEL_PATH.exists() or not ENCODER_PATH.exists():
raise SystemExit("Trained artefacts not found. Run `python train.py` first.")
clf = joblib.load(MODEL_PATH)
le = joblib.load(ENCODER_PATH)
return clf, le
def load_audio(path: Path, sr: int = 16000) -> tuple[np.ndarray, int]:
if path.suffix.lower() == ".wav":
y, sr = librosa.load(path, sr=sr, mono=True)
return y, sr
# fall back to BytesIO so we also support .mp3/.m4a just like the Streamlit app
with path.open("rb") as f:
data = io.BytesIO(f.read())
y, sr = librosa.load(data, sr=sr, mono=True)
return y, sr
def normalise_audio(y: np.ndarray) -> np.ndarray:
rms = float(np.sqrt(np.mean(y**2)) + 1e-8)
return y * (0.05 / rms), rms
def discover_inputs(paths: Iterable[Path]) -> List[Path]:
"""Expand directories into audio files, preserving explicit file ordering."""
collected: list[Path] = []
for path in paths:
if path.is_dir():
matches = sorted(
p for p in path.rglob("*")
if p.is_file() and p.suffix.lower() in AUDIO_EXTENSIONS
)
if not matches:
print(f"[!] No audio files found under directory: {path}")
continue
collected.extend(matches)
else:
collected.append(path)
return collected
def main() -> None:
parser = argparse.ArgumentParser(description="Score WAV/MP3/M4A clips with the Mic-ID classifier.")
parser.add_argument(
"paths",
nargs="+",
type=Path,
help="Audio files or directories containing audio to score",
)
parser.add_argument("--topk", type=int, default=3, help="How many ranked predictions to show per file")
args = parser.parse_args()
clf, le = load_model()
topk = max(1, min(args.topk, len(le.classes_)))
inputs = discover_inputs(args.paths)
if not inputs:
raise SystemExit("No valid audio inputs found. Provide files or directories with supported formats.")
for path in inputs:
if not path.exists():
print(f"[!] Skipping missing file: {path}")
continue
try:
y, sr = load_audio(path)
except Exception as exc: # pragma: no cover - friendly CLI message
print(f"[!] Failed to load {path}: {exc}")
continue
y, rms = normalise_audio(y)
feats = extract_features(y, sr).reshape(1, -1)
proba = clf.predict_proba(feats)[0]
order = np.argsort(proba)[::-1]
print(f"\nFile: {path}")
print(f"RMS loudness: {20 * np.log10(rms + 1e-12):.1f} dBFS")
for rank, idx in enumerate(order[:topk], start=1):
label = le.classes_[idx]
friendly = describe_label(label)
print(f" {rank}. {friendly}{proba[idx] * 100:.1f}%")
if __name__ == "__main__":
main()