connork commited on
Commit
b6c1b75
·
1 Parent(s): 8d73c70

Align Space with latest Mic-ID release

Browse files
.gitignore CHANGED
@@ -6,3 +6,4 @@ __pycache__/
6
  .DS_Store
7
  uploads/
8
  .tmp/
 
 
6
  .DS_Store
7
  uploads/
8
  .tmp/
9
+ senior-ml-engineer-script.md
README.md CHANGED
@@ -54,7 +54,8 @@ If you are running a live session, keep this script handy:
54
  python3 -m venv .venv
55
  source .venv/bin/activate
56
  pip install -r requirements.txt
57
- python train.py # optional if you want to refresh the model
 
58
  ```
59
 
60
  Then launch the app with `streamlit run app.py` (defaults to http://localhost:8501).
@@ -81,8 +82,8 @@ Each control includes inline help text so presenters can improvise without notes
81
 
82
  ## Device Recognition
83
  - 🧱 Audio flows through `features.extract_features`, stitching log-mel and MFCC statistics with zero-crossing, centroid, roll-off, and flatness cues.
84
- - 🌲 `train.py` fits a `HistGradientBoostingClassifier`, stratified split, and saves artefacts to `models/model.pkl` plus the label encoder.
85
- - 📈 Every training run exports `reports/metrics.json` and `reports/confusion_matrix.png` so you can cite precision/recall live.
86
  - 🏷️ The app and CLI surface friendly names (e.g. “Zoom F8 field recorder”) pulled from `devices.describe_label()` to keep the story human-readable.
87
 
88
  ## Scale Detection
@@ -95,31 +96,36 @@ All sample audio lives under `data/` and mirrors the device IDs referenced in th
95
 
96
  | Folder | What it represents | Count* |
97
  | --- | --- | --- |
98
- | `audio/` | TAU Urban Acoustic Scenes clips (device A) – Zoom F8 field recorder | 3 · demo bundle |
99
- | `audio2/` | TAU Urban Acoustic Scenes clips (device B) – Samsung Galaxy S7 | 2 · demo bundle |
100
- | `audio9/` | TAU Urban Acoustic Scenes clips (device C) – iPhone SE | 1 · demo bundle |
101
- | `iphone/` | Locally recorded iPhone speech snippets captured with `utils.py` | 2 |
102
- | `laptop/` | MacBook built-in mic samples recorded in a treated room | 2 |
103
- | `outtakes/` | Extra captures you can promote into training data after curation | 3 · demo bundle |
104
 
105
- The Space ships with a travel-sized sample set; pull the full dataset locally if you want to retrain the checkpoint.
106
 
107
  ## Download Contents
108
  Every run generates artefacts you can drop into a slide deck or share with collaborators:
109
 
110
  - 🎯 `models/model.pkl` and `models/label_encoder.pkl` store the trained classifier and label map.
111
  - 📊 `reports/metrics.json` plus `reports/confusion_matrix.png` capture evaluation snapshots for the latest training session.
 
 
112
  - 📁 Uploaded clips are preserved under `uploads/hooks - <original-name>` so you can replay or re-label them later.
113
 
114
  ## Testing
115
  Quick smoke checks live in the scripts themselves:
116
 
117
  ```bash
118
- # Rebuild the model and metrics
119
- python train.py
 
 
 
120
 
121
  # Score a few clips and verify probabilities look sane
122
- python predict.py data/laptop/clip_01.wav data/iphone/clip_05.wav --topk 5
123
  ```
124
 
125
  For deeper regression coverage, wire these commands into your CI and compare the resulting metrics JSON against previous baselines.
@@ -130,9 +136,11 @@ mic-id/
130
  ├─ app.py # Streamlit UI for uploading and scoring clips
131
  ├─ predict.py # CLI scorer with friendly device names
132
  ├─ train.py # Dataset loader, model trainer, metric exporter
 
133
  ├─ features.py # Audio feature extraction helpers
134
  ├─ utils.py # Command-line recorder for new device samples
135
- ├─ data/ # Per-device waveforms (TAU + local recordings)
 
136
  ├─ models/ # Saved classifier + label encoder
137
  ├─ reports/ # Metrics JSON and confusion matrix plots
138
  ├─ docs/ # Data sourcing guide and prep notes
@@ -143,7 +151,7 @@ mic-id/
143
  ## Roadmap
144
  - 🛰️ Add a lightweight CNN baseline alongside the gradient boosting model for comparison.
145
  - 🧪 Ship augmentation scripts (noise, EQ, impulse responses) to spotlight microphone colouration differences.
146
- - 🔐 Bundle provenance metadata (`data/metadata.csv`) and automated integrity checks for new clips.
147
  - 📦 Polish export helpers so the app can bundle probabilities + features in one download.
148
 
149
  ## Contributing
 
54
  python3 -m venv .venv
55
  source .venv/bin/activate
56
  pip install -r requirements.txt
57
+ python3 scripts/refresh_metadata.py # rebuild hashes + provenance records
58
+ python3 train.py --config configs/base.yaml # optional if you want to refresh the model
59
  ```
60
 
