File size: 3,804 Bytes
de9c0fe
 
 
 
 
 
 
 
 
b6c1b75
de9c0fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6c1b75
de9c0fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6c1b75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de9c0fe
 
b6c1b75
 
 
 
 
 
de9c0fe
 
 
 
 
 
b6c1b75
 
 
 
 
de9c0fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/env python3
"""Quick CLI to score audio clips with the trained Mic-ID model."""

from __future__ import annotations

import argparse
import io
import os
from pathlib import Path
from typing import Iterable, List

import joblib
import librosa
import numpy as np

BASE_DIR = Path(__file__).resolve().parent
CACHE_ROOT = BASE_DIR / ".cache"
NUMBA_CACHE_DIR = CACHE_ROOT / "numba"
MPL_CACHE_DIR = CACHE_ROOT / "matplotlib"
for path in (NUMBA_CACHE_DIR, MPL_CACHE_DIR):
    path.mkdir(parents=True, exist_ok=True)
os.environ.setdefault("NUMBA_CACHE_DIR", str(NUMBA_CACHE_DIR))
os.environ.setdefault("MPLCONFIGDIR", str(MPL_CACHE_DIR))

from features import extract_features
from devices import describe_label

MODEL_PATH = Path("models/model.pkl")
ENCODER_PATH = Path("models/label_encoder.pkl")
AUDIO_EXTENSIONS = {".wav", ".mp3", ".m4a", ".flac", ".ogg"}


def load_model():
    if not MODEL_PATH.exists() or not ENCODER_PATH.exists():
        raise SystemExit("Trained artefacts not found. Run `python train.py` first.")
    clf = joblib.load(MODEL_PATH)
    le = joblib.load(ENCODER_PATH)
    return clf, le


def load_audio(path: Path, sr: int = 16000) -> tuple[np.ndarray, int]:
    if path.suffix.lower() == ".wav":
        y, sr = librosa.load(path, sr=sr, mono=True)
        return y, sr
    # fall back to BytesIO so we also support .mp3/.m4a just like the Streamlit app
    with path.open("rb") as f:
        data = io.BytesIO(f.read())
    y, sr = librosa.load(data, sr=sr, mono=True)
    return y, sr


def normalise_audio(y: np.ndarray) -> np.ndarray:
    rms = float(np.sqrt(np.mean(y**2)) + 1e-8)
    return y * (0.05 / rms), rms


def discover_inputs(paths: Iterable[Path]) -> List[Path]:
    """Expand directories into audio files, preserving explicit file ordering."""
    collected: list[Path] = []
    for path in paths:
        if path.is_dir():
            matches = sorted(
                p for p in path.rglob("*")
                if p.is_file() and p.suffix.lower() in AUDIO_EXTENSIONS
            )
            if not matches:
                print(f"[!] No audio files found under directory: {path}")
                continue
            collected.extend(matches)
        else:
            collected.append(path)
    return collected


def main() -> None:
    parser = argparse.ArgumentParser(description="Score WAV/MP3/M4A clips with the Mic-ID classifier.")
    parser.add_argument(
        "paths",
        nargs="+",
        type=Path,
        help="Audio files or directories containing audio to score",
    )
    parser.add_argument("--topk", type=int, default=3, help="How many ranked predictions to show per file")
    args = parser.parse_args()

    clf, le = load_model()
    topk = max(1, min(args.topk, len(le.classes_)))

    inputs = discover_inputs(args.paths)
    if not inputs:
        raise SystemExit("No valid audio inputs found. Provide files or directories with supported formats.")

    for path in inputs:
        if not path.exists():
            print(f"[!] Skipping missing file: {path}")
            continue
        try:
            y, sr = load_audio(path)
        except Exception as exc:  # pragma: no cover - friendly CLI message
            print(f"[!] Failed to load {path}: {exc}")
            continue
        y, rms = normalise_audio(y)
        feats = extract_features(y, sr).reshape(1, -1)
        proba = clf.predict_proba(feats)[0]
        order = np.argsort(proba)[::-1]
        print(f"\nFile: {path}")
        print(f"RMS loudness: {20 * np.log10(rms + 1e-12):.1f} dBFS")
        for rank, idx in enumerate(order[:topk], start=1):
            label = le.classes_[idx]
            friendly = describe_label(label)
            print(f"  {rank}. {friendly}{proba[idx] * 100:.1f}%")


if __name__ == "__main__":
    main()