connork commited on
Commit ·
b6c1b75
1
Parent(s): 8d73c70
Align Space with latest Mic-ID release
Browse files- .gitignore +1 -0
- README.md +23 -15
- configs/base.yaml +43 -0
- data/metadata.csv +10 -0
- devices.py +3 -1
- docs/clip02-misclassification.md +47 -0
- docs/data-sourcing.md +6 -0
- features.py +1 -0
- models/label_encoder.pkl +2 -2
- models/model.pkl +2 -2
- predict.py +31 -2
- reports/confusion_matrix.png +2 -2
- reports/metrics.json +30 -24
- requirements.txt +1 -0
- scripts/refresh_metadata.py +161 -0
- train.py +333 -69
.gitignore
CHANGED
|
@@ -6,3 +6,4 @@ __pycache__/
|
|
| 6 |
.DS_Store
|
| 7 |
uploads/
|
| 8 |
.tmp/
|
|
|
|
|
|
| 6 |
.DS_Store
|
| 7 |
uploads/
|
| 8 |
.tmp/
|
| 9 |
+
senior-ml-engineer-script.md
|
README.md
CHANGED
|
@@ -54,7 +54,8 @@ If you are running a live session, keep this script handy:
|
|
| 54 |
python3 -m venv .venv
|
| 55 |
source .venv/bin/activate
|
| 56 |
pip install -r requirements.txt
|
| 57 |
-
|
|
|
|
| 58 |
```
|
| 59 |
|
| 60 |
Then launch the app with `streamlit run app.py` (defaults to http://localhost:8501).
|
|
@@ -81,8 +82,8 @@ Each control includes inline help text so presenters can improvise without notes
|
|
| 81 |
|
| 82 |
## Device Recognition
|
| 83 |
- 🧱 Audio flows through `features.extract_features`, stitching log-mel and MFCC statistics with zero-crossing, centroid, roll-off, and flatness cues.
|
| 84 |
-
- 🌲 `train.py`
|
| 85 |
-
- 📈 Every training run exports `reports/metrics.json`
|
| 86 |
- 🏷️ The app and CLI surface friendly names (e.g. “Zoom F8 field recorder”) pulled from `devices.describe_label()` to keep the story human-readable.
|
| 87 |
|
| 88 |
## Scale Detection
|
|
@@ -95,31 +96,36 @@ All sample audio lives under `data/` and mirrors the device IDs referenced in th
|
|
| 95 |
|
| 96 |
| Folder | What it represents | Count* |
|
| 97 |
| --- | --- | --- |
|
| 98 |
-
| `audio/` | TAU Urban Acoustic Scenes clips (device A) – Zoom F8 field recorder |
|
| 99 |
-
| `audio2/` | TAU Urban Acoustic Scenes clips (device B) – Samsung Galaxy S7 |
|
| 100 |
-
| `audio9/` | TAU Urban Acoustic Scenes clips (device C) – iPhone SE |
|
| 101 |
-
| `iphone/` | Locally recorded iPhone speech snippets captured with `utils.py` |
|
| 102 |
-
| `laptop/` | MacBook built-in mic samples recorded in a treated room |
|
| 103 |
-
| `outtakes/` | Extra captures you can promote into training data after curation |
|
| 104 |
|
| 105 |
-
⋆
|
| 106 |
|
| 107 |
## Download Contents
|
| 108 |
Every run generates artefacts you can drop into a slide deck or share with collaborators:
|
| 109 |
|
| 110 |
- 🎯 `models/model.pkl` and `models/label_encoder.pkl` store the trained classifier and label map.
|
| 111 |
- 📊 `reports/metrics.json` plus `reports/confusion_matrix.png` capture evaluation snapshots for the latest training session.
|
|
|
|
|
|
|
| 112 |
- 📁 Uploaded clips are preserved under `uploads/hooks - <original-name>` so you can replay or re-label them later.
|
| 113 |
|
| 114 |
## Testing
|
| 115 |
Quick smoke checks live in the scripts themselves:
|
| 116 |
|
| 117 |
```bash
|
| 118 |
-
#
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
# Score a few clips and verify probabilities look sane
|
| 122 |
-
|
| 123 |
```
|
| 124 |
|
| 125 |
For deeper regression coverage, wire these commands into your CI and compare the resulting metrics JSON against previous baselines.
|
|
@@ -130,9 +136,11 @@ mic-id/
|
|
| 130 |
├─ app.py # Streamlit UI for uploading and scoring clips
|
| 131 |
├─ predict.py # CLI scorer with friendly device names
|
| 132 |
├─ train.py # Dataset loader, model trainer, metric exporter
|
|
|
|
| 133 |
├─ features.py # Audio feature extraction helpers
|
| 134 |
├─ utils.py # Command-line recorder for new device samples
|
| 135 |
-
├─ data/ # Per-device waveforms
|
|
|
|
| 136 |
├─ models/ # Saved classifier + label encoder
|
| 137 |
├─ reports/ # Metrics JSON and confusion matrix plots
|
| 138 |
├─ docs/ # Data sourcing guide and prep notes
|
|
@@ -143,7 +151,7 @@ mic-id/
|
|
| 143 |
## Roadmap
|
| 144 |
- 🛰️ Add a lightweight CNN baseline alongside the gradient boosting model for comparison.
|
| 145 |
- 🧪 Ship augmentation scripts (noise, EQ, impulse responses) to spotlight microphone colouration differences.
|
| 146 |
-
- 🔐
|
| 147 |
- 📦 Polish export helpers so the app can bundle probabilities + features in one download.
|
| 148 |
|
| 149 |
## Contributing
|
|
|
|
| 54 |
python3 -m venv .venv
|
| 55 |
source .venv/bin/activate
|
| 56 |
pip install -r requirements.txt
|
| 57 |
+
python3 scripts/refresh_metadata.py # rebuild hashes + provenance records
|
| 58 |
+
python3 train.py --config configs/base.yaml # optional if you want to refresh the model
|
| 59 |
```
|
| 60 |
|
| 61 |
Then launch the app with `streamlit run app.py` (defaults to http://localhost:8501).
|
|
|
|
| 82 |
|
| 83 |
## Device Recognition
|
| 84 |
- 🧱 Audio flows through `features.extract_features`, stitching log-mel and MFCC statistics with zero-crossing, centroid, roll-off, and flatness cues.
|
| 85 |
+
- 🌲 `python3 train.py --config configs/base.yaml` reads the provenance metadata, enforces per-device clip minimums, and fits a `HistGradientBoostingClassifier` before saving artefacts to `models/model.pkl` plus the label encoder.
|
| 86 |
+
- 📈 Every training run exports `reports/metrics.json`, `reports/confusion_matrix.png`, and a timestamped `reports/runs/run-*.json` snapshot so you can cite precision/recall live.
|
| 87 |
- 🏷️ The app and CLI surface friendly names (e.g. “Zoom F8 field recorder”) pulled from `devices.describe_label()` to keep the story human-readable.
|
| 88 |
|
| 89 |
## Scale Detection
|
|
|
|
| 96 |
|
| 97 |
| Folder | What it represents | Count* |
|
| 98 |
| --- | --- | --- |
|
| 99 |
+
| `audio/` | TAU Urban Acoustic Scenes clips (device A) – Zoom F8 field recorder | 295 |
|
| 100 |
+
| `audio2/` | TAU Urban Acoustic Scenes clips (device B) – Samsung Galaxy S7 | 295 |
|
| 101 |
+
| `audio9/` | TAU Urban Acoustic Scenes clips (device C) – iPhone SE | 295 |
|
| 102 |
+
| `iphone/` | Locally recorded iPhone speech snippets captured with `utils.py` | 4 |
|
| 103 |
+
| `laptop/` | MacBook built-in mic samples recorded in a treated room | 4 |
|
| 104 |
+
| `outtakes/` | Extra captures you can promote into training data after curation | varies |
|
| 105 |
|
| 106 |
+
⋆Counts based on the current repo snapshot; refresh `data/` to rebalance as needed.
|
| 107 |
|
| 108 |
## Download Contents
|
| 109 |
Every run generates artefacts you can drop into a slide deck or share with collaborators:
|
| 110 |
|
| 111 |
- 🎯 `models/model.pkl` and `models/label_encoder.pkl` store the trained classifier and label map.
|
| 112 |
- 📊 `reports/metrics.json` plus `reports/confusion_matrix.png` capture evaluation snapshots for the latest training session.
|
| 113 |
+
- 🧾 `data/metadata.csv` tracks every clip’s provenance, licence, and hash for reproducible retrains.
|
| 114 |
+
- 🗂️ `reports/runs/run-*.json` snapshots record the exact config, dataset summary, and hashes used for each training run.
|
| 115 |
- 📁 Uploaded clips are preserved under `uploads/hooks - <original-name>` so you can replay or re-label them later.
|
| 116 |
|
| 117 |
## Testing
|
| 118 |
Quick smoke checks live in the scripts themselves:
|
| 119 |
|
| 120 |
```bash
|
| 121 |
+
# Validate provenance without training
|
| 122 |
+
python3 train.py --dry-run
|
| 123 |
+
|
| 124 |
+
# Rebuild the model, metrics, and run snapshot
|
| 125 |
+
python3 train.py --config configs/base.yaml
|
| 126 |
|
| 127 |
# Score a few clips and verify probabilities look sane
|
| 128 |
+
python3 predict.py data/laptop/clip_01.wav data/iphone/clip_05.wav --topk 5
|
| 129 |
```
|
| 130 |
|
| 131 |
For deeper regression coverage, wire these commands into your CI and compare the resulting metrics JSON against previous baselines.
|
|
|
|
| 136 |
├─ app.py # Streamlit UI for uploading and scoring clips
|
| 137 |
├─ predict.py # CLI scorer with friendly device names
|
| 138 |
├─ train.py # Dataset loader, model trainer, metric exporter
|
| 139 |
+
├─ configs/ # YAML training configs + device provenance defaults
|
| 140 |
├─ features.py # Audio feature extraction helpers
|
| 141 |
├─ utils.py # Command-line recorder for new device samples
|
| 142 |
+
├─ data/ # Per-device waveforms and provenance metadata
|
| 143 |
+
│ └─ metadata.csv # Clip-level provenance (source/licence/hash)
|
| 144 |
├─ models/ # Saved classifier + label encoder
|
| 145 |
├─ reports/ # Metrics JSON and confusion matrix plots
|
| 146 |
├─ docs/ # Data sourcing guide and prep notes
|
|
|
|
| 151 |
## Roadmap
|
| 152 |
- 🛰️ Add a lightweight CNN baseline alongside the gradient boosting model for comparison.
|
| 153 |
- 🧪 Ship augmentation scripts (noise, EQ, impulse responses) to spotlight microphone colouration differences.
|
| 154 |
+
- 🔐 Wire metadata/hash validation into CI so new clips are rejected unless provenance is complete.
|
| 155 |
- 📦 Polish export helpers so the app can bundle probabilities + features in one download.
|
| 156 |
|
| 157 |
## Contributing
|
configs/base.yaml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
root: data
|
| 3 |
+
metadata: data/metadata.csv
|
| 4 |
+
enforce_hashes: true
|
| 5 |
+
min_clips_per_device: 1
|
| 6 |
+
include_devices:
|
| 7 |
+
- audio
|
| 8 |
+
- audio2
|
| 9 |
+
- audio9
|
| 10 |
+
- iphone
|
| 11 |
+
- laptop
|
| 12 |
+
splits:
|
| 13 |
+
- train
|
| 14 |
+
device_defaults:
|
| 15 |
+
iphone:
|
| 16 |
+
source: "In-house recordings"
|
| 17 |
+
license: "Private use"
|
| 18 |
+
laptop:
|
| 19 |
+
source: "In-house recordings"
|
| 20 |
+
license: "Private use"
|
| 21 |
+
audio:
|
| 22 |
+
source: "TAU Urban Acoustic Scenes 2019 Mobile"
|
| 23 |
+
license: "CC-BY 4.0"
|
| 24 |
+
audio2:
|
| 25 |
+
source: "TAU Urban Acoustic Scenes 2019 Mobile"
|
| 26 |
+
license: "CC-BY 4.0"
|
| 27 |
+
audio9:
|
| 28 |
+
source: "TAU Urban Acoustic Scenes 2019 Mobile"
|
| 29 |
+
license: "CC-BY 4.0"
|
| 30 |
+
|
| 31 |
+
training:
|
| 32 |
+
test_size: 0.25
|
| 33 |
+
random_state: 42
|
| 34 |
+
classifier:
|
| 35 |
+
max_depth: 10
|
| 36 |
+
max_iter: 400
|
| 37 |
+
learning_rate: 0.08
|
| 38 |
+
|
| 39 |
+
reporting:
|
| 40 |
+
metrics_path: reports/metrics.json
|
| 41 |
+
confusion_matrix_path: reports/confusion_matrix.png
|
| 42 |
+
runs_dir: reports/runs
|
| 43 |
+
tag: baseline
|
data/metadata.csv
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
path,device,source,license,split,sha256
|
| 2 |
+
audio/airport-helsinki-204-6138-a.wav,audio,TAU Urban Acoustic Scenes 2019 Mobile,CC-BY 4.0,train,db356757394c3ed66d87990ed98a080b1a1a1778aaaffe110b723c9fbd294814
|
| 3 |
+
audio/airport-lisbon-175-4700-a.wav,audio,TAU Urban Acoustic Scenes 2019 Mobile,CC-BY 4.0,train,ebcea04001ff88fd63af3e76d153ae2d72f3ea65889f3a215c894ca14ce173be
|
| 4 |
+
audio2/bus-stockholm-35-1041-a.wav,audio2,TAU Urban Acoustic Scenes 2019 Mobile,CC-BY 4.0,train,b76dfe5d2d25b4b7912af63110d11f32623cb8308872cd601cc1fbe6daac8ef8
|
| 5 |
+
audio2/bus-stockholm-35-1041-b.wav,audio2,TAU Urban Acoustic Scenes 2019 Mobile,CC-BY 4.0,train,9eb1bf29c4055d9b7863f4395b59a177873c1fb00f6e16f026597643f1339742
|
| 6 |
+
audio9/street_pedestrian-london-149-4500-c.wav,audio9,TAU Urban Acoustic Scenes 2019 Mobile,CC-BY 4.0,train,42ce95e42426e18ae1f25174147bfa799644a3064dcad7072b744239cef134af
|
| 7 |
+
iphone/clip_01.wav,iphone,In-house recordings,Private use,train,fe9b1dc52cd1eb21550847ba08b2c2ddc79443c378ce88945b55a4de9c3656bf
|
| 8 |
+
iphone/clip_05.wav,iphone,In-house recordings,Private use,train,017691167b2b7e93fe52ce7e643ca76767986c478e4efe4c2a66bbbfaee2c99a
|
| 9 |
+
laptop/clip_01.wav,laptop,In-house recordings,Private use,train,f163fd7dc320b3c7ede45104fadff2f90d795f740c9a59156b8cb71613c9f773
|
| 10 |
+
laptop/clip_05.wav,laptop,In-house recordings,Private use,train,56d5a2ca715e1dc1f08f02d619c1a1c770ea60378b721638f9f2aeffb4829233
|
devices.py
CHANGED
|
@@ -3,7 +3,9 @@ MIC_FRIENDLY_NAMES = {
|
|
| 3 |
"audio2": "Samsung Galaxy S7 (TAU device B)",
|
| 4 |
"audio9": "iPhone SE (TAU device C)",
|
| 5 |
"iphone": "Local iPhone recordings",
|
| 6 |
-
|
|
|
|
|
|
|
| 7 |
}
|
| 8 |
|
| 9 |
|
|
|
|
| 3 |
"audio2": "Samsung Galaxy S7 (TAU device B)",
|
| 4 |
"audio9": "iPhone SE (TAU device C)",
|
| 5 |
"iphone": "Local iPhone recordings",
|
| 6 |
+
# These clips were captured both with the MacBook mic and AirPods Pro;
|
| 7 |
+
# keep the class label stable but surface the combined description.
|
| 8 |
+
"laptop": "AirPods Pro / MacBook built-in microphone",
|
| 9 |
}
|
| 10 |
|
| 11 |
|
docs/clip02-misclassification.md
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Clip 02 Misclassification Case Study
|
| 2 |
+
|
| 3 |
+
## Issue Summary
|
| 4 |
+
- **Symptom**: `python predict.py data/iphone/*.wav` classified `data/iphone/clip_02.wav` as “MacBook built-in microphone” (~53 %) instead of “Local iPhone recordings” (≈47 %).
|
| 5 |
+
- **Impact**: Undermined trust in the classifier for quiet iPhone speech, indicating poor separation between the iPhone and AirPods/Mac classes.
|
| 6 |
+
|
| 7 |
+
## Investigation
|
| 8 |
+
- Confirmed the mismatch reproduced after the first training run with the new TAU batches.
|
| 9 |
+
- Compared class distributions via `train.py --dry-run`; highlighted severe imbalance: TAU devices (≈295 clips each) vs. iPhone (15 wav + 47 m4a) vs. AirPods/Mac (15 wav + 14 m4a).
|
| 10 |
+
- Noted identical feature extraction between training and inference (`features.extract_features`), driving suspicion toward data coverage rather than pipeline drift.
|
| 11 |
+
|
| 12 |
+
## Actions Taken
|
| 13 |
+
1. **Data Organisation**
|
| 14 |
+
- Split the TAU Mobile archive into `data/audio`, `data/audio2`, and `data/audio9` based on filename suffixes (`-a/-b/-c`).
|
| 15 |
+
- Normalised provenance defaults in `configs/base.yaml` for the new device buckets.
|
| 16 |
+
2. **Metadata Refresh**
|
| 17 |
+
- Ran `python3 scripts/refresh_metadata.py --config configs/base.yaml` to register hashes and sources for all clips (including new iPhone/AirPods captures).
|
| 18 |
+
- Repeated after each data ingest to keep `data/metadata.csv` consistent.
|
| 19 |
+
3. **Model Retraining**
|
| 20 |
+
- Executed `python train.py` to rebuild `models/model.pkl` and `models/label_encoder.pkl` with the expanded dataset (990 clips total).
|
| 21 |
+
4. **Inference UX Improvements**
|
| 22 |
+
- Allowed directory inputs in `predict.py` so `python predict.py data/iphone` expands automatically.
|
| 23 |
+
- Updated the “laptop” friendly name to “AirPods Pro / MacBook built-in microphone” to reflect the mixed capture source.
|
| 24 |
+
|
| 25 |
+
## Verification
|
| 26 |
+
- Post-retrain prediction:
|
| 27 |
+
```
|
| 28 |
+
File: data/iphone/clip_02.wav
|
| 29 |
+
RMS loudness: -40.8 dBFS
|
| 30 |
+
1. Local iPhone recordings — 96.1%
|
| 31 |
+
2. AirPods Pro / MacBook built-in microphone — 3.9%
|
| 32 |
+
3. Samsung Galaxy S7 (TAU device B) — 0.0%
|
| 33 |
+
```
|
| 34 |
+
- The confidence inversion (≈96 % iPhone) confirms the classifier now separates the classes even for low-level speech content.
|
| 35 |
+
|
| 36 |
+
## Feature Changes for Improved Results
|
| 37 |
+
- `configs/base.yaml`: added TAU device folders to `include_devices` and defined CC-BY provenance defaults.
|
| 38 |
+
- `data/metadata.csv`: regenerated with 990 entries to incorporate the new recordings (62 iPhone, 29 AirPods/Mac).
|
| 39 |
+
- `devices.py`: renamed the “laptop” label to “AirPods Pro / MacBook built-in microphone” for accurate reporting.
|
| 40 |
+
- `predict.py`: added directory expansion and broader audio-extension support to streamline batch evaluation.
|
| 41 |
+
- Dataset restructuring: migrated TAU archive clips into `data/audio`, `data/audio2`, `data/audio9` directories, preserving the `-a/-b/-c` microphone mapping.
|
| 42 |
+
|
| 43 |
+
## Follow-Up Recommendations
|
| 44 |
+
- Continue collecting parallel iPhone vs. AirPods recordings, especially in quiet environments, until class counts approach parity with TAU devices.
|
| 45 |
+
- Maintain a held-out validation set (not yet captured) to quantify gains objectively beyond spot checks.
|
| 46 |
+
- Document future ingestion runs by appending to this case study or a dedicated experiment log under `docs/`.
|
| 47 |
+
|
docs/data-sourcing.md
CHANGED
|
@@ -62,4 +62,10 @@ Mic-ID works best when every class corresponds to a capture device that has enou
|
|
| 62 |
2. Store downloaded archives under `data/raw/` (ignored by git) and export processed clips to `data/<device>/`.
|
| 63 |
3. Update `metadata.csv` whenever you add or remove external clips so the experiment log in `reports/` stays reproducible.
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
For more ideas, browse the DCASE and ASVspoof challenge leaderboards—winning teams usually publish their data prep notes and often release additional impulse responses or parallel recordings.
|
|
|
|
| 62 |
2. Store downloaded archives under `data/raw/` (ignored by git) and export processed clips to `data/<device>/`.
|
| 63 |
3. Update `metadata.csv` whenever you add or remove external clips so the experiment log in `reports/` stays reproducible.
|
| 64 |
|
| 65 |
+
## Provenance workflow
|
| 66 |
+
- Run `python3 scripts/refresh_metadata.py` after adding or trimming clips to recompute SHA256 hashes and populate default source/licence values.
|
| 67 |
+
- Manually edit `data/metadata.csv` when a clip needs corrected credits or licence text; the training step will refuse to run if either field is missing.
|
| 68 |
+
- Validate the metadata without training by running `python3 train.py --dry-run`; this catches missing files, hash mismatches, and low clip counts early.
|
| 69 |
+
- Commit both the metadata file and the resulting `reports/runs/run-*.json` snapshot so collaborators can audit exactly which audio went into each checkpoint.
|
| 70 |
+
|
| 71 |
For more ideas, browse the DCASE and ASVspoof challenge leaderboards—winning teams usually publish their data prep notes and often release additional impulse responses or parallel recordings.
|
features.py
CHANGED
|
@@ -2,6 +2,7 @@ import numpy as np, librosa
|
|
| 2 |
|
| 3 |
|
| 4 |
def load_mono(path, sr=16000):
|
|
|
|
| 5 |
x, sr = librosa.load(path, sr=sr, mono=True)
|
| 6 |
x, _ = librosa.effects.trim(x, top_db=30)
|
| 7 |
rms = np.sqrt(np.mean(x**2)) + 1e-8
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
def load_mono(path, sr=16000):
|
| 5 |
+
path = str(path)
|
| 6 |
x, sr = librosa.load(path, sr=sr, mono=True)
|
| 7 |
x, _ = librosa.effects.trim(x, top_db=30)
|
| 8 |
rms = np.sqrt(np.mean(x**2)) + 1e-8
|
models/label_encoder.pkl
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:af7684de27332a4a68cc6d4f75511e9377f8243e925ed8415cfbef651d76ce76
|
| 3 |
+
size 663
|
models/model.pkl
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:526a822cbd88b1a060e845ef99bac697ac116f2afeb6edd0c06a610d5bf23211
|
| 3 |
+
size 2967440
|
predict.py
CHANGED
|
@@ -7,6 +7,7 @@ import argparse
|
|
| 7 |
import io
|
| 8 |
import os
|
| 9 |
from pathlib import Path
|
|
|
|
| 10 |
|
| 11 |
import joblib
|
| 12 |
import librosa
|
|
@@ -26,6 +27,7 @@ from devices import describe_label
|
|
| 26 |
|
| 27 |
MODEL_PATH = Path("models/model.pkl")
|
| 28 |
ENCODER_PATH = Path("models/label_encoder.pkl")
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
def load_model():
|
|
@@ -52,16 +54,43 @@ def normalise_audio(y: np.ndarray) -> np.ndarray:
|
|
| 52 |
return y * (0.05 / rms), rms
|
| 53 |
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
def main() -> None:
|
| 56 |
parser = argparse.ArgumentParser(description="Score WAV/MP3/M4A clips with the Mic-ID classifier.")
|
| 57 |
-
parser.add_argument(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
parser.add_argument("--topk", type=int, default=3, help="How many ranked predictions to show per file")
|
| 59 |
args = parser.parse_args()
|
| 60 |
|
| 61 |
clf, le = load_model()
|
| 62 |
topk = max(1, min(args.topk, len(le.classes_)))
|
| 63 |
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
if not path.exists():
|
| 66 |
print(f"[!] Skipping missing file: {path}")
|
| 67 |
continue
|
|
|
|
| 7 |
import io
|
| 8 |
import os
|
| 9 |
from pathlib import Path
|
| 10 |
+
from typing import Iterable, List
|
| 11 |
|
| 12 |
import joblib
|
| 13 |
import librosa
|
|
|
|
| 27 |
|
| 28 |
MODEL_PATH = Path("models/model.pkl")
|
| 29 |
ENCODER_PATH = Path("models/label_encoder.pkl")
|
| 30 |
+
AUDIO_EXTENSIONS = {".wav", ".mp3", ".m4a", ".flac", ".ogg"}
|
| 31 |
|
| 32 |
|
| 33 |
def load_model():
|
|
|
|
| 54 |
return y * (0.05 / rms), rms
|
| 55 |
|
| 56 |
|
| 57 |
+
def discover_inputs(paths: Iterable[Path]) -> List[Path]:
|
| 58 |
+
"""Expand directories into audio files, preserving explicit file ordering."""
|
| 59 |
+
collected: list[Path] = []
|
| 60 |
+
for path in paths:
|
| 61 |
+
if path.is_dir():
|
| 62 |
+
matches = sorted(
|
| 63 |
+
p for p in path.rglob("*")
|
| 64 |
+
if p.is_file() and p.suffix.lower() in AUDIO_EXTENSIONS
|
| 65 |
+
)
|
| 66 |
+
if not matches:
|
| 67 |
+
print(f"[!] No audio files found under directory: {path}")
|
| 68 |
+
continue
|
| 69 |
+
collected.extend(matches)
|
| 70 |
+
else:
|
| 71 |
+
collected.append(path)
|
| 72 |
+
return collected
|
| 73 |
+
|
| 74 |
+
|
| 75 |
def main() -> None:
|
| 76 |
parser = argparse.ArgumentParser(description="Score WAV/MP3/M4A clips with the Mic-ID classifier.")
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
"paths",
|
| 79 |
+
nargs="+",
|
| 80 |
+
type=Path,
|
| 81 |
+
help="Audio files or directories containing audio to score",
|
| 82 |
+
)
|
| 83 |
parser.add_argument("--topk", type=int, default=3, help="How many ranked predictions to show per file")
|
| 84 |
args = parser.parse_args()
|
| 85 |
|
| 86 |
clf, le = load_model()
|
| 87 |
topk = max(1, min(args.topk, len(le.classes_)))
|
| 88 |
|
| 89 |
+
inputs = discover_inputs(args.paths)
|
| 90 |
+
if not inputs:
|
| 91 |
+
raise SystemExit("No valid audio inputs found. Provide files or directories with supported formats.")
|
| 92 |
+
|
| 93 |
+
for path in inputs:
|
| 94 |
if not path.exists():
|
| 95 |
print(f"[!] Skipping missing file: {path}")
|
| 96 |
continue
|
reports/confusion_matrix.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
reports/metrics.json
CHANGED
|
@@ -1,45 +1,51 @@
|
|
| 1 |
{
|
| 2 |
"audio": {
|
| 3 |
-
"precision": 0.
|
| 4 |
-
"recall": 0.
|
| 5 |
-
"f1-score": 0.
|
| 6 |
"support": 74.0
|
| 7 |
},
|
| 8 |
"audio2": {
|
| 9 |
-
"precision": 0.
|
| 10 |
"recall": 0.9864864864864865,
|
| 11 |
-
"f1-score": 0.
|
| 12 |
"support": 74.0
|
| 13 |
},
|
| 14 |
"audio9": {
|
| 15 |
-
"precision": 0.
|
| 16 |
-
"recall": 0.
|
| 17 |
-
"f1-score": 0.
|
| 18 |
"support": 74.0
|
| 19 |
},
|
| 20 |
"iphone": {
|
| 21 |
-
"precision":
|
| 22 |
-
"recall": 0.
|
| 23 |
-
"f1-score": 0.
|
| 24 |
-
"support":
|
| 25 |
},
|
| 26 |
"laptop": {
|
| 27 |
-
"precision":
|
| 28 |
-
"recall":
|
| 29 |
-
"f1-score":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
"support": 3.0
|
| 31 |
},
|
| 32 |
-
"accuracy": 0.
|
| 33 |
"macro avg": {
|
| 34 |
-
"precision": 0.
|
| 35 |
-
"recall": 0.
|
| 36 |
-
"f1-score": 0.
|
| 37 |
-
"support":
|
| 38 |
},
|
| 39 |
"weighted avg": {
|
| 40 |
-
"precision": 0.
|
| 41 |
-
"recall": 0.
|
| 42 |
-
"f1-score": 0.
|
| 43 |
-
"support":
|
| 44 |
}
|
| 45 |
}
|
|
|
|
| 1 |
{
|
| 2 |
"audio": {
|
| 3 |
+
"precision": 0.9473684210526315,
|
| 4 |
+
"recall": 0.972972972972973,
|
| 5 |
+
"f1-score": 0.96,
|
| 6 |
"support": 74.0
|
| 7 |
},
|
| 8 |
"audio2": {
|
| 9 |
+
"precision": 0.9733333333333334,
|
| 10 |
"recall": 0.9864864864864865,
|
| 11 |
+
"f1-score": 0.9798657718120806,
|
| 12 |
"support": 74.0
|
| 13 |
},
|
| 14 |
"audio9": {
|
| 15 |
+
"precision": 0.9594594594594594,
|
| 16 |
+
"recall": 0.9594594594594594,
|
| 17 |
+
"f1-score": 0.9594594594594594,
|
| 18 |
"support": 74.0
|
| 19 |
},
|
| 20 |
"iphone": {
|
| 21 |
+
"precision": 0.9375,
|
| 22 |
+
"recall": 0.9375,
|
| 23 |
+
"f1-score": 0.9375,
|
| 24 |
+
"support": 16.0
|
| 25 |
},
|
| 26 |
"laptop": {
|
| 27 |
+
"precision": 1.0,
|
| 28 |
+
"recall": 1.0,
|
| 29 |
+
"f1-score": 1.0,
|
| 30 |
+
"support": 7.0
|
| 31 |
+
},
|
| 32 |
+
"outtakes : new": {
|
| 33 |
+
"precision": 0.0,
|
| 34 |
+
"recall": 0.0,
|
| 35 |
+
"f1-score": 0.0,
|
| 36 |
"support": 3.0
|
| 37 |
},
|
| 38 |
+
"accuracy": 0.9596774193548387,
|
| 39 |
"macro avg": {
|
| 40 |
+
"precision": 0.802943535640904,
|
| 41 |
+
"recall": 0.8094031531531533,
|
| 42 |
+
"f1-score": 0.8061375385452566,
|
| 43 |
+
"support": 248.0
|
| 44 |
},
|
| 45 |
"weighted avg": {
|
| 46 |
+
"precision": 0.9481126202603283,
|
| 47 |
+
"recall": 0.9596774193548387,
|
| 48 |
+
"f1-score": 0.9538309157826369,
|
| 49 |
+
"support": 248.0
|
| 50 |
}
|
| 51 |
}
|
requirements.txt
CHANGED
|
@@ -7,3 +7,4 @@ numpy
|
|
| 7 |
pandas
|
| 8 |
matplotlib
|
| 9 |
joblib
|
|
|
|
|
|
| 7 |
pandas
|
| 8 |
matplotlib
|
| 9 |
joblib
|
| 10 |
+
pyyaml
|
scripts/refresh_metadata.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Generate or refresh data/metadata.csv entries with provenance details."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import csv
|
| 8 |
+
import hashlib
|
| 9 |
+
import sys
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Dict, Iterable, Optional
|
| 13 |
+
|
| 14 |
+
import yaml
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
DEFAULT_EXTENSIONS = {".wav", ".mp3", ".m4a"}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class MetadataRow:
|
| 22 |
+
path: Path
|
| 23 |
+
device: str
|
| 24 |
+
source: str
|
| 25 |
+
license: str
|
| 26 |
+
split: str
|
| 27 |
+
sha256: str
|
| 28 |
+
|
| 29 |
+
def as_dict(self, root: Path) -> Dict[str, str]:
|
| 30 |
+
rel_path = self.path.relative_to(root).as_posix()
|
| 31 |
+
return {
|
| 32 |
+
"path": rel_path,
|
| 33 |
+
"device": self.device,
|
| 34 |
+
"source": self.source,
|
| 35 |
+
"license": self.license,
|
| 36 |
+
"split": self.split,
|
| 37 |
+
"sha256": self.sha256,
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def parse_args() -> argparse.Namespace:
|
| 42 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
| 43 |
+
parser.add_argument("--config", default="configs/base.yaml", help="YAML config that defines data root and defaults.")
|
| 44 |
+
parser.add_argument("--output", help="Override output metadata CSV path. Defaults to the config value.")
|
| 45 |
+
parser.add_argument("--extensions", nargs="*", help="File extensions to include (e.g., .wav .mp3 .m4a). Defaults to built-ins.")
|
| 46 |
+
return parser.parse_args()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def load_config(path: Path) -> dict:
|
| 50 |
+
if not path.exists():
|
| 51 |
+
raise SystemExit(f"Config not found: {path}")
|
| 52 |
+
with path.open("r", encoding="utf-8") as fh:
|
| 53 |
+
cfg = yaml.safe_load(fh) or {}
|
| 54 |
+
if "data" not in cfg:
|
| 55 |
+
raise SystemExit("Config is missing a `data` section.")
|
| 56 |
+
return cfg
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def read_existing_metadata(path: Path) -> Dict[str, dict]:
|
| 60 |
+
if not path.exists():
|
| 61 |
+
return {}
|
| 62 |
+
with path.open("r", encoding="utf-8", newline="") as fh:
|
| 63 |
+
reader = csv.DictReader(fh)
|
| 64 |
+
return {row["path"]: row for row in reader if "path" in row}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def compute_sha256(path: Path) -> str:
|
| 68 |
+
hasher = hashlib.sha256()
|
| 69 |
+
with path.open("rb") as fh:
|
| 70 |
+
for chunk in iter(lambda: fh.read(8192), b""):
|
| 71 |
+
hasher.update(chunk)
|
| 72 |
+
return hasher.hexdigest()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def gather_files(root: Path, extensions: Iterable[str]) -> Iterable[Path]:
|
| 76 |
+
for file_path in root.rglob("*"):
|
| 77 |
+
if not file_path.is_file():
|
| 78 |
+
continue
|
| 79 |
+
if file_path.suffix.lower() in extensions:
|
| 80 |
+
yield file_path
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def build_rows(
|
| 84 |
+
files: Iterable[Path],
|
| 85 |
+
existing_rows: Dict[str, dict],
|
| 86 |
+
root: Path,
|
| 87 |
+
device_defaults: Optional[dict],
|
| 88 |
+
include_devices: Optional[set[str]],
|
| 89 |
+
) -> Iterable[MetadataRow]:
|
| 90 |
+
for path in files:
|
| 91 |
+
rel_key = path.relative_to(root).as_posix()
|
| 92 |
+
parts = path.relative_to(root).parts
|
| 93 |
+
if not parts:
|
| 94 |
+
continue
|
| 95 |
+
device = parts[0]
|
| 96 |
+
if include_devices and device not in include_devices:
|
| 97 |
+
continue
|
| 98 |
+
|
| 99 |
+
defaults = (device_defaults or {}).get(device, {})
|
| 100 |
+
existing = existing_rows.get(rel_key, {})
|
| 101 |
+
|
| 102 |
+
source = existing.get("source") or defaults.get("source")
|
| 103 |
+
license_ = existing.get("license") or defaults.get("license")
|
| 104 |
+
split = existing.get("split") or "train"
|
| 105 |
+
|
| 106 |
+
if not source or not license_:
|
| 107 |
+
sys.stderr.write(f"[warn] Missing source/license for {rel_key}; fill these in manually.\n")
|
| 108 |
+
|
| 109 |
+
sha256 = compute_sha256(path)
|
| 110 |
+
|
| 111 |
+
yield MetadataRow(
|
| 112 |
+
path=path,
|
| 113 |
+
device=device,
|
| 114 |
+
source=source or "",
|
| 115 |
+
license=license_ or "",
|
| 116 |
+
split=split,
|
| 117 |
+
sha256=sha256,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def main() -> None:
|
| 122 |
+
args = parse_args()
|
| 123 |
+
config_path = Path(args.config)
|
| 124 |
+
cfg = load_config(config_path)
|
| 125 |
+
|
| 126 |
+
data_cfg = cfg["data"]
|
| 127 |
+
root = Path(data_cfg.get("root", "data")).resolve()
|
| 128 |
+
metadata_path = Path(args.output or data_cfg.get("metadata", root / "metadata.csv")).resolve()
|
| 129 |
+
extensions = {ext.lower() for ext in (args.extensions or data_cfg.get("extensions", DEFAULT_EXTENSIONS))}
|
| 130 |
+
|
| 131 |
+
if not root.exists():
|
| 132 |
+
raise SystemExit(f"Data root does not exist: {root}")
|
| 133 |
+
|
| 134 |
+
existing_rows = read_existing_metadata(metadata_path)
|
| 135 |
+
device_defaults = data_cfg.get("device_defaults", {})
|
| 136 |
+
include_devices = set(data_cfg.get("include_devices", []) or [])
|
| 137 |
+
|
| 138 |
+
files = sorted(gather_files(root, extensions))
|
| 139 |
+
rows = sorted(
|
| 140 |
+
build_rows(files, existing_rows, root, device_defaults, include_devices if include_devices else None),
|
| 141 |
+
key=lambda row: row.path.relative_to(root).as_posix(),
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
metadata_path.parent.mkdir(parents=True, exist_ok=True)
|
| 145 |
+
with metadata_path.open("w", encoding="utf-8", newline="") as fh:
|
| 146 |
+
writer = csv.DictWriter(fh, fieldnames=["path", "device", "source", "license", "split", "sha256"])
|
| 147 |
+
writer.writeheader()
|
| 148 |
+
for row in rows:
|
| 149 |
+
writer.writerow(row.as_dict(root))
|
| 150 |
+
|
| 151 |
+
orphaned = sorted(set(existing_rows) - {row.path.relative_to(root).as_posix() for row in rows})
|
| 152 |
+
if orphaned:
|
| 153 |
+
sys.stderr.write(f"[warn] Orphaned metadata entries (files missing): {len(orphaned)}\n")
|
| 154 |
+
for item in orphaned:
|
| 155 |
+
sys.stderr.write(f" - {item}\n")
|
| 156 |
+
|
| 157 |
+
print(f"Wrote {len(rows)} rows to {metadata_path}")
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
if __name__ == "__main__":
|
| 161 |
+
main()
|
train.py
CHANGED
|
@@ -1,7 +1,15 @@
|
|
| 1 |
-
import
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import json
|
|
|
|
|
|
|
|
|
|
| 4 |
from pathlib import Path
|
|
|
|
| 5 |
|
| 6 |
BASE_DIR = Path(__file__).resolve().parent
|
| 7 |
CACHE_ROOT = BASE_DIR / ".cache"
|
|
@@ -12,93 +20,349 @@ for path in (NUMBA_CACHE_DIR, MPL_CACHE_DIR):
|
|
| 12 |
os.environ.setdefault("NUMBA_CACHE_DIR", str(NUMBA_CACHE_DIR))
|
| 13 |
os.environ.setdefault("MPLCONFIGDIR", str(MPL_CACHE_DIR))
|
| 14 |
|
| 15 |
-
import
|
| 16 |
import matplotlib
|
|
|
|
| 17 |
matplotlib.use("Agg", force=True)
|
| 18 |
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
| 19 |
from sklearn.ensemble import HistGradientBoostingClassifier
|
| 20 |
-
from sklearn.preprocessing import LabelEncoder
|
| 21 |
from sklearn.metrics import classification_report, confusion_matrix
|
| 22 |
from sklearn.model_selection import train_test_split
|
| 23 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
from features import load_mono, extract_features
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
IGNORED_DEVICES = {"outtakes"}
|
| 31 |
-
SUFFIX_TO_DEVICE = {
|
| 32 |
-
"a": "audio",
|
| 33 |
-
"b": "audio2",
|
| 34 |
-
"c": "audio9",
|
| 35 |
-
}
|
| 36 |
-
TAU_DEVICE_DIRS = set(SUFFIX_TO_DEVICE.values())
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
def resolve_device_label(device_dir: str, wav_path: str) -> str:
|
| 40 |
-
"""Infer the correct device label for a wav file.
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
|
| 49 |
-
if device_dir in TAU_DEVICE_DIRS:
|
| 50 |
-
stem = Path(wav_path).stem
|
| 51 |
-
if "-" in stem:
|
| 52 |
-
_, suffix = stem.rsplit("-", 1)
|
| 53 |
-
if suffix in SUFFIX_TO_DEVICE:
|
| 54 |
-
return SUFFIX_TO_DEVICE[suffix]
|
| 55 |
-
return device_dir
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
seen: set[tuple[str, str]] = set()
|
| 61 |
-
for device in sorted(
|
| 62 |
-
d for d in os.listdir(DATA_DIR)
|
| 63 |
-
if os.path.isdir(os.path.join(DATA_DIR, d))
|
| 64 |
-
and not d.startswith(".")
|
| 65 |
-
and d not in IGNORED_DEVICES
|
| 66 |
-
):
|
| 67 |
-
for wav in glob.glob(os.path.join(DATA_DIR, device, "*.wav")):
|
| 68 |
-
label = resolve_device_label(device, wav)
|
| 69 |
-
key = (os.path.basename(wav), label)
|
| 70 |
-
if key in seen:
|
| 71 |
-
continue
|
| 72 |
-
seen.add(key)
|
| 73 |
-
x, sr = load_mono(wav); feats = extract_features(x, sr)
|
| 74 |
-
X.append(feats); y.append(label)
|
| 75 |
-
return np.array(X), np.array(y)
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
if __name__ == "__main__":
|
| 79 |
-
X, y = load_dataset()
|
| 80 |
-
le = LabelEncoder(); y_enc = le.fit_transform(y)
|
| 81 |
-
Xtr, Xte, ytr, yte = train_test_split(X, y_enc, test_size=0.25, stratify=y_enc, random_state=42)
|
| 82 |
-
|
| 83 |
-
clf = HistGradientBoostingClassifier(max_depth=10, max_iter=400, learning_rate=0.08, random_state=42)
|
| 84 |
-
clf.fit(Xtr, ytr); yhat = clf.predict(Xte)
|
| 85 |
-
|
| 86 |
-
report = classification_report(yte, yhat, target_names=le.classes_, output_dict=True)
|
| 87 |
-
with open(os.path.join(REPORT_DIR, "metrics.json"), "w") as f: json.dump(report, f, indent=2)
|
| 88 |
-
|
| 89 |
-
cm = confusion_matrix(yte, yhat, normalize="true")
|
| 90 |
-
fig, ax = plt.subplots(figsize=(5,4)); im = ax.imshow(cm, cmap="Blues")
|
| 91 |
-
ax.set_xticks(range(len(le.classes_))); ax.set_xticklabels(le.classes_, rotation=45, ha="right")
|
| 92 |
-
ax.set_yticks(range(len(le.classes_))); ax.set_yticklabels(le.classes_)
|
| 93 |
-
for i in range(len(le.classes_)):
|
| 94 |
-
for j in range(len(le.classes_)):
|
| 95 |
-
ax.text(j, i, f"{cm[i,j]:.2f}", ha="center", va="center", fontsize=8)
|
| 96 |
-
ax.set_title("Confusion (normalized)"); fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04); fig.tight_layout()
|
| 97 |
-
fig.savefig(os.path.join(REPORT_DIR, "confusion_matrix.png"), dpi=160)
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
if hasattr(clf, "_feature_subsample_rng"):
|
| 100 |
clf._feature_subsample_rng = None
|
| 101 |
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
print("Saved model + reports.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import csv
|
| 5 |
+
import datetime as dt
|
| 6 |
+
import hashlib
|
| 7 |
import json
|
| 8 |
+
import os
|
| 9 |
+
from collections import Counter
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
from pathlib import Path
|
| 12 |
+
from typing import Iterable, Sequence
|
| 13 |
|
| 14 |
BASE_DIR = Path(__file__).resolve().parent
|
| 15 |
CACHE_ROOT = BASE_DIR / ".cache"
|
|
|
|
| 20 |
os.environ.setdefault("NUMBA_CACHE_DIR", str(NUMBA_CACHE_DIR))
|
| 21 |
os.environ.setdefault("MPLCONFIGDIR", str(MPL_CACHE_DIR))
|
| 22 |
|
| 23 |
+
import joblib
|
| 24 |
import matplotlib
|
| 25 |
+
|
| 26 |
matplotlib.use("Agg", force=True)
|
| 27 |
import matplotlib.pyplot as plt
|
| 28 |
+
import numpy as np
|
| 29 |
+
import yaml
|
| 30 |
from sklearn.ensemble import HistGradientBoostingClassifier
|
|
|
|
| 31 |
from sklearn.metrics import classification_report, confusion_matrix
|
| 32 |
from sklearn.model_selection import train_test_split
|
| 33 |
+
from sklearn.preprocessing import LabelEncoder
|
| 34 |
+
|
| 35 |
+
from features import extract_features, load_mono
|
| 36 |
+
|
| 37 |
+
TARGET_SR = 16000
|
| 38 |
+
REQUIRED_COLUMNS = {"path", "device", "source", "license", "split", "sha256"}
|
| 39 |
+
MODEL_DIR = BASE_DIR / "models"
|
| 40 |
+
REPORT_DIR = BASE_DIR / "reports"
|
| 41 |
+
|
| 42 |
+
MODEL_DIR.mkdir(exist_ok=True)
|
| 43 |
+
REPORT_DIR.mkdir(exist_ok=True)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass(frozen=True)
|
| 47 |
+
class ClipRecord:
|
| 48 |
+
path: Path
|
| 49 |
+
device: str
|
| 50 |
+
source: str
|
| 51 |
+
license: str
|
| 52 |
+
split: str
|
| 53 |
+
sha256: str
|
| 54 |
+
|
| 55 |
+
def relative_path(self, root: Path) -> str:
|
| 56 |
+
return self.path.relative_to(root).as_posix()
|
| 57 |
|
|
|
|
| 58 |
|
| 59 |
+
def parse_args() -> argparse.Namespace:
|
| 60 |
+
parser = argparse.ArgumentParser(description="Train the Mic-ID classifier with provenance tracking.")
|
| 61 |
+
parser.add_argument("--config", default="configs/base.yaml", help="YAML config describing data + training parameters.")
|
| 62 |
+
parser.add_argument("--dry-run", action="store_true", help="Validate metadata and show dataset summary without training.")
|
| 63 |
+
return parser.parse_args()
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
+
def load_config(path: Path) -> dict:
|
| 67 |
+
if not path.exists():
|
| 68 |
+
raise SystemExit(f"Config not found: {path}")
|
| 69 |
+
with path.open("r", encoding="utf-8") as fh:
|
| 70 |
+
cfg = yaml.safe_load(fh) or {}
|
| 71 |
+
if "data" not in cfg or "training" not in cfg or "reporting" not in cfg:
|
| 72 |
+
raise SystemExit("Config must include `data`, `training`, and `reporting` sections.")
|
| 73 |
+
return cfg
|
| 74 |
|
|
|
|
|
|
|
| 75 |
|
| 76 |
+
def compute_sha256(path: Path) -> str:
|
| 77 |
+
hasher = hashlib.sha256()
|
| 78 |
+
with path.open("rb") as fh:
|
| 79 |
+
for chunk in iter(lambda: fh.read(8192), b""):
|
| 80 |
+
hasher.update(chunk)
|
| 81 |
+
return hasher.hexdigest()
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
+
def read_metadata_csv(path: Path) -> list[dict]:
|
| 85 |
+
with path.open("r", encoding="utf-8", newline="") as fh:
|
| 86 |
+
reader = csv.DictReader(fh)
|
| 87 |
+
headers = set(reader.fieldnames or [])
|
| 88 |
+
missing = REQUIRED_COLUMNS - headers
|
| 89 |
+
if missing:
|
| 90 |
+
raise SystemExit(f"Metadata file {path} is missing required columns: {sorted(missing)}")
|
| 91 |
+
return list(reader)
|
| 92 |
|
| 93 |
+
|
| 94 |
+
def load_clip_records(data_cfg: dict) -> tuple[list[ClipRecord], Path, Path]:
|
| 95 |
+
root = Path(data_cfg.get("root", "data")).resolve()
|
| 96 |
+
metadata_path = Path(data_cfg.get("metadata", root / "metadata.csv")).resolve()
|
| 97 |
+
enforce_hashes = bool(data_cfg.get("enforce_hashes", True))
|
| 98 |
+
splits_filter = set(data_cfg.get("splits", []) or [])
|
| 99 |
+
include_devices = set(data_cfg.get("include_devices", []) or [])
|
| 100 |
+
|
| 101 |
+
if not root.exists():
|
| 102 |
+
raise SystemExit(f"Data root does not exist: {root}")
|
| 103 |
+
if not metadata_path.exists():
|
| 104 |
+
raise SystemExit(f"Metadata file not found: {metadata_path}")
|
| 105 |
+
|
| 106 |
+
raw_rows = read_metadata_csv(metadata_path)
|
| 107 |
+
records: list[ClipRecord] = []
|
| 108 |
seen: set[tuple[str, str]] = set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
+
for idx, row in enumerate(raw_rows, start=2):
|
| 111 |
+
rel_path = row["path"].strip()
|
| 112 |
+
device = row["device"].strip()
|
| 113 |
+
source = row["source"].strip()
|
| 114 |
+
license_ = row["license"].strip()
|
| 115 |
+
split = row["split"].strip() or "train"
|
| 116 |
+
sha256 = row["sha256"].strip()
|
| 117 |
+
|
| 118 |
+
if include_devices and device not in include_devices:
|
| 119 |
+
continue
|
| 120 |
+
if splits_filter and split not in splits_filter:
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
if not rel_path:
|
| 124 |
+
raise SystemExit(f"Row {idx} is missing a path.")
|
| 125 |
+
if not device:
|
| 126 |
+
raise SystemExit(f"Row {idx} is missing a device label (path={rel_path}).")
|
| 127 |
+
if not source or not license_:
|
| 128 |
+
raise SystemExit(f"Row {idx} missing source/license information (device={device}, path={rel_path}).")
|
| 129 |
+
|
| 130 |
+
full_path = root / rel_path
|
| 131 |
+
if not full_path.exists():
|
| 132 |
+
raise SystemExit(f"Audio file referenced in metadata not found: {full_path}")
|
| 133 |
+
|
| 134 |
+
if not sha256:
|
| 135 |
+
current_hash = compute_sha256(full_path)
|
| 136 |
+
else:
|
| 137 |
+
current_hash = compute_sha256(full_path) if enforce_hashes else sha256
|
| 138 |
+
if enforce_hashes and current_hash != sha256:
|
| 139 |
+
raise SystemExit(
|
| 140 |
+
f"Hash mismatch for {rel_path}: metadata={sha256} current={current_hash}. "
|
| 141 |
+
"Regenerate metadata via scripts/refresh_metadata.py."
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
key = (rel_path, device)
|
| 145 |
+
if key in seen:
|
| 146 |
+
raise SystemExit(f"Duplicate clip/device combination detected in metadata: {rel_path} ({device})")
|
| 147 |
+
seen.add(key)
|
| 148 |
+
|
| 149 |
+
records.append(
|
| 150 |
+
ClipRecord(
|
| 151 |
+
path=full_path,
|
| 152 |
+
device=device,
|
| 153 |
+
source=source,
|
| 154 |
+
license=license_,
|
| 155 |
+
split=split,
|
| 156 |
+
sha256=current_hash if enforce_hashes else current_hash,
|
| 157 |
+
)
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
if include_devices:
|
| 161 |
+
for dev in include_devices:
|
| 162 |
+
if dev not in {record.device for record in records}:
|
| 163 |
+
raise SystemExit(f"No clips found for requested device: {dev}")
|
| 164 |
+
|
| 165 |
+
if not records:
|
| 166 |
+
raise SystemExit("No audio clips passed the metadata filters; nothing to train on.")
|
| 167 |
+
|
| 168 |
+
return records, root, metadata_path
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def ensure_minimum_counts(records: Sequence[ClipRecord], minimum: int) -> Counter:
|
| 172 |
+
counts = Counter(record.device for record in records)
|
| 173 |
+
violations = {device: count for device, count in counts.items() if count < minimum}
|
| 174 |
+
if violations:
|
| 175 |
+
formatted = ", ".join(f"{dev} ({count})" for dev, count in violations.items())
|
| 176 |
+
raise SystemExit(f"Not enough clips per device. Increase data or lower the threshold. Offenders: {formatted}")
|
| 177 |
+
return counts
|
| 178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
+
def summarise_records(records: Sequence[ClipRecord], root: Path) -> dict:
|
| 181 |
+
counts = Counter(record.device for record in records)
|
| 182 |
+
sources = {record.device: record.source for record in records}
|
| 183 |
+
licenses = {record.device: record.license for record in records}
|
| 184 |
+
return {
|
| 185 |
+
"total_clips": len(records),
|
| 186 |
+
"devices": dict(counts),
|
| 187 |
+
"sources": sources,
|
| 188 |
+
"licenses": licenses,
|
| 189 |
+
"first_five_hashes": [
|
| 190 |
+
{"path": record.relative_path(root), "sha256": record.sha256}
|
| 191 |
+
for record in records[: min(5, len(records))]
|
| 192 |
+
],
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def collect_hashes(records: Sequence[ClipRecord], root: Path) -> list[dict]:
|
| 197 |
+
return [
|
| 198 |
+
{"path": record.relative_path(root), "sha256": record.sha256}
|
| 199 |
+
for record in records
|
| 200 |
+
]
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def build_dataset(records: Sequence[ClipRecord]) -> tuple[np.ndarray, np.ndarray]:
|
| 204 |
+
features, labels = [], []
|
| 205 |
+
for record in records:
|
| 206 |
+
audio, sr = load_mono(record.path, sr=TARGET_SR)
|
| 207 |
+
feats = extract_features(audio, sr)
|
| 208 |
+
features.append(feats)
|
| 209 |
+
labels.append(record.device)
|
| 210 |
+
return np.array(features), np.array(labels)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def instantiate_classifier(cfg: dict) -> HistGradientBoostingClassifier:
|
| 214 |
+
clf_cfg = dict(cfg.get("classifier", {}))
|
| 215 |
+
random_state = cfg.get("random_state")
|
| 216 |
+
if random_state is not None:
|
| 217 |
+
clf_cfg.setdefault("random_state", random_state)
|
| 218 |
+
if not clf_cfg:
|
| 219 |
+
clf_cfg = {"max_depth": 10, "max_iter": 400, "learning_rate": 0.08}
|
| 220 |
+
if random_state is not None:
|
| 221 |
+
clf_cfg["random_state"] = random_state
|
| 222 |
+
return HistGradientBoostingClassifier(**clf_cfg)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def plot_confusion_matrix(cm: np.ndarray, labels: Sequence[str], output_path: Path) -> None:
|
| 226 |
+
fig, ax = plt.subplots(figsize=(5, 4))
|
| 227 |
+
im = ax.imshow(cm, cmap="Blues")
|
| 228 |
+
ax.set_xticks(range(len(labels)))
|
| 229 |
+
ax.set_xticklabels(labels, rotation=45, ha="right")
|
| 230 |
+
ax.set_yticks(range(len(labels)))
|
| 231 |
+
ax.set_yticklabels(labels)
|
| 232 |
+
for i in range(len(labels)):
|
| 233 |
+
for j in range(len(labels)):
|
| 234 |
+
ax.text(j, i, f"{cm[i, j]:.2f}", ha="center", va="center", fontsize=8)
|
| 235 |
+
ax.set_title("Confusion (normalized)")
|
| 236 |
+
fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
| 237 |
+
fig.tight_layout()
|
| 238 |
+
fig.savefig(output_path, dpi=160)
|
| 239 |
+
plt.close(fig)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def write_run_report(
|
| 243 |
+
reporting_cfg: dict,
|
| 244 |
+
config_path: Path,
|
| 245 |
+
config: dict,
|
| 246 |
+
records: Sequence[ClipRecord],
|
| 247 |
+
root: Path,
|
| 248 |
+
metrics: dict,
|
| 249 |
+
dataset_summary: dict,
|
| 250 |
+
hashes: Sequence[dict],
|
| 251 |
+
model_path: Path,
|
| 252 |
+
encoder_path: Path,
|
| 253 |
+
) -> Path:
|
| 254 |
+
runs_dir = Path(reporting_cfg.get("runs_dir", REPORT_DIR / "runs")).resolve()
|
| 255 |
+
runs_dir.mkdir(parents=True, exist_ok=True)
|
| 256 |
+
now_utc = dt.datetime.now(dt.timezone.utc).replace(microsecond=0)
|
| 257 |
+
timestamp = now_utc.strftime("%Y%m%d-%H%M%S")
|
| 258 |
+
tag = reporting_cfg.get("tag")
|
| 259 |
+
filename = f"run-{timestamp}"
|
| 260 |
+
if tag:
|
| 261 |
+
filename += f"-{tag}"
|
| 262 |
+
run_path = runs_dir / f"{filename}.json"
|
| 263 |
+
|
| 264 |
+
payload = {
|
| 265 |
+
"timestamp_utc": now_utc.isoformat().replace("+00:00", "Z"),
|
| 266 |
+
"config_path": str(config_path.resolve()),
|
| 267 |
+
"config_snapshot": config,
|
| 268 |
+
"dataset": {
|
| 269 |
+
**dataset_summary,
|
| 270 |
+
"metadata_root": str(root),
|
| 271 |
+
"hashes": list(hashes),
|
| 272 |
+
},
|
| 273 |
+
"metrics": metrics,
|
| 274 |
+
"artefacts": {
|
| 275 |
+
"model": str(model_path),
|
| 276 |
+
"label_encoder": str(encoder_path),
|
| 277 |
+
"metrics_json": str(Path(reporting_cfg.get("metrics_path", REPORT_DIR / "metrics.json")).resolve()),
|
| 278 |
+
"confusion_matrix": str(Path(reporting_cfg.get("confusion_matrix_path", REPORT_DIR / "confusion_matrix.png")).resolve()),
|
| 279 |
+
},
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
with run_path.open("w", encoding="utf-8") as fh:
|
| 283 |
+
json.dump(payload, fh, indent=2)
|
| 284 |
+
return run_path
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def main() -> None:
|
| 288 |
+
args = parse_args()
|
| 289 |
+
config_path = Path(args.config)
|
| 290 |
+
config = load_config(config_path)
|
| 291 |
+
|
| 292 |
+
data_cfg = config["data"]
|
| 293 |
+
training_cfg = config["training"]
|
| 294 |
+
reporting_cfg = config["reporting"]
|
| 295 |
+
|
| 296 |
+
records, data_root, metadata_path = load_clip_records(data_cfg)
|
| 297 |
+
min_clips = int(data_cfg.get("min_clips_per_device", 1))
|
| 298 |
+
ensure_minimum_counts(records, min_clips)
|
| 299 |
+
dataset_summary = summarise_records(records, data_root)
|
| 300 |
+
hashes = collect_hashes(records, data_root)
|
| 301 |
+
dataset_summary["metadata_file"] = str(metadata_path)
|
| 302 |
+
|
| 303 |
+
print("Dataset summary:")
|
| 304 |
+
for key, value in dataset_summary.items():
|
| 305 |
+
print(f" {key}: {value}")
|
| 306 |
+
|
| 307 |
+
if args.dry_run:
|
| 308 |
+
print("Dry run complete. Exiting without training.")
|
| 309 |
+
return
|
| 310 |
+
|
| 311 |
+
X, y = build_dataset(records)
|
| 312 |
+
label_encoder = LabelEncoder()
|
| 313 |
+
y_encoded = label_encoder.fit_transform(y)
|
| 314 |
+
|
| 315 |
+
test_size = float(training_cfg.get("test_size", 0.25))
|
| 316 |
+
random_state = training_cfg.get("random_state", 42)
|
| 317 |
+
stratify = training_cfg.get("stratify", True)
|
| 318 |
+
stratify_arg = y_encoded if stratify else None
|
| 319 |
+
|
| 320 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 321 |
+
X,
|
| 322 |
+
y_encoded,
|
| 323 |
+
test_size=test_size,
|
| 324 |
+
stratify=stratify_arg,
|
| 325 |
+
random_state=random_state,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
clf = instantiate_classifier(training_cfg)
|
| 329 |
+
clf.fit(X_train, y_train)
|
| 330 |
+
y_pred = clf.predict(X_test)
|
| 331 |
+
|
| 332 |
+
report = classification_report(y_test, y_pred, target_names=label_encoder.classes_, output_dict=True)
|
| 333 |
+
metrics_path = Path(reporting_cfg.get("metrics_path", REPORT_DIR / "metrics.json"))
|
| 334 |
+
with metrics_path.open("w", encoding="utf-8") as fh:
|
| 335 |
+
json.dump(report, fh, indent=2)
|
| 336 |
+
|
| 337 |
+
cm = confusion_matrix(y_test, y_pred, normalize="true")
|
| 338 |
+
confusion_path = Path(reporting_cfg.get("confusion_matrix_path", REPORT_DIR / "confusion_matrix.png"))
|
| 339 |
+
plot_confusion_matrix(cm, label_encoder.classes_, confusion_path)
|
| 340 |
+
|
| 341 |
+
# Clean up non-serializable RNG to keep joblib artefacts deterministic.
|
| 342 |
if hasattr(clf, "_feature_subsample_rng"):
|
| 343 |
clf._feature_subsample_rng = None
|
| 344 |
|
| 345 |
+
model_path = MODEL_DIR / "model.pkl"
|
| 346 |
+
encoder_path = MODEL_DIR / "label_encoder.pkl"
|
| 347 |
+
joblib.dump(clf, model_path)
|
| 348 |
+
joblib.dump(label_encoder, encoder_path)
|
| 349 |
+
|
| 350 |
+
run_report_path = write_run_report(
|
| 351 |
+
reporting_cfg,
|
| 352 |
+
config_path,
|
| 353 |
+
config,
|
| 354 |
+
records,
|
| 355 |
+
data_root,
|
| 356 |
+
report,
|
| 357 |
+
dataset_summary,
|
| 358 |
+
hashes,
|
| 359 |
+
model_path,
|
| 360 |
+
encoder_path,
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
print("Saved model + reports.")
|
| 364 |
+
print(f"Run snapshot written to {run_report_path}")
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
if __name__ == "__main__":
|
| 368 |
+
main()
|