61
  Then launch the app with `streamlit run app.py` (defaults to http://localhost:8501).
 
82
 
83
  ## Device Recognition
84
  - 🧱 Audio flows through `features.extract_features`, stitching log-mel and MFCC statistics with zero-crossing, centroid, roll-off, and flatness cues.
85
+ - 🌲 `python3 train.py --config configs/base.yaml` reads the provenance metadata, enforces per-device clip minimums, and fits a `HistGradientBoostingClassifier` before saving artefacts to `models/model.pkl` plus the label encoder.
86
+ - 📈 Every training run exports `reports/metrics.json`, `reports/confusion_matrix.png`, and a timestamped `reports/runs/run-*.json` snapshot so you can cite precision/recall live.
87
  - 🏷️ The app and CLI surface friendly names (e.g. “Zoom F8 field recorder”) pulled from `devices.describe_label()` to keep the story human-readable.
88
 
89
  ## Scale Detection
 
96
 
97
  | Folder | What it represents | Count* |
98
  | --- | --- | --- |
99
+ | `audio/` | TAU Urban Acoustic Scenes clips (device A) – Zoom F8 field recorder | 295 |
100
+ | `audio2/` | TAU Urban Acoustic Scenes clips (device B) – Samsung Galaxy S7 | 295 |
101
+ | `audio9/` | TAU Urban Acoustic Scenes clips (device C) – iPhone SE | 295 |
102
+ | `iphone/` | Locally recorded iPhone speech snippets captured with `utils.py` | 4 |
103
+ | `laptop/` | MacBook built-in mic samples recorded in a treated room | 4 |
104
+ | `outtakes/` | Extra captures you can promote into training data after curation | varies |
105
 
106
+ Counts based on the current repo snapshot; refresh `data/` to rebalance as needed.
107
 
108
  ## Download Contents
109
  Every run generates artefacts you can drop into a slide deck or share with collaborators:
110
 
111
  - 🎯 `models/model.pkl` and `models/label_encoder.pkl` store the trained classifier and label map.
112
  - 📊 `reports/metrics.json` plus `reports/confusion_matrix.png` capture evaluation snapshots for the latest training session.
113
+ - 🧾 `data/metadata.csv` tracks every clip’s provenance, licence, and hash for reproducible retrains.
114
+ - 🗂️ `reports/runs/run-*.json` snapshots record the exact config, dataset summary, and hashes used for each training run.
115
  - 📁 Uploaded clips are preserved under `uploads/hooks - <original-name>` so you can replay or re-label them later.
116
 
117
  ## Testing
118
  Quick smoke checks live in the scripts themselves:
119
 
120
  ```bash
121
+ # Validate provenance without training
122
+ python3 train.py --dry-run
123
+
124
+ # Rebuild the model, metrics, and run snapshot
125
+ python3 train.py --config configs/base.yaml
126
 
127
  # Score a few clips and verify probabilities look sane
128
+ python3 predict.py data/laptop/clip_01.wav data/iphone/clip_05.wav --topk 5
129
  ```
130
 
131
  For deeper regression coverage, wire these commands into your CI and compare the resulting metrics JSON against previous baselines.
 
136
  ├─ app.py # Streamlit UI for uploading and scoring clips
137
  ├─ predict.py # CLI scorer with friendly device names
138
  ├─ train.py # Dataset loader, model trainer, metric exporter
139
+ ├─ configs/ # YAML training configs + device provenance defaults
140
  ├─ features.py # Audio feature extraction helpers
141
  ├─ utils.py # Command-line recorder for new device samples
142
+ ├─ data/ # Per-device waveforms and provenance metadata
143
+ │ └─ metadata.csv # Clip-level provenance (source/licence/hash)
144
  ├─ models/ # Saved classifier + label encoder
145
  ├─ reports/ # Metrics JSON and confusion matrix plots
146
  ├─ docs/ # Data sourcing guide and prep notes
 
151
  ## Roadmap
152
  - 🛰️ Add a lightweight CNN baseline alongside the gradient boosting model for comparison.
153
  - 🧪 Ship augmentation scripts (noise, EQ, impulse responses) to spotlight microphone colouration differences.
154
+ - 🔐 Wire metadata/hash validation into CI so new clips are rejected unless provenance is complete.
155
  - 📦 Polish export helpers so the app can bundle probabilities + features in one download.
156
 
157
  ## Contributing
configs/base.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ root: data
3
+ metadata: data/metadata.csv
4
+ enforce_hashes: true
5
+ min_clips_per_device: 1
6
+ include_devices:
7
+ - audio
8
+ - audio2
9
+ - audio9
10
+ - iphone
11
+ - laptop
12
+ splits:
13
+ - train
14
+ device_defaults:
15
+ iphone:
16
+ source: "In-house recordings"
17
+ license: "Private use"
18
+ laptop:
19
+ source: "In-house recordings"
20
+ license: "Private use"
21
+ audio:
22
+ source: "TAU Urban Acoustic Scenes 2019 Mobile"
23
+ license: "CC-BY 4.0"
24
+ audio2:
25
+ source: "TAU Urban Acoustic Scenes 2019 Mobile"
26
+ license: "CC-BY 4.0"
27
+ audio9:
28
+ source: "TAU Urban Acoustic Scenes 2019 Mobile"
29
+ license: "CC-BY 4.0"
30
+
31
+ training:
32
+ test_size: 0.25
33
+ random_state: 42
34
+ classifier:
35
+ max_depth: 10
36
+ max_iter: 400
37
+ learning_rate: 0.08
38
+
39
+ reporting:
40
+ metrics_path: reports/metrics.json
41
+ confusion_matrix_path: reports/confusion_matrix.png
42
+ runs_dir: reports/runs
43
+ tag: baseline
data/metadata.csv ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ path,device,source,license,split,sha256
2
+ audio/airport-helsinki-204-6138-a.wav,audio,TAU Urban Acoustic Scenes 2019 Mobile,CC-BY 4.0,train,db356757394c3ed66d87990ed98a080b1a1a1778aaaffe110b723c9fbd294814
3
+ audio/airport-lisbon-175-4700-a.wav,audio,TAU Urban Acoustic Scenes 2019 Mobile,CC-BY 4.0,train,ebcea04001ff88fd63af3e76d153ae2d72f3ea65889f3a215c894ca14ce173be
4
+ audio2/bus-stockholm-35-1041-a.wav,audio2,TAU Urban Acoustic Scenes 2019 Mobile,CC-BY 4.0,train,b76dfe5d2d25b4b7912af63110d11f32623cb8308872cd601cc1fbe6daac8ef8
5
+ audio2/bus-stockholm-35-1041-b.wav,audio2,TAU Urban Acoustic Scenes 2019 Mobile,CC-BY 4.0,train,9eb1bf29c4055d9b7863f4395b59a177873c1fb00f6e16f026597643f1339742
6
+ audio9/street_pedestrian-london-149-4500-c.wav,audio9,TAU Urban Acoustic Scenes 2019 Mobile,CC-BY 4.0,train,42ce95e42426e18ae1f25174147bfa799644a3064dcad7072b744239cef134af
7
+ iphone/clip_01.wav,iphone,In-house recordings,Private use,train,fe9b1dc52cd1eb21550847ba08b2c2ddc79443c378ce88945b55a4de9c3656bf
8
+ iphone/clip_05.wav,iphone,In-house recordings,Private use,train,017691167b2b7e93fe52ce7e643ca76767986c478e4efe4c2a66bbbfaee2c99a
9
+ laptop/clip_01.wav,laptop,In-house recordings,Private use,train,f163fd7dc320b3c7ede45104fadff2f90d795f740c9a59156b8cb71613c9f773
10
+ laptop/clip_05.wav,laptop,In-house recordings,Private use,train,56d5a2ca715e1dc1f08f02d619c1a1c770ea60378b721638f9f2aeffb4829233
devices.py CHANGED
@@ -3,7 +3,9 @@ MIC_FRIENDLY_NAMES = {
3
  "audio2": "Samsung Galaxy S7 (TAU device B)",
4
  "audio9": "iPhone SE (TAU device C)",
5
  "iphone": "Local iPhone recordings",
6
- "laptop": "MacBook built-in microphone",
 
 
7
  }
8
 
9
 
 
3
  "audio2": "Samsung Galaxy S7 (TAU device B)",
4
  "audio9": "iPhone SE (TAU device C)",
5
  "iphone": "Local iPhone recordings",
6
+ # These clips were captured both with the MacBook mic and AirPods Pro;
7
+ # keep the class label stable but surface the combined description.
8
+ "laptop": "AirPods Pro / MacBook built-in microphone",
9
  }
10
 
11
 
docs/clip02-misclassification.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Clip 02 Misclassification Case Study
2
+
3
+ ## Issue Summary
4
+ - **Symptom**: `python predict.py data/iphone/*.wav` classified `data/iphone/clip_02.wav` as “MacBook built-in microphone” (~53 %) instead of “Local iPhone recordings” (≈47 %).
5
+ - **Impact**: Undermined trust in the classifier for quiet iPhone speech, indicating poor separation between the iPhone and AirPods/Mac classes.
6
+
7
+ ## Investigation
8
+ - Confirmed the mismatch reproduced after the first training run with the new TAU batches.
9
+ - Compared class distributions via `train.py --dry-run`; highlighted severe imbalance: TAU devices (≈295 clips each) vs. iPhone (15 wav + 47 m4a) vs. AirPods/Mac (15 wav + 14 m4a).
10
+ - Noted identical feature extraction between training and inference (`features.extract_features`), driving suspicion toward data coverage rather than pipeline drift.
11
+
12
+ ## Actions Taken
13
+ 1. **Data Organisation**
14
+ - Split the TAU Mobile archive into `data/audio`, `data/audio2`, and `data/audio9` based on filename suffixes (`-a/-b/-c`).
15
+ - Normalised provenance defaults in `configs/base.yaml` for the new device buckets.
16
+ 2. **Metadata Refresh**
17
+ - Ran `python3 scripts/refresh_metadata.py --config configs/base.yaml` to register hashes and sources for all clips (including new iPhone/AirPods captures).
18
+ - Repeated after each data ingest to keep `data/metadata.csv` consistent.
19
+ 3. **Model Retraining**
20
+ - Executed `python train.py` to rebuild `models/model.pkl` and `models/label_encoder.pkl` with the expanded dataset (990 clips total).
21
+ 4. **Inference UX Improvements**
22
+ - Allowed directory inputs in `predict.py` so `python predict.py data/iphone` expands automatically.
23
+ - Updated the “laptop” friendly name to “AirPods Pro / MacBook built-in microphone” to reflect the mixed capture source.
24
+
25
+ ## Verification
26
+ - Post-retrain prediction:
27
+ ```
28
+ File: data/iphone/clip_02.wav
29
+ RMS loudness: -40.8 dBFS
30
+ 1. Local iPhone recordings — 96.1%
31
+ 2. AirPods Pro / MacBook built-in microphone — 3.9%
32
+ 3. Samsung Galaxy S7 (TAU device B) — 0.0%
33
+ ```
34
+ - The confidence inversion (≈96 % iPhone) confirms the classifier now separates the classes even for low-level speech content.
35
+
36
+ ## Feature Changes for Improved Results
37
+ - `configs/base.yaml`: added TAU device folders to `include_devices` and defined CC-BY provenance defaults.
38
+ - `data/metadata.csv`: regenerated with 990 entries to incorporate the new recordings (62 iPhone, 29 AirPods/Mac).
39
+ - `devices.py`: renamed the “laptop” label to “AirPods Pro / MacBook built-in microphone” for accurate reporting.
40
+ - `predict.py`: added directory expansion and broader audio-extension support to streamline batch evaluation.
41
+ - Dataset restructuring: migrated TAU archive clips into `data/audio`, `data/audio2`, `data/audio9` directories, preserving the `-a/-b/-c` microphone mapping.
42
+
43
+ ## Follow-Up Recommendations
44
+ - Continue collecting parallel iPhone vs. AirPods recordings, especially in quiet environments, until class counts approach parity with TAU devices.
45
+ - Maintain a held-out validation set (not yet captured) to quantify gains objectively beyond spot checks.
46
+ - Document future ingestion runs by appending to this case study or a dedicated experiment log under `docs/`.
47
+
docs/data-sourcing.md CHANGED
@@ -62,4 +62,10 @@ Mic-ID works best when every class corresponds to a capture device that has enou
62
  2. Store downloaded archives under `data/raw/` (ignored by git) and export processed clips to `data/<device>/`.
63
  3. Update `metadata.csv` whenever you add or remove external clips so the experiment log in `reports/` stays reproducible.
64
 
 
 
 
 
 
 
65
  For more ideas, browse the DCASE and ASVspoof challenge leaderboards—winning teams usually publish their data prep notes and often release additional impulse responses or parallel recordings.
 
62
  2. Store downloaded archives under `data/raw/` (ignored by git) and export processed clips to `data/<device>/`.
63
  3. Update `metadata.csv` whenever you add or remove external clips so the experiment log in `reports/` stays reproducible.
64
 
65
+ ## Provenance workflow
66
+ - Run `python3 scripts/refresh_metadata.py` after adding or trimming clips to recompute SHA256 hashes and populate default source/licence values.
67
+ - Manually edit `data/metadata.csv` when a clip needs corrected credits or licence text; the training step will refuse to run if either field is missing.
68
+ - Validate the metadata without training by running `python3 train.py --dry-run`; this catches missing files, hash mismatches, and low clip counts early.
69
+ - Commit both the metadata file and the resulting `reports/runs/run-*.json` snapshot so collaborators can audit exactly which audio went into each checkpoint.
70
+
71
  For more ideas, browse the DCASE and ASVspoof challenge leaderboards—winning teams usually publish their data prep notes and often release additional impulse responses or parallel recordings.
features.py CHANGED
@@ -2,6 +2,7 @@ import numpy as np, librosa
2
 
3
 
4
  def load_mono(path, sr=16000):
 
5
  x, sr = librosa.load(path, sr=sr, mono=True)
6
  x, _ = librosa.effects.trim(x, top_db=30)
7
  rms = np.sqrt(np.mean(x**2)) + 1e-8
 
2
 
3
 
4
  def load_mono(path, sr=16000):
5
+ path = str(path)
6
  x, sr = librosa.load(path, sr=sr, mono=True)
7
  x, _ = librosa.effects.trim(x, top_db=30)
8
  rms = np.sqrt(np.mean(x**2)) + 1e-8
models/label_encoder.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b63e81ce06710e7e8cc2dd245a0960912697516459129ce32657a5b0234cbd49
3
- size 447
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af7684de27332a4a68cc6d4f75511e9377f8243e925ed8415cfbef651d76ce76
3
+ size 663
models/model.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9bd34e60ea1f1851d23b6808d4d0ad6ca2a10968322b527dd01f32b4e8761e0b
3
- size 2006992
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:526a822cbd88b1a060e845ef99bac697ac116f2afeb6edd0c06a610d5bf23211
3
+ size 2967440
predict.py CHANGED
@@ -7,6 +7,7 @@ import argparse
7
  import io
8
  import os
9
  from pathlib import Path
 
10
 
11
  import joblib
12
  import librosa
@@ -26,6 +27,7 @@ from devices import describe_label
26
 
27
  MODEL_PATH = Path("models/model.pkl")
28
  ENCODER_PATH = Path("models/label_encoder.pkl")
 
29
 
30
 
31
  def load_model():
@@ -52,16 +54,43 @@ def normalise_audio(y: np.ndarray) -> np.ndarray:
52
  return y * (0.05 / rms), rms
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def main() -> None:
56
  parser = argparse.ArgumentParser(description="Score WAV/MP3/M4A clips with the Mic-ID classifier.")
57
- parser.add_argument("paths", nargs="+", type=Path, help="Audio files to score")
 
 
 
 
 
58
  parser.add_argument("--topk", type=int, default=3, help="How many ranked predictions to show per file")
59
  args = parser.parse_args()
60
 
61
  clf, le = load_model()
62
  topk = max(1, min(args.topk, len(le.classes_)))
63
 
64
- for path in args.paths:
 
 
 
 
65
  if not path.exists():
66
  print(f"[!] Skipping missing file: {path}")
67
  continue
 
7
  import io
8
  import os
9
  from pathlib import Path
10
+ from typing import Iterable, List
11
 
12
  import joblib
13
  import librosa
 
27
 
28
  MODEL_PATH = Path("models/model.pkl")
29
  ENCODER_PATH = Path("models/label_encoder.pkl")
30
+ AUDIO_EXTENSIONS = {".wav", ".mp3", ".m4a", ".flac", ".ogg"}
31
 
32
 
33
  def load_model():
 
54
  return y * (0.05 / rms), rms
55
 
56
 
57
+ def discover_inputs(paths: Iterable[Path]) -> List[Path]:
58
+ """Expand directories into audio files, preserving explicit file ordering."""
59
+ collected: list[Path] = []
60
+ for path in paths:
61
+ if path.is_dir():
62
+ matches = sorted(
63
+ p for p in path.rglob("*")
64
+ if p.is_file() and p.suffix.lower() in AUDIO_EXTENSIONS
65
+ )
66
+ if not matches:
67
+ print(f"[!] No audio files found under directory: {path}")
68
+ continue
69
+ collected.extend(matches)
70
+ else:
71
+ collected.append(path)
72
+ return collected
73
+
74
+
75
  def main() -> None:
76
  parser = argparse.ArgumentParser(description="Score WAV/MP3/M4A clips with the Mic-ID classifier.")
77
+ parser.add_argument(
78
+ "paths",
79
+ nargs="+",
80
+ type=Path,
81
+ help="Audio files or directories containing audio to score",
82
+ )
83
  parser.add_argument("--topk", type=int, default=3, help="How many ranked predictions to show per file")
84
  args = parser.parse_args()
85
 
86
  clf, le = load_model()
87
  topk = max(1, min(args.topk, len(le.classes_)))
88
 
89
+ inputs = discover_inputs(args.paths)
90
+ if not inputs:
91
+ raise SystemExit("No valid audio inputs found. Provide files or directories with supported formats.")
92
+
93
+ for path in inputs:
94
  if not path.exists():
95
  print(f"[!] Skipping missing file: {path}")
96
  continue
reports/confusion_matrix.png CHANGED

Git LFS Details

  • SHA256: 0b0e09da5010560decc3487eb55484df9a7d8956ec127d0ebd0665dcf848d358
  • Pointer size: 130 Bytes
  • Size of remote file: 44.4 kB

Git LFS Details

  • SHA256: 10970dab554b283179ddcdc09ba7519fa398fe9b92ad93a2827eb6b4c37d6451
  • Pointer size: 130 Bytes
  • Size of remote file: 52.4 kB
reports/metrics.json CHANGED
@@ -1,45 +1,51 @@
1
  {
2
  "audio": {
3
- "precision": 0.9726027397260274,
4
- "recall": 0.9594594594594594,
5
- "f1-score": 0.9659863945578231,
6
  "support": 74.0
7
  },
8
  "audio2": {
9
- "precision": 0.9864864864864865,
10
  "recall": 0.9864864864864865,
11
- "f1-score": 0.9864864864864865,
12
  "support": 74.0
13
  },
14
  "audio9": {
15
- "precision": 0.9605263157894737,
16
- "recall": 0.9864864864864865,
17
- "f1-score": 0.9733333333333334,
18
  "support": 74.0
19
  },
20
  "iphone": {
21
- "precision": 1.0,
22
- "recall": 0.75,
23
- "f1-score": 0.8571428571428571,
24
- "support": 4.0
25
  },
26
  "laptop": {
27
- "precision": 0.6666666666666666,
28
- "recall": 0.6666666666666666,
29
- "f1-score": 0.6666666666666666,
 
 
 
 
 
 
30
  "support": 3.0
31
  },
32
- "accuracy": 0.9694323144104804,
33
  "macro avg": {
34
- "precision": 0.9172564417337309,
35
- "recall": 0.8698198198198199,
36
- "f1-score": 0.8899231476374334,
37
- "support": 229.0
38
  },
39
  "weighted avg": {
40
- "precision": 0.969657424053044,
41
- "recall": 0.9694323144104804,
42
- "f1-score": 0.969162582063393,
43
- "support": 229.0
44
  }
45
  }
 
1
  {
2
  "audio": {
3
+ "precision": 0.9473684210526315,
4
+ "recall": 0.972972972972973,
5
+ "f1-score": 0.96,
6
  "support": 74.0
7
  },
8
  "audio2": {
9
+ "precision": 0.9733333333333334,
10
  "recall": 0.9864864864864865,
11
+ "f1-score": 0.9798657718120806,
12
  "support": 74.0
13
  },
14
  "audio9": {
15
+ "precision": 0.9594594594594594,
16
+ "recall": 0.9594594594594594,
17
+ "f1-score": 0.9594594594594594,
18
  "support": 74.0
19
  },
20
  "iphone": {
21
+ "precision": 0.9375,
22
+ "recall": 0.9375,
23
+ "f1-score": 0.9375,
24
+ "support": 16.0
25
  },
26
  "laptop": {
27
+ "precision": 1.0,
28
+ "recall": 1.0,
29
+ "f1-score": 1.0,
30
+ "support": 7.0
31
+ },
32
+ "outtakes : new": {
33
+ "precision": 0.0,
34
+ "recall": 0.0,
35
+ "f1-score": 0.0,
36
  "support": 3.0
37
  },
38
+ "accuracy": 0.9596774193548387,
39
  "macro avg": {
40
+ "precision": 0.802943535640904,
41
+ "recall": 0.8094031531531533,
42
+ "f1-score": 0.8061375385452566,
43
+ "support": 248.0
44
  },
45
  "weighted avg": {
46
+ "precision": 0.9481126202603283,
47
+ "recall": 0.9596774193548387,
48
+ "f1-score": 0.9538309157826369,
49
+ "support": 248.0
50
  }
51
  }
requirements.txt CHANGED
@@ -7,3 +7,4 @@ numpy
7
  pandas
8
  matplotlib
9
  joblib
 
 
7
  pandas
8
  matplotlib
9
  joblib
10
+ pyyaml
scripts/refresh_metadata.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Generate or refresh data/metadata.csv entries with provenance details."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import csv
8
+ import hashlib
9
+ import sys
10
+ from dataclasses import dataclass
11
+ from pathlib import Path
12
+ from typing import Dict, Iterable, Optional
13
+
14
+ import yaml
15
+
16
+
17
+ DEFAULT_EXTENSIONS = {".wav", ".mp3", ".m4a"}
18
+
19
+
20
+ @dataclass
21
+ class MetadataRow:
22
+ path: Path
23
+ device: str
24
+ source: str
25
+ license: str
26
+ split: str
27
+ sha256: str
28
+
29
+ def as_dict(self, root: Path) -> Dict[str, str]:
30
+ rel_path = self.path.relative_to(root).as_posix()
31
+ return {
32
+ "path": rel_path,
33
+ "device": self.device,
34
+ "source": self.source,
35
+ "license": self.license,
36
+ "split": self.split,
37
+ "sha256": self.sha256,
38
+ }
39
+
40
+
41
+ def parse_args() -> argparse.Namespace:
42
+ parser = argparse.ArgumentParser(description=__doc__)
43
+ parser.add_argument("--config", default="configs/base.yaml", help="YAML config that defines data root and defaults.")
44
+ parser.add_argument("--output", help="Override output metadata CSV path. Defaults to the config value.")
45
+ parser.add_argument("--extensions", nargs="*", help="File extensions to include (e.g., .wav .mp3 .m4a). Defaults to built-ins.")
46
+ return parser.parse_args()
47
+
48
+
49
+ def load_config(path: Path) -> dict:
50
+ if not path.exists():
51
+ raise SystemExit(f"Config not found: {path}")
52
+ with path.open("r", encoding="utf-8") as fh:
53
+ cfg = yaml.safe_load(fh) or {}
54
+ if "data" not in cfg:
55
+ raise SystemExit("Config is missing a `data` section.")
56
+ return cfg
57
+
58
+
59
+ def read_existing_metadata(path: Path) -> Dict[str, dict]:
60
+ if not path.exists():
61
+ return {}
62
+ with path.open("r", encoding="utf-8", newline="") as fh:
63
+ reader = csv.DictReader(fh)
64
+ return {row["path"]: row for row in reader if "path" in row}
65
+
66
+
67
+ def compute_sha256(path: Path) -> str:
68
+ hasher = hashlib.sha256()
69
+ with path.open("rb") as fh:
70
+ for chunk in iter(lambda: fh.read(8192), b""):
71
+ hasher.update(chunk)
72
+ return hasher.hexdigest()
73
+
74
+
75
+ def gather_files(root: Path, extensions: Iterable[str]) -> Iterable[Path]:
76
+ for file_path in root.rglob("*"):
77
+ if not file_path.is_file():
78
+ continue
79
+ if file_path.suffix.lower() in extensions:
80
+ yield file_path
81
+
82
+
83
+ def build_rows(
84
+ files: Iterable[Path],
85
+ existing_rows: Dict[str, dict],
86
+ root: Path,
87
+ device_defaults: Optional[dict],
88
+ include_devices: Optional[set[str]],
89
+ ) -> Iterable[MetadataRow]:
90
+ for path in files:
91
+ rel_key = path.relative_to(root).as_posix()
92
+ parts = path.relative_to(root).parts
93
+ if not parts:
94
+ continue
95
+ device = parts[0]
96
+ if include_devices and device not in include_devices:
97
+ continue
98
+
99
+ defaults = (device_defaults or {}).get(device, {})
100
+ existing = existing_rows.get(rel_key, {})
101
+
102
+ source = existing.get("source") or defaults.get("source")
103
+ license_ = existing.get("license") or defaults.get("license")
104
+ split = existing.get("split") or "train"
105
+
106
+ if not source or not license_:
107
+ sys.stderr.write(f"[warn] Missing source/license for {rel_key}; fill these in manually.\n")
108
+
109
+ sha256 = compute_sha256(path)
110
+
111
+ yield MetadataRow(
112
+ path=path,
113
+ device=device,
114
+ source=source or "",
115
+ license=license_ or "",
116
+ split=split,
117
+ sha256=sha256,
118
+ )
119
+
120
+
121
+ def main() -> None:
122
+ args = parse_args()
123
+ config_path = Path(args.config)
124
+ cfg = load_config(config_path)
125
+
126
+ data_cfg = cfg["data"]
127
+ root = Path(data_cfg.get("root", "data")).resolve()
128
+ metadata_path = Path(args.output or data_cfg.get("metadata", root / "metadata.csv")).resolve()
129
+ extensions = {ext.lower() for ext in (args.extensions or data_cfg.get("extensions", DEFAULT_EXTENSIONS))}
130
+
131
+ if not root.exists():
132
+ raise SystemExit(f"Data root does not exist: {root}")
133
+
134
+ existing_rows = read_existing_metadata(metadata_path)
135
+ device_defaults = data_cfg.get("device_defaults", {})
136
+ include_devices = set(data_cfg.get("include_devices", []) or [])
137
+
138
+ files = sorted(gather_files(root, extensions))
139
+ rows = sorted(
140
+ build_rows(files, existing_rows, root, device_defaults, include_devices if include_devices else None),
141
+ key=lambda row: row.path.relative_to(root).as_posix(),
142
+ )
143
+
144
+ metadata_path.parent.mkdir(parents=True, exist_ok=True)
145
+ with metadata_path.open("w", encoding="utf-8", newline="") as fh:
146
+ writer = csv.DictWriter(fh, fieldnames=["path", "device", "source", "license", "split", "sha256"])
147
+ writer.writeheader()
148
+ for row in rows:
149
+ writer.writerow(row.as_dict(root))
150
+
151
+ orphaned = sorted(set(existing_rows) - {row.path.relative_to(root).as_posix() for row in rows})
152
+ if orphaned:
153
+ sys.stderr.write(f"[warn] Orphaned metadata entries (files missing): {len(orphaned)}\n")
154
+ for item in orphaned:
155
+ sys.stderr.write(f" - {item}\n")
156
+
157
+ print(f"Wrote {len(rows)} rows to {metadata_path}")
158
+
159
+
160
+ if __name__ == "__main__":
161
+ main()
train.py CHANGED
@@ -1,7 +1,15 @@
1
- import os
2
- import glob
 
 
 
 
3
  import json
 
 
 
4
  from pathlib import Path
 
5
 
6
  BASE_DIR = Path(__file__).resolve().parent
7
  CACHE_ROOT = BASE_DIR / ".cache"
@@ -12,93 +20,349 @@ for path in (NUMBA_CACHE_DIR, MPL_CACHE_DIR):
12
  os.environ.setdefault("NUMBA_CACHE_DIR", str(NUMBA_CACHE_DIR))
13
  os.environ.setdefault("MPLCONFIGDIR", str(MPL_CACHE_DIR))
14
 
15
- import numpy as np
16
  import matplotlib
 
17
  matplotlib.use("Agg", force=True)
18
  import matplotlib.pyplot as plt
 
 
19
  from sklearn.ensemble import HistGradientBoostingClassifier
20
- from sklearn.preprocessing import LabelEncoder
21
  from sklearn.metrics import classification_report, confusion_matrix
22
  from sklearn.model_selection import train_test_split
23
- import joblib
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- from features import load_mono, extract_features
26
 
27
- DATA_DIR, MODEL_DIR, REPORT_DIR = "data", "models", "reports"
28
- os.makedirs(MODEL_DIR, exist_ok=True); os.makedirs(REPORT_DIR, exist_ok=True)
 
 
 
29
 
30
- IGNORED_DEVICES = {"outtakes"}
31
- SUFFIX_TO_DEVICE = {
32
- "a": "audio",
33
- "b": "audio2",
34
- "c": "audio9",
35
- }
36
- TAU_DEVICE_DIRS = set(SUFFIX_TO_DEVICE.values())
37
 
 
 
 
 
 
 
 
 
38
 
39
- def resolve_device_label(device_dir: str, wav_path: str) -> str:
40
- """Infer the correct device label for a wav file.
41
 
42
- TAU scenes live under per-device directories but each folder still contains
43
- the parallel `-a/-b/-c` recordings. Instead of trusting the directory name
44
- (which mislabels the clips), derive the device from the filename suffix and
45
- fall back to the directory label for any locally recorded additions that
46
- do not follow that convention.
47
- """
48
 
49
- if device_dir in TAU_DEVICE_DIRS:
50
- stem = Path(wav_path).stem
51
- if "-" in stem:
52
- _, suffix = stem.rsplit("-", 1)
53
- if suffix in SUFFIX_TO_DEVICE:
54
- return SUFFIX_TO_DEVICE[suffix]
55
- return device_dir
56
 
 
 
 
 
 
 
 
 
57
 
58
- def load_dataset():
59
- X, y = [], []
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  seen: set[tuple[str, str]] = set()
61
- for device in sorted(
62
- d for d in os.listdir(DATA_DIR)
63
- if os.path.isdir(os.path.join(DATA_DIR, d))
64
- and not d.startswith(".")
65
- and d not in IGNORED_DEVICES
66
- ):
67
- for wav in glob.glob(os.path.join(DATA_DIR, device, "*.wav")):
68
- label = resolve_device_label(device, wav)
69
- key = (os.path.basename(wav), label)
70
- if key in seen:
71
- continue
72
- seen.add(key)
73
- x, sr = load_mono(wav); feats = extract_features(x, sr)
74
- X.append(feats); y.append(label)
75
- return np.array(X), np.array(y)
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- if __name__ == "__main__":
79
- X, y = load_dataset()
80
- le = LabelEncoder(); y_enc = le.fit_transform(y)
81
- Xtr, Xte, ytr, yte = train_test_split(X, y_enc, test_size=0.25, stratify=y_enc, random_state=42)
82
-
83
- clf = HistGradientBoostingClassifier(max_depth=10, max_iter=400, learning_rate=0.08, random_state=42)
84
- clf.fit(Xtr, ytr); yhat = clf.predict(Xte)
85
-
86
- report = classification_report(yte, yhat, target_names=le.classes_, output_dict=True)
87
- with open(os.path.join(REPORT_DIR, "metrics.json"), "w") as f: json.dump(report, f, indent=2)
88
-
89
- cm = confusion_matrix(yte, yhat, normalize="true")
90
- fig, ax = plt.subplots(figsize=(5,4)); im = ax.imshow(cm, cmap="Blues")
91
- ax.set_xticks(range(len(le.classes_))); ax.set_xticklabels(le.classes_, rotation=45, ha="right")
92
- ax.set_yticks(range(len(le.classes_))); ax.set_yticklabels(le.classes_)
93
- for i in range(len(le.classes_)):
94
- for j in range(len(le.classes_)):
95
- ax.text(j, i, f"{cm[i,j]:.2f}", ha="center", va="center", fontsize=8)
96
- ax.set_title("Confusion (normalized)"); fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04); fig.tight_layout()
97
- fig.savefig(os.path.join(REPORT_DIR, "confusion_matrix.png"), dpi=160)
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  if hasattr(clf, "_feature_subsample_rng"):
100
  clf._feature_subsample_rng = None
101
 
102
- joblib.dump(clf, os.path.join(MODEL_DIR, "model.pkl"))
103
- joblib.dump(le, os.path.join(MODEL_DIR, "label_encoder.pkl"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  print("Saved model + reports.")
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import csv
5
+ import datetime as dt
6
+ import hashlib
7
  import json
8
+ import os
9
+ from collections import Counter
10
+ from dataclasses import dataclass
11
  from pathlib import Path
12
+ from typing import Iterable, Sequence
13
 
14
  BASE_DIR = Path(__file__).resolve().parent
15
  CACHE_ROOT = BASE_DIR / ".cache"
 
20
  os.environ.setdefault("NUMBA_CACHE_DIR", str(NUMBA_CACHE_DIR))
21
  os.environ.setdefault("MPLCONFIGDIR", str(MPL_CACHE_DIR))
22
 
23
+ import joblib
24
  import matplotlib
25
+
26
  matplotlib.use("Agg", force=True)
27
  import matplotlib.pyplot as plt
28
+ import numpy as np
29
+ import yaml
30
  from sklearn.ensemble import HistGradientBoostingClassifier
 
31
  from sklearn.metrics import classification_report, confusion_matrix
32
  from sklearn.model_selection import train_test_split
33
+ from sklearn.preprocessing import LabelEncoder
34
+
35
+ from features import extract_features, load_mono
36
+
37
+ TARGET_SR = 16000
38
+ REQUIRED_COLUMNS = {"path", "device", "source", "license", "split", "sha256"}
39
+ MODEL_DIR = BASE_DIR / "models"
40
+ REPORT_DIR = BASE_DIR / "reports"
41
+
42
+ MODEL_DIR.mkdir(exist_ok=True)
43
+ REPORT_DIR.mkdir(exist_ok=True)
44
+
45
+
46
+ @dataclass(frozen=True)
47
+ class ClipRecord:
48
+ path: Path
49
+ device: str
50
+ source: str
51
+ license: str
52
+ split: str
53
+ sha256: str
54
+
55
+ def relative_path(self, root: Path) -> str:
56
+ return self.path.relative_to(root).as_posix()
57
 
 
58
 
59
+ def parse_args() -> argparse.Namespace:
60
+ parser = argparse.ArgumentParser(description="Train the Mic-ID classifier with provenance tracking.")
61
+ parser.add_argument("--config", default="configs/base.yaml", help="YAML config describing data + training parameters.")
62
+ parser.add_argument("--dry-run", action="store_true", help="Validate metadata and show dataset summary without training.")
63
+ return parser.parse_args()
64
 
 
 
 
 
 
 
 
65
 
66
+ def load_config(path: Path) -> dict:
67
+ if not path.exists():
68
+ raise SystemExit(f"Config not found: {path}")
69
+ with path.open("r", encoding="utf-8") as fh:
70
+ cfg = yaml.safe_load(fh) or {}
71
+ if "data" not in cfg or "training" not in cfg or "reporting" not in cfg:
72
+ raise SystemExit("Config must include `data`, `training`, and `reporting` sections.")
73
+ return cfg
74
 
 
 
75
 
76
+ def compute_sha256(path: Path) -> str:
77
+ hasher = hashlib.sha256()
78
+ with path.open("rb") as fh:
79
+ for chunk in iter(lambda: fh.read(8192), b""):
80
+ hasher.update(chunk)
81
+ return hasher.hexdigest()
82
 
 
 
 
 
 
 
 
83
 
84
+ def read_metadata_csv(path: Path) -> list[dict]:
85
+ with path.open("r", encoding="utf-8", newline="") as fh:
86
+ reader = csv.DictReader(fh)
87
+ headers = set(reader.fieldnames or [])
88
+ missing = REQUIRED_COLUMNS - headers
89
+ if missing:
90
+ raise SystemExit(f"Metadata file {path} is missing required columns: {sorted(missing)}")
91
+ return list(reader)
92
 
93
+
94
+ def load_clip_records(data_cfg: dict) -> tuple[list[ClipRecord], Path, Path]:
95
+ root = Path(data_cfg.get("root", "data")).resolve()
96
+ metadata_path = Path(data_cfg.get("metadata", root / "metadata.csv")).resolve()
97
+ enforce_hashes = bool(data_cfg.get("enforce_hashes", True))
98
+ splits_filter = set(data_cfg.get("splits", []) or [])
99
+ include_devices = set(data_cfg.get("include_devices", []) or [])
100
+
101
+ if not root.exists():
102
+ raise SystemExit(f"Data root does not exist: {root}")
103
+ if not metadata_path.exists():
104
+ raise SystemExit(f"Metadata file not found: {metadata_path}")
105
+
106
+ raw_rows = read_metadata_csv(metadata_path)
107
+ records: list[ClipRecord] = []
108
  seen: set[tuple[str, str]] = set()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
+ for idx, row in enumerate(raw_rows, start=2):
111
+ rel_path = row["path"].strip()
112
+ device = row["device"].strip()
113
+ source = row["source"].strip()
114
+ license_ = row["license"].strip()
115
+ split = row["split"].strip() or "train"
116
+ sha256 = row["sha256"].strip()
117
+
118
+ if include_devices and device not in include_devices:
119
+ continue
120
+ if splits_filter and split not in splits_filter:
121
+ continue
122
+
123
+ if not rel_path:
124
+ raise SystemExit(f"Row {idx} is missing a path.")
125
+ if not device:
126
+ raise SystemExit(f"Row {idx} is missing a device label (path={rel_path}).")
127
+ if not source or not license_:
128
+ raise SystemExit(f"Row {idx} missing source/license information (device={device}, path={rel_path}).")
129
+
130
+ full_path = root / rel_path
131
+ if not full_path.exists():
132
+ raise SystemExit(f"Audio file referenced in metadata not found: {full_path}")
133
+
134
+ if not sha256:
135
+ current_hash = compute_sha256(full_path)
136
+ else:
137
+ current_hash = compute_sha256(full_path) if enforce_hashes else sha256
138
+ if enforce_hashes and current_hash != sha256:
139
+ raise SystemExit(
140
+ f"Hash mismatch for {rel_path}: metadata={sha256} current={current_hash}. "
141
+ "Regenerate metadata via scripts/refresh_metadata.py."
142
+ )
143
+
144
+ key = (rel_path, device)
145
+ if key in seen:
146
+ raise SystemExit(f"Duplicate clip/device combination detected in metadata: {rel_path} ({device})")
147
+ seen.add(key)
148
+
149
+ records.append(
150
+ ClipRecord(
151
+ path=full_path,
152
+ device=device,
153
+ source=source,
154
+ license=license_,
155
+ split=split,
156
+ sha256=current_hash if enforce_hashes else current_hash,
157
+ )
158
+ )
159
+
160
+ if include_devices:
161
+ for dev in include_devices:
162
+ if dev not in {record.device for record in records}:
163
+ raise SystemExit(f"No clips found for requested device: {dev}")
164
+
165
+ if not records:
166
+ raise SystemExit("No audio clips passed the metadata filters; nothing to train on.")
167
+
168
+ return records, root, metadata_path
169
+
170
+
171
+ def ensure_minimum_counts(records: Sequence[ClipRecord], minimum: int) -> Counter:
172
+ counts = Counter(record.device for record in records)
173
+ violations = {device: count for device, count in counts.items() if count < minimum}
174
+ if violations:
175
+ formatted = ", ".join(f"{dev} ({count})" for dev, count in violations.items())
176
+ raise SystemExit(f"Not enough clips per device. Increase data or lower the threshold. Offenders: {formatted}")
177
+ return counts
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
+ def summarise_records(records: Sequence[ClipRecord], root: Path) -> dict:
181
+ counts = Counter(record.device for record in records)
182
+ sources = {record.device: record.source for record in records}
183
+ licenses = {record.device: record.license for record in records}
184
+ return {
185
+ "total_clips": len(records),
186
+ "devices": dict(counts),
187
+ "sources": sources,
188
+ "licenses": licenses,
189
+ "first_five_hashes": [
190
+ {"path": record.relative_path(root), "sha256": record.sha256}
191
+ for record in records[: min(5, len(records))]
192
+ ],
193
+ }
194
+
195
+
196
+ def collect_hashes(records: Sequence[ClipRecord], root: Path) -> list[dict]:
197
+ return [
198
+ {"path": record.relative_path(root), "sha256": record.sha256}
199
+ for record in records
200
+ ]
201
+
202
+
203
+ def build_dataset(records: Sequence[ClipRecord]) -> tuple[np.ndarray, np.ndarray]:
204
+ features, labels = [], []
205
+ for record in records:
206
+ audio, sr = load_mono(record.path, sr=TARGET_SR)
207
+ feats = extract_features(audio, sr)
208
+ features.append(feats)
209
+ labels.append(record.device)
210
+ return np.array(features), np.array(labels)
211
+
212
+
213
+ def instantiate_classifier(cfg: dict) -> HistGradientBoostingClassifier:
214
+ clf_cfg = dict(cfg.get("classifier", {}))
215
+ random_state = cfg.get("random_state")
216
+ if random_state is not None:
217
+ clf_cfg.setdefault("random_state", random_state)
218
+ if not clf_cfg:
219
+ clf_cfg = {"max_depth": 10, "max_iter": 400, "learning_rate": 0.08}
220
+ if random_state is not None:
221
+ clf_cfg["random_state"] = random_state
222
+ return HistGradientBoostingClassifier(**clf_cfg)
223
+
224
+
225
+ def plot_confusion_matrix(cm: np.ndarray, labels: Sequence[str], output_path: Path) -> None:
226
+ fig, ax = plt.subplots(figsize=(5, 4))
227
+ im = ax.imshow(cm, cmap="Blues")
228
+ ax.set_xticks(range(len(labels)))
229
+ ax.set_xticklabels(labels, rotation=45, ha="right")
230
+ ax.set_yticks(range(len(labels)))
231
+ ax.set_yticklabels(labels)
232
+ for i in range(len(labels)):
233
+ for j in range(len(labels)):
234
+ ax.text(j, i, f"{cm[i, j]:.2f}", ha="center", va="center", fontsize=8)
235
+ ax.set_title("Confusion (normalized)")
236
+ fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
237
+ fig.tight_layout()
238
+ fig.savefig(output_path, dpi=160)
239
+ plt.close(fig)
240
+
241
+
242
+ def write_run_report(
243
+ reporting_cfg: dict,
244
+ config_path: Path,
245
+ config: dict,
246
+ records: Sequence[ClipRecord],
247
+ root: Path,
248
+ metrics: dict,
249
+ dataset_summary: dict,
250
+ hashes: Sequence[dict],
251
+ model_path: Path,
252
+ encoder_path: Path,
253
+ ) -> Path:
254
+ runs_dir = Path(reporting_cfg.get("runs_dir", REPORT_DIR / "runs")).resolve()
255
+ runs_dir.mkdir(parents=True, exist_ok=True)
256
+ now_utc = dt.datetime.now(dt.timezone.utc).replace(microsecond=0)
257
+ timestamp = now_utc.strftime("%Y%m%d-%H%M%S")
258
+ tag = reporting_cfg.get("tag")
259
+ filename = f"run-{timestamp}"
260
+ if tag:
261
+ filename += f"-{tag}"
262
+ run_path = runs_dir / f"{filename}.json"
263
+
264
+ payload = {
265
+ "timestamp_utc": now_utc.isoformat().replace("+00:00", "Z"),
266
+ "config_path": str(config_path.resolve()),
267
+ "config_snapshot": config,
268
+ "dataset": {
269
+ **dataset_summary,
270
+ "metadata_root": str(root),
271
+ "hashes": list(hashes),
272
+ },
273
+ "metrics": metrics,
274
+ "artefacts": {
275
+ "model": str(model_path),
276
+ "label_encoder": str(encoder_path),
277
+ "metrics_json": str(Path(reporting_cfg.get("metrics_path", REPORT_DIR / "metrics.json")).resolve()),
278
+ "confusion_matrix": str(Path(reporting_cfg.get("confusion_matrix_path", REPORT_DIR / "confusion_matrix.png")).resolve()),
279
+ },
280
+ }
281
+
282
+ with run_path.open("w", encoding="utf-8") as fh:
283
+ json.dump(payload, fh, indent=2)
284
+ return run_path
285
+
286
+
287
+ def main() -> None:
288
+ args = parse_args()
289
+ config_path = Path(args.config)
290
+ config = load_config(config_path)
291
+
292
+ data_cfg = config["data"]
293
+ training_cfg = config["training"]
294
+ reporting_cfg = config["reporting"]
295
+
296
+ records, data_root, metadata_path = load_clip_records(data_cfg)
297
+ min_clips = int(data_cfg.get("min_clips_per_device", 1))
298
+ ensure_minimum_counts(records, min_clips)
299
+ dataset_summary = summarise_records(records, data_root)
300
+ hashes = collect_hashes(records, data_root)
301
+ dataset_summary["metadata_file"] = str(metadata_path)
302
+
303
+ print("Dataset summary:")
304
+ for key, value in dataset_summary.items():
305
+ print(f" {key}: {value}")
306
+
307
+ if args.dry_run:
308
+ print("Dry run complete. Exiting without training.")
309
+ return
310
+
311
+ X, y = build_dataset(records)
312
+ label_encoder = LabelEncoder()
313
+ y_encoded = label_encoder.fit_transform(y)
314
+
315
+ test_size = float(training_cfg.get("test_size", 0.25))
316
+ random_state = training_cfg.get("random_state", 42)
317
+ stratify = training_cfg.get("stratify", True)
318
+ stratify_arg = y_encoded if stratify else None
319
+
320
+ X_train, X_test, y_train, y_test = train_test_split(
321
+ X,
322
+ y_encoded,
323
+ test_size=test_size,
324
+ stratify=stratify_arg,
325
+ random_state=random_state,
326
+ )
327
+
328
+ clf = instantiate_classifier(training_cfg)
329
+ clf.fit(X_train, y_train)
330
+ y_pred = clf.predict(X_test)
331
+
332
+ report = classification_report(y_test, y_pred, target_names=label_encoder.classes_, output_dict=True)
333
+ metrics_path = Path(reporting_cfg.get("metrics_path", REPORT_DIR / "metrics.json"))
334
+ with metrics_path.open("w", encoding="utf-8") as fh:
335
+ json.dump(report, fh, indent=2)
336
+
337
+ cm = confusion_matrix(y_test, y_pred, normalize="true")
338
+ confusion_path = Path(reporting_cfg.get("confusion_matrix_path", REPORT_DIR / "confusion_matrix.png"))
339
+ plot_confusion_matrix(cm, label_encoder.classes_, confusion_path)
340
+
341
+ # Clean up non-serializable RNG to keep joblib artefacts deterministic.
342
  if hasattr(clf, "_feature_subsample_rng"):
343
  clf._feature_subsample_rng = None
344
 
345
+ model_path = MODEL_DIR / "model.pkl"
346
+ encoder_path = MODEL_DIR / "label_encoder.pkl"
347
+ joblib.dump(clf, model_path)
348
+ joblib.dump(label_encoder, encoder_path)
349
+
350
+ run_report_path = write_run_report(
351
+ reporting_cfg,
352
+ config_path,
353
+ config,
354
+ records,
355
+ data_root,
356
+ report,
357
+ dataset_summary,
358
+ hashes,
359
+ model_path,
360
+ encoder_path,
361
+ )
362
+
363
  print("Saved model + reports.")
364
+ print(f"Run snapshot written to {run_report_path}")
365
+
366
+
367
+ if __name__ == "__main__":
368
+ main()