Spaces:
Sleeping
Sleeping
Abdelrahman Almatrooshi commited on
Commit ·
22a6915
0
Parent(s):
Deploy snapshot from main b7a59b11809483dfc959f196f1930240f2662c49
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .coveragerc +23 -0
- .dockerignore +31 -0
- .gitattributes +5 -0
- .gitignore +56 -0
- Dockerfile +35 -0
- LICENSE +21 -0
- README.md +366 -0
- api/__init__.py +1 -0
- api/db.py +201 -0
- api/drawing.py +124 -0
- app.py +1 -0
- assets/focusguard-demo.gif +3 -0
- checkpoints/L2CSNet_gaze360.pkl +3 -0
- checkpoints/README.md +47 -0
- checkpoints/hybrid_combiner.joblib +3 -0
- checkpoints/hybrid_focus_config.json +10 -0
- checkpoints/meta_best.npz +3 -0
- checkpoints/meta_mlp.npz +3 -0
- checkpoints/mlp_best.pt +3 -0
- checkpoints/scaler_mlp.joblib +3 -0
- checkpoints/xgboost_face_orientation_best.json +0 -0
- config/README.md +45 -0
- config/__init__.py +60 -0
- config/clearml_enrich.py +87 -0
- config/default.yaml +80 -0
- data_preparation/README.md +90 -0
- data_preparation/__init__.py +0 -0
- data_preparation/data_exploration.ipynb +0 -0
- data_preparation/prepare_dataset.py +279 -0
- docker-compose.yml +5 -0
- download_l2cs_weights.py +37 -0
- eslint.config.js +42 -0
- evaluation/GROUPED_SPLIT_BENCHMARK.md +13 -0
- evaluation/README.md +84 -0
- evaluation/THRESHOLD_JUSTIFICATION.md +100 -0
- evaluation/feature_importance.py +279 -0
- evaluation/feature_selection_justification.md +53 -0
- evaluation/grouped_split_benchmark.py +107 -0
- evaluation/justify_thresholds.py +573 -0
- evaluation/logs/.gitkeep +0 -0
- evaluation/plots/confusion_matrix_mlp.png +0 -0
- evaluation/plots/confusion_matrix_xgb.png +0 -0
- evaluation/plots/ear_distribution.png +0 -0
- evaluation/plots/geo_weight_search.png +0 -0
- evaluation/plots/hybrid_weight_search.png +0 -0
- evaluation/plots/hybrid_xgb_weight_search.png +0 -0
- evaluation/plots/mar_distribution.png +0 -0
- evaluation/plots/roc_mlp.png +0 -0
- evaluation/plots/roc_xgb.png +0 -0
- index.html +17 -0
.coveragerc
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[run]
|
| 2 |
+
branch = True
|
| 3 |
+
source =
|
| 4 |
+
.
|
| 5 |
+
omit =
|
| 6 |
+
.venv/*
|
| 7 |
+
venv/*
|
| 8 |
+
*/site-packages/*
|
| 9 |
+
tests/*
|
| 10 |
+
notebooks/*
|
| 11 |
+
evaluation/*
|
| 12 |
+
models/mlp/train.py
|
| 13 |
+
models/mlp/sweep.py
|
| 14 |
+
models/mlp/eval_accuracy.py
|
| 15 |
+
models/cnn/eye_attention/train.py
|
| 16 |
+
models\collect_features.py
|
| 17 |
+
[report]
|
| 18 |
+
show_missing = True
|
| 19 |
+
skip_covered = False
|
| 20 |
+
precision = 1
|
| 21 |
+
exclude_lines =
|
| 22 |
+
pragma: no cover
|
| 23 |
+
if __name__ == .__main__.:
|
.dockerignore
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
.gitattributes
|
| 3 |
+
.github
|
| 4 |
+
node_modules
|
| 5 |
+
dist
|
| 6 |
+
venv
|
| 7 |
+
.venv
|
| 8 |
+
__pycache__
|
| 9 |
+
*.pyc
|
| 10 |
+
.pytest_cache
|
| 11 |
+
.mypy_cache
|
| 12 |
+
.ruff_cache
|
| 13 |
+
|
| 14 |
+
notebooks/
|
| 15 |
+
evaluation/
|
| 16 |
+
tests/
|
| 17 |
+
others/
|
| 18 |
+
*.ipynb
|
| 19 |
+
requirements-dev.txt
|
| 20 |
+
pytest.ini
|
| 21 |
+
eslint.config.js
|
| 22 |
+
docker-compose.yml
|
| 23 |
+
|
| 24 |
+
models/L2CS-Net/L2CS-Net-backup/
|
| 25 |
+
|
| 26 |
+
*.db
|
| 27 |
+
|
| 28 |
+
.DS_Store
|
| 29 |
+
.cursor
|
| 30 |
+
.vscode
|
| 31 |
+
*.swp
|
.gitattributes
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Logs
|
| 2 |
+
logs
|
| 3 |
+
*.log
|
| 4 |
+
npm-debug.log*
|
| 5 |
+
yarn-debug.log*
|
| 6 |
+
yarn-error.log*
|
| 7 |
+
pnpm-debug.log*
|
| 8 |
+
lerna-debug.log*
|
| 9 |
+
|
| 10 |
+
node_modules/
|
| 11 |
+
dist/
|
| 12 |
+
dist-ssr/
|
| 13 |
+
*.local
|
| 14 |
+
|
| 15 |
+
# Editor directories and files
|
| 16 |
+
.vscode/
|
| 17 |
+
.idea/
|
| 18 |
+
.DS_Store
|
| 19 |
+
*.suo
|
| 20 |
+
*.ntvs*
|
| 21 |
+
*.njsproj
|
| 22 |
+
*.sln
|
| 23 |
+
*.sw?
|
| 24 |
+
*.py[cod]
|
| 25 |
+
*$py.class
|
| 26 |
+
*.so
|
| 27 |
+
.Python
|
| 28 |
+
venv/
|
| 29 |
+
.venv/
|
| 30 |
+
env/
|
| 31 |
+
.env
|
| 32 |
+
*.egg-info/
|
| 33 |
+
.eggs/
|
| 34 |
+
build/
|
| 35 |
+
Thumbs.db
|
| 36 |
+
ignore/
|
| 37 |
+
|
| 38 |
+
# Coverage / caches
|
| 39 |
+
.coverage
|
| 40 |
+
htmlcov/
|
| 41 |
+
|
| 42 |
+
# Project specific
|
| 43 |
+
focus_guard.db
|
| 44 |
+
test_focus_guard.db
|
| 45 |
+
# Large weights: fetch at build/runtime (see download_l2cs_weights.py)
|
| 46 |
+
checkpoints/L2CSNet_gaze360.pkl
|
| 47 |
+
models/L2CS-Net/models/L2CSNet_gaze360.pkl
|
| 48 |
+
# Training artefacts (too large for HF Hub; keep local only)
|
| 49 |
+
data/
|
| 50 |
+
data_preparation/collected*/
|
| 51 |
+
best_eye_cnn.pth
|
| 52 |
+
checkpoints/model_best.joblib
|
| 53 |
+
__pycache__/
|
| 54 |
+
docs/
|
| 55 |
+
docs
|
| 56 |
+
LOCAL_TESTING.md
|
Dockerfile
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
RUN useradd -m -u 1000 user
|
| 4 |
+
ENV HOME=/home/user PATH=/home/user/.local/bin:$PATH
|
| 5 |
+
ENV PYTHONUNBUFFERED=1
|
| 6 |
+
|
| 7 |
+
WORKDIR /app
|
| 8 |
+
|
| 9 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 10 |
+
libglib2.0-0 libsm6 libxrender1 libxext6 libxcb1 libgl1 libgles2 libegl1 libgomp1 \
|
| 11 |
+
ffmpeg libavcodec-dev libavformat-dev libavutil-dev libswscale-dev \
|
| 12 |
+
libavdevice-dev libopus-dev libvpx-dev libsrtp2-dev \
|
| 13 |
+
build-essential nodejs npm git \
|
| 14 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 15 |
+
|
| 16 |
+
RUN pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
| 17 |
+
|
| 18 |
+
COPY requirements.txt ./
|
| 19 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 20 |
+
|
| 21 |
+
COPY . .
|
| 22 |
+
|
| 23 |
+
RUN npm install && npm run build && mkdir -p /app/static && cp -R dist/* /app/static/ \
|
| 24 |
+
&& rm -rf node_modules dist
|
| 25 |
+
|
| 26 |
+
ENV FOCUSGUARD_CACHE_DIR=/app/.cache/focusguard
|
| 27 |
+
RUN python -c "from models.face_mesh import _ensure_model; _ensure_model()"
|
| 28 |
+
RUN python download_l2cs_weights.py || echo "[WARN] L2CS weights not downloaded — will run without gaze model"
|
| 29 |
+
|
| 30 |
+
RUN mkdir -p /app/data && chown -R user:user /app
|
| 31 |
+
|
| 32 |
+
USER user
|
| 33 |
+
EXPOSE 7860
|
| 34 |
+
|
| 35 |
+
CMD ["bash", "start.sh"]
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 k23172173
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: FocusGuard
|
| 3 |
+
emoji: 👁️
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
pinned: false
|
| 9 |
+
short_description: Real-time webcam focus detection via MediaPipe + MLP/XGBoost
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# FocusGuard
|
| 13 |
+
|
| 14 |
+
Real-time webcam-based visual attention estimation. MediaPipe Face Mesh extracts 17 features (EAR, gaze ratios, head pose, PERCLOS) per frame, selects 10, and routes them through MLP or XGBoost for binary focused/unfocused classification. Includes a local OpenCV demo and a full React + FastAPI web app with WebSocket/WebRTC video streaming.
|
| 15 |
+
|
| 16 |
+

|
| 17 |
+
|
| 18 |
+
---
|
| 19 |
+
|
| 20 |
+
## Team
|
| 21 |
+
|
| 22 |
+
**Team name:** FocusGuards (5CCSAGAP Large Group Project)
|
| 23 |
+
|
| 24 |
+
**Members:** Yingao Zheng, Mohamed Alketbi, Abdelrahman Almatrooshi, Junhao Zhou, Kexin Wang, Langyuan Huang, Saba Al-Gafri, Ayten Arab, Jaroslav Rakoto-Miklas
|
| 25 |
+
|
| 26 |
+
---
|
| 27 |
+
|
| 28 |
+
## Links
|
| 29 |
+
|
| 30 |
+
### Project access
|
| 31 |
+
|
| 32 |
+
- Git repository: [GAP_Large_project](https://github.kcl.ac.uk/k23172173/GAP_Large_project)
|
| 33 |
+
- Deployed app (Hugging Face): [FocusGuard/final_v2](https://huggingface.co/spaces/FocusGuard/final_v2)
|
| 34 |
+
- ClearML experiments: [FocusGuards Large Group Project](https://app.5ccsagap.er.kcl.ac.uk/projects/ce218b2f751641c68042f8fa216f8746/experiments)
|
| 35 |
+
|
| 36 |
+
### Data and checkpoints
|
| 37 |
+
|
| 38 |
+
- Checkpoints (Google Drive): [Download folder](https://drive.google.com/drive/folders/15yYHKgCHg5AFIBb04XnVaeqHRukwBLAd?usp=drive_link)
|
| 39 |
+
- Dataset (Google Drive): [Dataset folder](https://drive.google.com/drive/folders/1fwACM6i6uVGFkTlJKSlqVhizzgrHl_gY?usp=sharing)
|
| 40 |
+
- Data consent form (PDF): [Consent document](https://drive.google.com/file/d/1g1Hc764ffljoKrjApD6nmWDCXJGYTR0j/view?usp=drive_link)
|
| 41 |
+
|
| 42 |
+
The deployed app contains the full feature set (session history, L2CS calibration, model selector, achievements).
|
| 43 |
+
|
| 44 |
+
---
|
| 45 |
+
|
| 46 |
+
## Trained models
|
| 47 |
+
|
| 48 |
+
Model checkpoints are **not included** in the submission archive. Download them before running inference.
|
| 49 |
+
|
| 50 |
+
### Option 1: Hugging Face Space
|
| 51 |
+
|
| 52 |
+
Pre-trained checkpoints are available in the Hugging Face Space files:
|
| 53 |
+
|
| 54 |
+
```
|
| 55 |
+
https://huggingface.co/spaces/FocusGuard/final_v2/tree/main/checkpoints
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
Download and place into `checkpoints/`:
|
| 59 |
+
|
| 60 |
+
| File | Description |
|
| 61 |
+
|------|-------------|
|
| 62 |
+
| `mlp_best.pt` | PyTorch MLP (10-64-32-2, ~2,850 params) |
|
| 63 |
+
| `xgboost_face_orientation_best.json` | XGBoost (600 trees, depth 8, lr 0.1489) |
|
| 64 |
+
| `scaler_mlp.joblib` | StandardScaler fit on training data |
|
| 65 |
+
| `hybrid_focus_config.json` | Hybrid pipeline fusion weights |
|
| 66 |
+
| `hybrid_combiner.joblib` | Hybrid combiner |
|
| 67 |
+
| `L2CSNet_gaze360.pkl` | L2CS-Net ResNet50 gaze weights (96 MB) |
|
| 68 |
+
|
| 69 |
+
### Option 2: ClearML
|
| 70 |
+
|
| 71 |
+
Models are registered as ClearML OutputModels under project "FocusGuards Large Group Project".
|
| 72 |
+
|
| 73 |
+
| Model | Task ID | Model ID |
|
| 74 |
+
|-------|---------|----------|
|
| 75 |
+
| MLP | `3899b5aa0c3348b28213a3194322cdf7` | `56f94b799f624bdc845fa50c4d0606fe` |
|
| 76 |
+
| XGBoost | `c0ceb8e7e8194a51a7a31078cc47775c` | `6727b8de334f4ca0961c46b436f6fb7c` |
|
| 77 |
+
|
| 78 |
+
**UI:** Open a task on the [experiments page](https://app.5ccsagap.er.kcl.ac.uk/projects/ce218b2f751641c68042f8fa216f8746/experiments), go to Artifacts > Output Models, and download.
|
| 79 |
+
|
| 80 |
+
**Python:**
|
| 81 |
+
|
| 82 |
+
```python
|
| 83 |
+
from clearml import Model
|
| 84 |
+
|
| 85 |
+
mlp = Model(model_id="56f94b799f624bdc845fa50c4d0606fe")
|
| 86 |
+
mlp_path = mlp.get_local_copy() # downloads .pt
|
| 87 |
+
|
| 88 |
+
xgb = Model(model_id="6727b8de334f4ca0961c46b436f6fb7c")
|
| 89 |
+
xgb_path = xgb.get_local_copy() # downloads .json
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
Copy the downloaded files into `checkpoints/`.
|
| 93 |
+
|
| 94 |
+
### Option 3: Google Drive (submission fallback)
|
| 95 |
+
|
| 96 |
+
If ClearML access is restricted, download checkpoints from:
|
| 97 |
+
https://drive.google.com/drive/folders/15yYHKgCHg5AFIBb04XnVaeqHRukwBLAd?usp=drive_link
|
| 98 |
+
|
| 99 |
+
Place all files under `checkpoints/`.
|
| 100 |
+
|
| 101 |
+
### Option 4: Retrain from scratch
|
| 102 |
+
|
| 103 |
+
```bash
|
| 104 |
+
python -m models.mlp.train
|
| 105 |
+
python -m models.xgboost.train
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
This regenerates `checkpoints/mlp_best.pt`, `checkpoints/xgboost_face_orientation_best.json`, and scalers. Requires training data under `data/collected_*/`.
|
| 109 |
+
|
| 110 |
+
---
|
| 111 |
+
|
| 112 |
+
## Project layout
|
| 113 |
+
|
| 114 |
+
```
|
| 115 |
+
config/
|
| 116 |
+
default.yaml hyperparameters, thresholds, app settings
|
| 117 |
+
__init__.py config loader + ClearML flattener
|
| 118 |
+
clearml_enrich.py ClearML task enrichment + artifact upload
|
| 119 |
+
data_preparation/
|
| 120 |
+
prepare_dataset.py load/split/scale .npz files (pooled + LOPO)
|
| 121 |
+
data_exploration.ipynb EDA: distributions, class balance, correlations
|
| 122 |
+
models/
|
| 123 |
+
face_mesh.py MediaPipe 478-point face landmarks
|
| 124 |
+
head_pose.py yaw/pitch/roll via solvePnP, face-orientation score
|
| 125 |
+
eye_scorer.py EAR, MAR, gaze ratios, PERCLOS
|
| 126 |
+
collect_features.py real-time feature extraction + webcam labelling CLI
|
| 127 |
+
gaze_calibration.py 9-point polynomial gaze calibration
|
| 128 |
+
gaze_eye_fusion.py fuses calibrated gaze with eye openness
|
| 129 |
+
mlp/ MLP training, eval, Optuna sweep
|
| 130 |
+
xgboost/ XGBoost training, eval, ClearML + Optuna sweeps
|
| 131 |
+
L2CS-Net/ vendored L2CS-Net (ResNet50, Gaze360)
|
| 132 |
+
checkpoints/ (excluded from archive; see download instructions above)
|
| 133 |
+
notebooks/
|
| 134 |
+
mlp.ipynb MLP training + LOPO in Jupyter
|
| 135 |
+
xgboost.ipynb XGBoost training + LOPO in Jupyter
|
| 136 |
+
evaluation/
|
| 137 |
+
justify_thresholds.py LOPO threshold + weight grid search
|
| 138 |
+
feature_importance.py XGBoost gain + leave-one-feature-out ablation
|
| 139 |
+
grouped_split_benchmark.py pooled vs LOPO comparison
|
| 140 |
+
plots/ ROC curves, confusion matrices, weight searches
|
| 141 |
+
logs/ JSON training logs
|
| 142 |
+
tests/
|
| 143 |
+
test_*.py unit + integration tests (pytest)
|
| 144 |
+
.coveragerc coverage config
|
| 145 |
+
ui/
|
| 146 |
+
pipeline.py all 5 pipeline classes + output smoothing
|
| 147 |
+
live_demo.py OpenCV webcam demo
|
| 148 |
+
src/ React (Vite) frontend source
|
| 149 |
+
static/ built frontend assets (after npm build)
|
| 150 |
+
main.py FastAPI application entry point
|
| 151 |
+
package.json frontend package manifest
|
| 152 |
+
requirements.txt
|
| 153 |
+
pytest.ini
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
---
|
| 157 |
+
|
| 158 |
+
## Setup
|
| 159 |
+
|
| 160 |
+
Recommended versions:
|
| 161 |
+
|
| 162 |
+
- Python 3.10-3.11
|
| 163 |
+
- Node.js 18+ (needed only for frontend rebuild/dev)
|
| 164 |
+
|
| 165 |
+
```bash
|
| 166 |
+
python -m venv venv
|
| 167 |
+
source venv/bin/activate # Windows: venv\Scripts\activate
|
| 168 |
+
pip install -r requirements.txt
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
Then download checkpoints (see above).
|
| 172 |
+
|
| 173 |
+
If you need to rebuild frontend assets locally:
|
| 174 |
+
|
| 175 |
+
```bash
|
| 176 |
+
npm install
|
| 177 |
+
npm run build
|
| 178 |
+
mkdir -p static && cp -r dist/* static/
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
---
|
| 182 |
+
|
| 183 |
+
## Run
|
| 184 |
+
|
| 185 |
+
### Local OpenCV demo
|
| 186 |
+
|
| 187 |
+
```bash
|
| 188 |
+
python ui/live_demo.py
|
| 189 |
+
python ui/live_demo.py --xgb # XGBoost
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
Controls: `m` cycle mesh overlay, `1-5` switch pipeline mode, `q` quit.
|
| 193 |
+
|
| 194 |
+
### Web app (without Docker)
|
| 195 |
+
|
| 196 |
+
```bash
|
| 197 |
+
source venv/bin/activate
|
| 198 |
+
python -m uvicorn main:app --host 0.0.0.0 --port 7860
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
Open http://localhost:7860
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
### Web app (Docker)
|
| 205 |
+
|
| 206 |
+
```bash
|
| 207 |
+
docker-compose up # serves on port 7860
|
| 208 |
+
```
|
| 209 |
+
|
| 210 |
+
---
|
| 211 |
+
|
| 212 |
+
## Data collection
|
| 213 |
+
|
| 214 |
+
```bash
|
| 215 |
+
python -m models.collect_features --name <participant>
|
| 216 |
+
```
|
| 217 |
+
|
| 218 |
+
Records webcam sessions with real-time binary labelling (spacebar toggles focused/unfocused). Saves per-frame feature vectors to `data/collected_<participant>/` as `.npz` files. Raw video is never stored.
|
| 219 |
+
|
| 220 |
+
9 participants recorded 5-10 min sessions across varied environments (144,793 frames total, 61.5% focused / 38.5% unfocused). All participants provided informed consent. Dataset files are not included in this repository.
|
| 221 |
+
|
| 222 |
+
Consent document: https://drive.google.com/file/d/1g1Hc764ffljoKrjApD6nmWDCXJGYTR0j/view?usp=drive_link
|
| 223 |
+
Raw participant dataset is excluded from this submission (coursework policy and privacy constraints). It can be shared with module staff on request: https://drive.google.com/drive/folders/1fwACM6i6uVGFkTlJKSlqVhizzgrHl_gY?usp=sharing
|
| 224 |
+
|
| 225 |
+
---
|
| 226 |
+
|
| 227 |
+
## Pipeline
|
| 228 |
+
|
| 229 |
+
```
|
| 230 |
+
Webcam frame
|
| 231 |
+
--> MediaPipe Face Mesh (478 landmarks)
|
| 232 |
+
--> Head pose (solvePnP): yaw, pitch, roll, s_face, head_deviation
|
| 233 |
+
--> Eye scorer: EAR_left, EAR_right, EAR_avg, s_eye, MAR
|
| 234 |
+
--> Gaze ratios: h_gaze, v_gaze, gaze_offset
|
| 235 |
+
--> Temporal tracker: PERCLOS, blink_rate, closure_dur, yawn_dur
|
| 236 |
+
--> 17 features --> select 10 --> clip to physiological bounds
|
| 237 |
+
--> ML model (MLP / XGBoost) or geometric scorer
|
| 238 |
+
--> Asymmetric EMA smoothing (alpha_up=0.55, alpha_down=0.45)
|
| 239 |
+
--> FOCUSED / UNFOCUSED
|
| 240 |
+
```
|
| 241 |
+
|
| 242 |
+
Five runtime modes share the same feature extraction backbone:
|
| 243 |
+
|
| 244 |
+
| Mode | Description |
|
| 245 |
+
|------|-------------|
|
| 246 |
+
| **Geometric** | Deterministic scoring: 0.7 * s_face + 0.3 * s_eye, cosine-decay with max_angle=22 deg |
|
| 247 |
+
| **XGBoost** | 600-tree gradient-boosted ensemble, threshold 0.28 (LOPO-optimal) |
|
| 248 |
+
| **MLP** | PyTorch 10-64-32-2 perceptron, threshold 0.23 (LOPO-optimal) |
|
| 249 |
+
| **Hybrid** | 30% MLP + 70% geometric ensemble (LOPO F1 = 0.841) |
|
| 250 |
+
| **L2CS** | Deep gaze estimation via L2CS-Net (ResNet50, Gaze360 pretrained) |
|
| 251 |
+
|
| 252 |
+
Any mode can be combined with L2CS Boost mode (35% base + 65% L2CS, fused threshold 0.52). Off-screen gaze produces near-zero L2CS score via cosine decay, acting as a soft veto.
|
| 253 |
+
|
| 254 |
+
---
|
| 255 |
+
|
| 256 |
+
## Training
|
| 257 |
+
|
| 258 |
+
Both scripts read all hyperparameters from `config/default.yaml`.
|
| 259 |
+
|
| 260 |
+
```bash
|
| 261 |
+
python -m models.mlp.train
|
| 262 |
+
python -m models.xgboost.train
|
| 263 |
+
```
|
| 264 |
+
|
| 265 |
+
Outputs: `checkpoints/` (model + scaler) and `evaluation/logs/` (CSVs, JSON summaries).
|
| 266 |
+
|
| 267 |
+
### ClearML experiment tracking
|
| 268 |
+
|
| 269 |
+
```bash
|
| 270 |
+
USE_CLEARML=1 python -m models.mlp.train
|
| 271 |
+
USE_CLEARML=1 CLEARML_QUEUE=gpu python -m models.xgboost.train
|
| 272 |
+
USE_CLEARML=1 python -m evaluation.justify_thresholds --clearml
|
| 273 |
+
```
|
| 274 |
+
|
| 275 |
+
Logs hyperparameters, per-epoch scalars, confusion matrices, ROC curves, model registration, dataset stats, and reproducibility artifacts (config YAML, requirements.txt, git SHA).
|
| 276 |
+
|
| 277 |
+
Reference experiment IDs:
|
| 278 |
+
|
| 279 |
+
| Model | ClearML experiment ID |
|
| 280 |
+
|-------|------------------------|
|
| 281 |
+
| MLP (`models.mlp.train`) | `3899b5aa0c3348b28213a3194322cdf7` |
|
| 282 |
+
| XGBoost (`models.xgboost.train`) | `c0ceb8e7e8194a51a7a31078cc47775c` |
|
| 283 |
+
|
| 284 |
+
---
|
| 285 |
+
|
| 286 |
+
## Evaluation
|
| 287 |
+
|
| 288 |
+
```bash
|
| 289 |
+
python -m evaluation.justify_thresholds # LOPO threshold + weight search
|
| 290 |
+
python -m evaluation.grouped_split_benchmark # pooled vs LOPO comparison
|
| 291 |
+
python -m evaluation.feature_importance # XGBoost gain + LOFO ablation
|
| 292 |
+
```
|
| 293 |
+
|
| 294 |
+
### Results (pooled random split, 15% test)
|
| 295 |
+
|
| 296 |
+
| Model | Accuracy | F1 | ROC-AUC |
|
| 297 |
+
|-------|----------|----|---------|
|
| 298 |
+
| XGBoost (600 trees, depth 8) | 95.87% | 0.959 | 0.991 |
|
| 299 |
+
| MLP (64-32) | 92.92% | 0.929 | 0.971 |
|
| 300 |
+
|
| 301 |
+
### Results (LOPO, 9 participants)
|
| 302 |
+
|
| 303 |
+
| Model | LOPO AUC | Best threshold (Youden's J) | F1 at best threshold |
|
| 304 |
+
|-------|----------|-----------------------------|----------------------|
|
| 305 |
+
| MLP | 0.862 | 0.228 | 0.858 |
|
| 306 |
+
| XGBoost | 0.870 | 0.280 | 0.855 |
|
| 307 |
+
|
| 308 |
+
Best geometric face weight (alpha) = 0.7 (mean LOPO F1 = 0.820).
|
| 309 |
+
Best hybrid MLP weight (w_mlp) = 0.3 (mean LOPO F1 = 0.841).
|
| 310 |
+
|
| 311 |
+
The ~12 pp drop from pooled to LOPO reflects temporal data leakage and confirms LOPO as the primary generalisation metric.
|
| 312 |
+
|
| 313 |
+
### Feature ablation
|
| 314 |
+
|
| 315 |
+
| Channel subset | Mean LOPO F1 |
|
| 316 |
+
|----------------|-------------|
|
| 317 |
+
| All 10 features | 0.829 |
|
| 318 |
+
| Eye state only | 0.807 |
|
| 319 |
+
| Head pose only | 0.748 |
|
| 320 |
+
| Gaze only | 0.726 |
|
| 321 |
+
|
| 322 |
+
Top-5 XGBoost gain: `s_face` (10.27), `ear_right` (9.54), `head_deviation` (8.83), `ear_avg` (6.96), `perclos` (5.68).
|
| 323 |
+
|
| 324 |
+
---
|
| 325 |
+
|
| 326 |
+
## L2CS Gaze Tracking
|
| 327 |
+
|
| 328 |
+
L2CS-Net predicts where your eyes are looking, not just where your head is pointed, catching the scenario where the head faces the screen but eyes wander.
|
| 329 |
+
|
| 330 |
+
**Standalone mode:** Select L2CS as the model.
|
| 331 |
+
|
| 332 |
+
**Boost mode:** Select any other model, then enable the GAZE toggle. L2CS runs alongside the base model with score-level fusion (35% base / 65% L2CS). Off-screen gaze triggers a soft veto.
|
| 333 |
+
|
| 334 |
+
**Calibration:** Click Calibrate during a session. A fullscreen overlay shows 9 target dots (3x3 grid). After all 9 points, a degree-2 polynomial maps gaze angles to screen coordinates with IQR outlier filtering and centre-point bias correction.
|
| 335 |
+
|
| 336 |
+
L2CS weight lookup order in runtime:
|
| 337 |
+
|
| 338 |
+
1. `checkpoints/L2CSNet_gaze360.pkl`
|
| 339 |
+
2. `models/L2CS-Net/models/L2CSNet_gaze360.pkl`
|
| 340 |
+
3. `models/L2CSNet_gaze360.pkl`
|
| 341 |
+
|
| 342 |
+
---
|
| 343 |
+
|
| 344 |
+
## Config
|
| 345 |
+
|
| 346 |
+
All hyperparameters and app settings are in `config/default.yaml`. Override with `FOCUSGUARD_CONFIG=/path/to/custom.yaml`.
|
| 347 |
+
|
| 348 |
+
---
|
| 349 |
+
|
| 350 |
+
## Tests
|
| 351 |
+
|
| 352 |
+
Included checks:
|
| 353 |
+
|
| 354 |
+
- data prep helpers and real split consistency (`test_data_preparation.py`; split test **skips** if `data/collected_*/*.npz` is absent)
|
| 355 |
+
- feature clipping (`test_models_clip_features.py`)
|
| 356 |
+
- pipeline integration (`test_pipeline_integration.py`)
|
| 357 |
+
- gaze calibration / fusion diagnostics (`test_gaze_pipeline.py`)
|
| 358 |
+
- FastAPI health, settings, sessions (`test_health_endpoint.py`, `test_api_settings.py`, `test_api_sessions.py`)
|
| 359 |
+
|
| 360 |
+
```bash
|
| 361 |
+
pytest
|
| 362 |
+
```
|
| 363 |
+
|
| 364 |
+
Coverage is enabled by default via `pytest.ini` (`--cov` / term report). For HTML coverage: `pytest --cov-report=html`.
|
| 365 |
+
|
| 366 |
+
**Stack:** Python, PyTorch, XGBoost, MediaPipe, OpenCV, L2CS-Net, FastAPI, React/Vite, SQLite, Docker, ClearML, pytest.
|
api/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# API package: db, drawing, routes, websocket.
|
api/db.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SQLite DB for focus sessions and user settings."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import json
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
|
| 9 |
+
import aiosqlite
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_db_path() -> str:
|
| 13 |
+
"""Database file path from config or default."""
|
| 14 |
+
try:
|
| 15 |
+
from config import get
|
| 16 |
+
return get("app.db_path") or "focus_guard.db"
|
| 17 |
+
except Exception:
|
| 18 |
+
return "focus_guard.db"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
async def init_database(db_path: str | None = None) -> None:
|
| 22 |
+
"""Create focus_sessions, focus_events, user_settings tables if missing."""
|
| 23 |
+
path = db_path or get_db_path()
|
| 24 |
+
async with aiosqlite.connect(path) as db:
|
| 25 |
+
await db.execute("""
|
| 26 |
+
CREATE TABLE IF NOT EXISTS focus_sessions (
|
| 27 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 28 |
+
start_time TIMESTAMP NOT NULL,
|
| 29 |
+
end_time TIMESTAMP,
|
| 30 |
+
duration_seconds INTEGER DEFAULT 0,
|
| 31 |
+
focus_score REAL DEFAULT 0.0,
|
| 32 |
+
total_frames INTEGER DEFAULT 0,
|
| 33 |
+
focused_frames INTEGER DEFAULT 0,
|
| 34 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 35 |
+
)
|
| 36 |
+
""")
|
| 37 |
+
await db.execute("""
|
| 38 |
+
CREATE TABLE IF NOT EXISTS focus_events (
|
| 39 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 40 |
+
session_id INTEGER NOT NULL,
|
| 41 |
+
timestamp TIMESTAMP NOT NULL,
|
| 42 |
+
is_focused BOOLEAN NOT NULL,
|
| 43 |
+
confidence REAL NOT NULL,
|
| 44 |
+
detection_data TEXT,
|
| 45 |
+
FOREIGN KEY (session_id) REFERENCES focus_sessions (id)
|
| 46 |
+
)
|
| 47 |
+
""")
|
| 48 |
+
await db.execute("""
|
| 49 |
+
CREATE TABLE IF NOT EXISTS user_settings (
|
| 50 |
+
id INTEGER PRIMARY KEY CHECK (id = 1),
|
| 51 |
+
model_name TEXT DEFAULT 'mlp'
|
| 52 |
+
)
|
| 53 |
+
""")
|
| 54 |
+
await db.execute("""
|
| 55 |
+
INSERT OR IGNORE INTO user_settings (id, model_name)
|
| 56 |
+
VALUES (1, 'mlp')
|
| 57 |
+
""")
|
| 58 |
+
await db.commit()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
async def create_session(db_path: str | None = None) -> int:
|
| 62 |
+
"""Insert a new focus session. Returns session id."""
|
| 63 |
+
path = db_path or get_db_path()
|
| 64 |
+
async with aiosqlite.connect(path) as db:
|
| 65 |
+
cursor = await db.execute(
|
| 66 |
+
"INSERT INTO focus_sessions (start_time) VALUES (?)",
|
| 67 |
+
(datetime.now().isoformat(),),
|
| 68 |
+
)
|
| 69 |
+
await db.commit()
|
| 70 |
+
return cursor.lastrowid
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
async def end_session(session_id: int, db_path: str | None = None) -> dict | None:
|
| 74 |
+
"""Close session and return summary (duration, focus_score, etc.)."""
|
| 75 |
+
path = db_path or get_db_path()
|
| 76 |
+
async with aiosqlite.connect(path) as db:
|
| 77 |
+
cursor = await db.execute(
|
| 78 |
+
"SELECT start_time, total_frames, focused_frames FROM focus_sessions WHERE id = ?",
|
| 79 |
+
(session_id,),
|
| 80 |
+
)
|
| 81 |
+
row = await cursor.fetchone()
|
| 82 |
+
if not row:
|
| 83 |
+
return None
|
| 84 |
+
start_time_str, total_frames, focused_frames = row
|
| 85 |
+
start_time = datetime.fromisoformat(start_time_str)
|
| 86 |
+
end_time = datetime.now()
|
| 87 |
+
duration = (end_time - start_time).total_seconds()
|
| 88 |
+
focus_score = focused_frames / total_frames if total_frames > 0 else 0.0
|
| 89 |
+
async with aiosqlite.connect(path) as db:
|
| 90 |
+
await db.execute("""
|
| 91 |
+
UPDATE focus_sessions
|
| 92 |
+
SET end_time = ?, duration_seconds = ?, focus_score = ?
|
| 93 |
+
WHERE id = ?
|
| 94 |
+
""", (end_time.isoformat(), int(duration), focus_score, session_id))
|
| 95 |
+
await db.commit()
|
| 96 |
+
return {
|
| 97 |
+
"session_id": session_id,
|
| 98 |
+
"start_time": start_time_str,
|
| 99 |
+
"end_time": end_time.isoformat(),
|
| 100 |
+
"duration_seconds": int(duration),
|
| 101 |
+
"focus_score": round(focus_score, 3),
|
| 102 |
+
"total_frames": total_frames,
|
| 103 |
+
"focused_frames": focused_frames,
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
async def store_focus_event(
|
| 108 |
+
session_id: int,
|
| 109 |
+
is_focused: bool,
|
| 110 |
+
confidence: float,
|
| 111 |
+
metadata: dict,
|
| 112 |
+
db_path: str | None = None,
|
| 113 |
+
) -> None:
|
| 114 |
+
"""Append one focus event and update session counters."""
|
| 115 |
+
path = db_path or get_db_path()
|
| 116 |
+
async with aiosqlite.connect(path) as db:
|
| 117 |
+
await db.execute("""
|
| 118 |
+
INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
|
| 119 |
+
VALUES (?, ?, ?, ?, ?)
|
| 120 |
+
""", (session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata)))
|
| 121 |
+
await db.execute("""
|
| 122 |
+
UPDATE focus_sessions
|
| 123 |
+
SET total_frames = total_frames + 1,
|
| 124 |
+
focused_frames = focused_frames + ?
|
| 125 |
+
WHERE id = ?
|
| 126 |
+
""", (1 if is_focused else 0, session_id))
|
| 127 |
+
await db.commit()
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class EventBuffer:
|
| 131 |
+
"""Buffer focus events and flush to DB in batches to avoid per-frame writes."""
|
| 132 |
+
|
| 133 |
+
def __init__(self, db_path: str | None = None, flush_interval: float = 2.0):
|
| 134 |
+
self._db_path = db_path or get_db_path()
|
| 135 |
+
self._flush_interval = flush_interval
|
| 136 |
+
self._buf: list = []
|
| 137 |
+
self._lock = asyncio.Lock()
|
| 138 |
+
self._task: asyncio.Task | None = None
|
| 139 |
+
self._total_frames = 0
|
| 140 |
+
self._focused_frames = 0
|
| 141 |
+
|
| 142 |
+
def start(self) -> None:
|
| 143 |
+
if self._task is None:
|
| 144 |
+
self._task = asyncio.create_task(self._flush_loop())
|
| 145 |
+
|
| 146 |
+
async def stop(self) -> None:
|
| 147 |
+
if self._task:
|
| 148 |
+
self._task.cancel()
|
| 149 |
+
try:
|
| 150 |
+
await self._task
|
| 151 |
+
except asyncio.CancelledError:
|
| 152 |
+
pass
|
| 153 |
+
self._task = None
|
| 154 |
+
await self._flush()
|
| 155 |
+
|
| 156 |
+
def add(self, session_id: int, is_focused: bool, confidence: float, metadata: dict) -> None:
|
| 157 |
+
self._buf.append((
|
| 158 |
+
session_id,
|
| 159 |
+
datetime.now().isoformat(),
|
| 160 |
+
is_focused,
|
| 161 |
+
confidence,
|
| 162 |
+
json.dumps(metadata),
|
| 163 |
+
))
|
| 164 |
+
self._total_frames += 1
|
| 165 |
+
if is_focused:
|
| 166 |
+
self._focused_frames += 1
|
| 167 |
+
|
| 168 |
+
async def _flush_loop(self) -> None:
|
| 169 |
+
while True:
|
| 170 |
+
await asyncio.sleep(self._flush_interval)
|
| 171 |
+
await self._flush()
|
| 172 |
+
|
| 173 |
+
async def _flush(self) -> None:
|
| 174 |
+
async with self._lock:
|
| 175 |
+
if not self._buf:
|
| 176 |
+
return
|
| 177 |
+
batch = self._buf[:]
|
| 178 |
+
total = self._total_frames
|
| 179 |
+
focused = self._focused_frames
|
| 180 |
+
self._buf.clear()
|
| 181 |
+
self._total_frames = 0
|
| 182 |
+
self._focused_frames = 0
|
| 183 |
+
if not batch:
|
| 184 |
+
return
|
| 185 |
+
session_id = batch[0][0]
|
| 186 |
+
try:
|
| 187 |
+
async with aiosqlite.connect(self._db_path) as db:
|
| 188 |
+
await db.executemany("""
|
| 189 |
+
INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
|
| 190 |
+
VALUES (?, ?, ?, ?, ?)
|
| 191 |
+
""", batch)
|
| 192 |
+
await db.execute("""
|
| 193 |
+
UPDATE focus_sessions
|
| 194 |
+
SET total_frames = total_frames + ?,
|
| 195 |
+
focused_frames = focused_frames + ?
|
| 196 |
+
WHERE id = ?
|
| 197 |
+
""", (total, focused, session_id))
|
| 198 |
+
await db.commit()
|
| 199 |
+
except Exception as e:
|
| 200 |
+
import logging
|
| 201 |
+
logging.getLogger(__name__).warning("DB flush error: %s", e)
|
api/drawing.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Server-side face mesh and HUD drawing for WebRTC/WS video frames."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from mediapipe.tasks.python.vision import FaceLandmarksConnections
|
| 9 |
+
from models.face_mesh import FaceMeshDetector
|
| 10 |
+
|
| 11 |
+
_FONT = cv2.FONT_HERSHEY_SIMPLEX
|
| 12 |
+
_CYAN = (255, 255, 0)
|
| 13 |
+
_GREEN = (0, 255, 0)
|
| 14 |
+
_MAGENTA = (255, 0, 255)
|
| 15 |
+
_ORANGE = (0, 165, 255)
|
| 16 |
+
_RED = (0, 0, 255)
|
| 17 |
+
_WHITE = (255, 255, 255)
|
| 18 |
+
_LIGHT_GREEN = (144, 238, 144)
|
| 19 |
+
|
| 20 |
+
_TESSELATION_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_TESSELATION]
|
| 21 |
+
_CONTOUR_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_CONTOURS]
|
| 22 |
+
_LEFT_EYEBROW = [70, 63, 105, 66, 107, 55, 65, 52, 53, 46]
|
| 23 |
+
_RIGHT_EYEBROW = [300, 293, 334, 296, 336, 285, 295, 282, 283, 276]
|
| 24 |
+
_NOSE_BRIDGE = [6, 197, 195, 5, 4, 1, 19, 94, 2]
|
| 25 |
+
_LIPS_OUTER = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 409, 270, 269, 267, 0, 37, 39, 40, 185, 61]
|
| 26 |
+
_LIPS_INNER = [78, 95, 88, 178, 87, 14, 317, 402, 318, 324, 308, 415, 310, 311, 312, 13, 82, 81, 80, 191, 78]
|
| 27 |
+
_LEFT_EAR_POINTS = [33, 160, 158, 133, 153, 145]
|
| 28 |
+
_RIGHT_EAR_POINTS = [362, 385, 387, 263, 373, 380]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _lm_px(lm: np.ndarray, idx: int, w: int, h: int) -> tuple[int, int]:
|
| 32 |
+
return (int(lm[idx, 0] * w), int(lm[idx, 1] * h))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _draw_polyline(
|
| 36 |
+
frame: np.ndarray, lm: np.ndarray, indices: list[int], w: int, h: int, color: tuple, thickness: int
|
| 37 |
+
) -> None:
|
| 38 |
+
for i in range(len(indices) - 1):
|
| 39 |
+
cv2.line(
|
| 40 |
+
frame,
|
| 41 |
+
_lm_px(lm, indices[i], w, h),
|
| 42 |
+
_lm_px(lm, indices[i + 1], w, h),
|
| 43 |
+
color,
|
| 44 |
+
thickness,
|
| 45 |
+
cv2.LINE_AA,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def draw_face_mesh(frame: np.ndarray, lm: np.ndarray, w: int, h: int) -> None:
|
| 50 |
+
"""Draw tessellation, contours, eyebrows, nose, lips, eyes, irises, gaze lines on frame."""
|
| 51 |
+
overlay = frame.copy()
|
| 52 |
+
for s, e in _TESSELATION_CONNS:
|
| 53 |
+
cv2.line(overlay, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), (200, 200, 200), 1, cv2.LINE_AA)
|
| 54 |
+
cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame)
|
| 55 |
+
for s, e in _CONTOUR_CONNS:
|
| 56 |
+
cv2.line(frame, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), _CYAN, 1, cv2.LINE_AA)
|
| 57 |
+
_draw_polyline(frame, lm, _LEFT_EYEBROW, w, h, _LIGHT_GREEN, 2)
|
| 58 |
+
_draw_polyline(frame, lm, _RIGHT_EYEBROW, w, h, _LIGHT_GREEN, 2)
|
| 59 |
+
_draw_polyline(frame, lm, _NOSE_BRIDGE, w, h, _ORANGE, 1)
|
| 60 |
+
_draw_polyline(frame, lm, _LIPS_OUTER, w, h, _MAGENTA, 1)
|
| 61 |
+
_draw_polyline(frame, lm, _LIPS_INNER, w, h, (200, 0, 200), 1)
|
| 62 |
+
left_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.LEFT_EYE_INDICES], dtype=np.int32)
|
| 63 |
+
cv2.polylines(frame, [left_pts], True, _GREEN, 2, cv2.LINE_AA)
|
| 64 |
+
right_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.RIGHT_EYE_INDICES], dtype=np.int32)
|
| 65 |
+
cv2.polylines(frame, [right_pts], True, _GREEN, 2, cv2.LINE_AA)
|
| 66 |
+
for indices in [_LEFT_EAR_POINTS, _RIGHT_EAR_POINTS]:
|
| 67 |
+
for idx in indices:
|
| 68 |
+
cv2.circle(frame, _lm_px(lm, idx, w, h), 3, (0, 255, 255), -1, cv2.LINE_AA)
|
| 69 |
+
for iris_idx, eye_inner, eye_outer in [
|
| 70 |
+
(FaceMeshDetector.LEFT_IRIS_INDICES, 133, 33),
|
| 71 |
+
(FaceMeshDetector.RIGHT_IRIS_INDICES, 362, 263),
|
| 72 |
+
]:
|
| 73 |
+
iris_pts = np.array([_lm_px(lm, i, w, h) for i in iris_idx], dtype=np.int32)
|
| 74 |
+
center = iris_pts[0]
|
| 75 |
+
if len(iris_pts) >= 5:
|
| 76 |
+
radii = [np.linalg.norm(iris_pts[j] - center) for j in range(1, 5)]
|
| 77 |
+
radius = max(int(np.mean(radii)), 2)
|
| 78 |
+
cv2.circle(frame, tuple(center), radius, _MAGENTA, 2, cv2.LINE_AA)
|
| 79 |
+
cv2.circle(frame, tuple(center), 2, _WHITE, -1, cv2.LINE_AA)
|
| 80 |
+
eye_cx = int((lm[eye_inner, 0] + lm[eye_outer, 0]) / 2.0 * w)
|
| 81 |
+
eye_cy = int((lm[eye_inner, 1] + lm[eye_outer, 1]) / 2.0 * h)
|
| 82 |
+
dx, dy = center[0] - eye_cx, center[1] - eye_cy
|
| 83 |
+
cv2.line(
|
| 84 |
+
frame,
|
| 85 |
+
tuple(center),
|
| 86 |
+
(int(center[0] + dx * 3), int(center[1] + dy * 3)),
|
| 87 |
+
_RED,
|
| 88 |
+
1,
|
| 89 |
+
cv2.LINE_AA,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def draw_hud(frame: np.ndarray, result: dict, model_name: str) -> None:
|
| 94 |
+
"""Draw status bar and detail overlay (FOCUSED/NOT FOCUSED, conf, s_face, s_eye, MAR, yawn)."""
|
| 95 |
+
h, w = frame.shape[:2]
|
| 96 |
+
is_focused = result["is_focused"]
|
| 97 |
+
status = "FOCUSED" if is_focused else "NOT FOCUSED"
|
| 98 |
+
color = _GREEN if is_focused else _RED
|
| 99 |
+
cv2.rectangle(frame, (0, 0), (w, 55), (0, 0, 0), -1)
|
| 100 |
+
cv2.putText(frame, status, (10, 28), _FONT, 0.8, color, 2, cv2.LINE_AA)
|
| 101 |
+
cv2.putText(frame, model_name.upper(), (w - 150, 28), _FONT, 0.45, _WHITE, 1, cv2.LINE_AA)
|
| 102 |
+
conf = result.get("mlp_prob", result.get("raw_score", 0.0))
|
| 103 |
+
mar_s = f" MAR:{result['mar']:.2f}" if result.get("mar") is not None else ""
|
| 104 |
+
sf, se = result.get("s_face", 0), result.get("s_eye", 0)
|
| 105 |
+
detail = f"conf:{conf:.2f} S_face:{sf:.2f} S_eye:{se:.2f}{mar_s}"
|
| 106 |
+
cv2.putText(frame, detail, (10, 48), _FONT, 0.4, _WHITE, 1, cv2.LINE_AA)
|
| 107 |
+
if result.get("yaw") is not None:
|
| 108 |
+
cv2.putText(
|
| 109 |
+
frame,
|
| 110 |
+
f"yaw:{result['yaw']:+.0f} pitch:{result['pitch']:+.0f} roll:{result['roll']:+.0f}",
|
| 111 |
+
(w - 280, 48),
|
| 112 |
+
_FONT,
|
| 113 |
+
0.4,
|
| 114 |
+
(180, 180, 180),
|
| 115 |
+
1,
|
| 116 |
+
cv2.LINE_AA,
|
| 117 |
+
)
|
| 118 |
+
if result.get("is_yawning"):
|
| 119 |
+
cv2.putText(frame, "YAWN", (10, 75), _FONT, 0.7, _ORANGE, 2, cv2.LINE_AA)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def get_tesselation_connections() -> list[tuple[int, int]]:
|
| 123 |
+
"""Return tessellation edge pairs for client-side face mesh (cached by client)."""
|
| 124 |
+
return list(_TESSELATION_CONNS)
|
app.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from main import app
|
assets/focusguard-demo.gif
ADDED
|
Git LFS Details
|
checkpoints/L2CSNet_gaze360.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8a7f3480d868dd48261e1d59f915b0ef0bb33ea12ea00938fb2168f212080665
|
| 3 |
+
size 95849977
|
checkpoints/README.md
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# checkpoints
|
| 2 |
+
|
| 3 |
+
Checkpoint files are not included in the submission archive.
|
| 4 |
+
|
| 5 |
+
## Required files
|
| 6 |
+
|
| 7 |
+
Place the following files in this folder:
|
| 8 |
+
|
| 9 |
+
- `mlp_best.pt`
|
| 10 |
+
- `xgboost_face_orientation_best.json`
|
| 11 |
+
- `scaler_mlp.joblib`
|
| 12 |
+
- `hybrid_focus_config.json`
|
| 13 |
+
- `hybrid_combiner.joblib`
|
| 14 |
+
- `L2CSNet_gaze360.pkl`
|
| 15 |
+
|
| 16 |
+
## Optional generated files
|
| 17 |
+
|
| 18 |
+
- `meta_best.npz`
|
| 19 |
+
- `meta_mlp.npz`
|
| 20 |
+
|
| 21 |
+
They are metadata artifacts and are not required for standard inference.
|
| 22 |
+
|
| 23 |
+
## Download sources
|
| 24 |
+
|
| 25 |
+
Use any one source:
|
| 26 |
+
|
| 27 |
+
- Hugging Face Space files: [checkpoints folder](https://huggingface.co/spaces/FocusGuard/final_v2/tree/main/checkpoints)
|
| 28 |
+
- ClearML project: [FocusGuards Large Group Project](https://app.5ccsagap.er.kcl.ac.uk/projects/ce218b2f751641c68042f8fa216f8746/experiments)
|
| 29 |
+
- Google Drive fallback: [checkpoint folder](https://drive.google.com/drive/folders/15yYHKgCHg5AFIBb04XnVaeqHRukwBLAd?usp=drive_link)
|
| 30 |
+
|
| 31 |
+
## Verify files
|
| 32 |
+
|
| 33 |
+
Run from repo root:
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
ls -lh checkpoints
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
You should see all required filenames above.
|
| 40 |
+
|
| 41 |
+
## L2CS lookup order in runtime
|
| 42 |
+
|
| 43 |
+
The app checks for L2CS weights in this order:
|
| 44 |
+
|
| 45 |
+
1. `checkpoints/L2CSNet_gaze360.pkl`
|
| 46 |
+
2. `models/L2CS-Net/models/L2CSNet_gaze360.pkl`
|
| 47 |
+
3. `models/L2CSNet_gaze360.pkl`
|
checkpoints/hybrid_combiner.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7e460c6ca8d2cadf37727456401a0d63028ba23cb6401f0835d869abfa2e053c
|
| 3 |
+
size 965
|
checkpoints/hybrid_focus_config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"use_xgb": false,
|
| 3 |
+
"w_mlp": 0.3,
|
| 4 |
+
"w_geo": 0.7,
|
| 5 |
+
"threshold": 0.35,
|
| 6 |
+
"use_yawn_veto": true,
|
| 7 |
+
"geo_face_weight": 0.7,
|
| 8 |
+
"geo_eye_weight": 0.3,
|
| 9 |
+
"mar_yawn_threshold": 0.55
|
| 10 |
+
}
|
checkpoints/meta_best.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5d78d1df5e25536a2c82c4b8f5fd0c26dd35f44b28fd59761634cbf78c7546f8
|
| 3 |
+
size 4196
|
checkpoints/meta_mlp.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4771c61cdf0711aa640b4d600a0851d344414cd16c1c2f75afc90e3c6135d14b
|
| 3 |
+
size 840
|
checkpoints/mlp_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c2f55129785b6882c304483aa5399f5bf6c9ed6e73dfec7ca6f36cd0436156c8
|
| 3 |
+
size 14497
|
checkpoints/scaler_mlp.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2038d5b051d4de303c5688b1b861a0b53b1307a52b9447bfa48e8c7ace749329
|
| 3 |
+
size 823
|
checkpoints/xgboost_face_orientation_best.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
config/README.md
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# config
|
| 2 |
+
|
| 3 |
+
Centralised configuration for FocusGuard. Every training script, pipeline, and evaluation tool reads from this package rather than hardcoding values. This ensures reproducibility across experiments and consistent behaviour between training and deployment.
|
| 4 |
+
|
| 5 |
+
## Files
|
| 6 |
+
|
| 7 |
+
| File | Purpose |
|
| 8 |
+
|------|---------|
|
| 9 |
+
| `default.yaml` | Single source of truth for all hyperparameters, thresholds, clipping bounds, and app settings |
|
| 10 |
+
| `__init__.py` | YAML loader with dotted-key access (`get("pipeline.mlp_threshold")`), ClearML flattener, and project-wide constants |
|
| 11 |
+
| `clearml_enrich.py` | ClearML experiment tracking helpers: environment tags, config/requirements upload, model metadata |
|
| 12 |
+
|
| 13 |
+
## Usage
|
| 14 |
+
|
| 15 |
+
```python
|
| 16 |
+
from config import get
|
| 17 |
+
|
| 18 |
+
lr = get("mlp.lr") # 0.001
|
| 19 |
+
threshold = get("pipeline.mlp_threshold") # 0.23
|
| 20 |
+
clip_yaw = get("data.clip.yaw") # [-45, 45]
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
Override the default path by setting `FOCUSGUARD_CONFIG` to point at a different YAML file.
|
| 24 |
+
|
| 25 |
+
## Key sections in default.yaml
|
| 26 |
+
|
| 27 |
+
| Section | What it controls |
|
| 28 |
+
|---------|-----------------|
|
| 29 |
+
| `app` | DB path, inference size (640x480), 4 workers, default model |
|
| 30 |
+
| `l2cs_boost` | Boost-mode fusion weights (35% base / 65% L2CS), fused threshold 0.52 |
|
| 31 |
+
| `mlp` | 30 epochs, batch 32, lr 0.001, hidden sizes [64, 32] |
|
| 32 |
+
| `xgboost` | 600 trees, depth 8, lr 0.1489, regularisation from 40-trial Optuna sweep |
|
| 33 |
+
| `data.clip` | Physiological clipping: yaw +/-45, pitch +/-30, EAR [0, 0.85], MAR [0, 1.0] |
|
| 34 |
+
| `pipeline.geometric` | max_angle 22 deg, face/eye weights 0.7/0.3, asymmetric EMA (alpha_up=0.55, alpha_down=0.45) |
|
| 35 |
+
| `pipeline` | Production thresholds: MLP 0.23, XGBoost 0.28, hybrid 0.35 (all derived from LOPO Youden's J) |
|
| 36 |
+
| `evaluation` | Seed 42, weight search ranges for geometric alpha and hybrid w_mlp |
|
| 37 |
+
|
| 38 |
+
## ClearML enrichment
|
| 39 |
+
|
| 40 |
+
`clearml_enrich.py` provides reusable helpers called by all training and evaluation scripts:
|
| 41 |
+
|
| 42 |
+
- `enrich_task(task, role)` adds tags (Python version, OS, torch/CUDA, git SHA)
|
| 43 |
+
- `upload_repro_artifacts(task)` pins the exact YAML config and requirements.txt
|
| 44 |
+
- `attach_output_metrics(model, metrics)` surfaces headline metrics on registered model cards
|
| 45 |
+
- `task_done_summary(task, summary)` sets a human-readable task comment
|
config/__init__.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Load app and model config from YAML. Single source for hyperparameters and tunables."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
# ClearML UI project name (must match the project in your ClearML workspace).
|
| 8 |
+
CLEARML_PROJECT_NAME = "FocusGuards Large Group Project"
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
_CONFIG: dict[str, Any] | None = None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _default_path() -> Path:
|
| 16 |
+
return Path(__file__).resolve().parent / "default.yaml"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def load_config(path: str | Path | None = None) -> dict[str, Any]:
|
| 20 |
+
"""Load YAML config. Uses FOCUSGUARD_CONFIG env or config/default.yaml."""
|
| 21 |
+
global _CONFIG
|
| 22 |
+
if _CONFIG is not None:
|
| 23 |
+
return _CONFIG
|
| 24 |
+
import yaml
|
| 25 |
+
p = path or os.environ.get("FOCUSGUARD_CONFIG") or _default_path()
|
| 26 |
+
p = Path(p)
|
| 27 |
+
if not p.is_file():
|
| 28 |
+
_CONFIG = {}
|
| 29 |
+
return _CONFIG
|
| 30 |
+
with open(p, "r", encoding="utf-8") as f:
|
| 31 |
+
_CONFIG = yaml.safe_load(f) or {}
|
| 32 |
+
return _CONFIG
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get(key_path: str, default: Any = None) -> Any:
|
| 36 |
+
"""Return a nested config value. E.g. get('app.db_path'), get('mlp.epochs')."""
|
| 37 |
+
cfg = load_config()
|
| 38 |
+
for part in key_path.split("."):
|
| 39 |
+
if not isinstance(cfg, dict) or part not in cfg:
|
| 40 |
+
return default
|
| 41 |
+
cfg = cfg[part]
|
| 42 |
+
return cfg
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def flatten_for_clearml(cfg: dict[str, Any] | None = None, prefix: str = "") -> dict[str, Any]:
|
| 46 |
+
"""Flatten nested config so every value appears as a ClearML task parameter (no nested dicts)."""
|
| 47 |
+
cfg = cfg if cfg is not None else load_config()
|
| 48 |
+
out = {}
|
| 49 |
+
for k, v in cfg.items():
|
| 50 |
+
key = f"{prefix}/{k}" if prefix else k
|
| 51 |
+
if isinstance(v, dict) and v and not any(isinstance(x, (dict, list)) for x in v.values()):
|
| 52 |
+
for k2, v2 in v.items():
|
| 53 |
+
out[f"{key}/{k2}"] = v2
|
| 54 |
+
elif isinstance(v, dict) and v:
|
| 55 |
+
out.update(flatten_for_clearml(v, key))
|
| 56 |
+
elif isinstance(v, list):
|
| 57 |
+
out[key] = str(v)
|
| 58 |
+
else:
|
| 59 |
+
out[key] = v
|
| 60 |
+
return out
|
config/clearml_enrich.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Extra ClearML polish: env tags, config snapshot, output model metadata."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import subprocess
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def project_root() -> Path:
|
| 12 |
+
return Path(__file__).resolve().parent.parent
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def active_config_path() -> Path:
|
| 16 |
+
env = os.environ.get("FOCUSGUARD_CONFIG")
|
| 17 |
+
if env:
|
| 18 |
+
return Path(env).expanduser()
|
| 19 |
+
return Path(__file__).resolve().parent / "default.yaml"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def enrich_task(task, *, role: str) -> None:
|
| 23 |
+
"""Tags for filtering in the UI (Python, OS, torch device, git revision)."""
|
| 24 |
+
tags = [
|
| 25 |
+
role,
|
| 26 |
+
f"py{sys.version_info.major}{sys.version_info.minor}",
|
| 27 |
+
sys.platform.replace(" ", "_"),
|
| 28 |
+
]
|
| 29 |
+
try:
|
| 30 |
+
import torch
|
| 31 |
+
|
| 32 |
+
ver = torch.__version__.split("+")[0].replace(".", "_")
|
| 33 |
+
tags.append(f"torch_{ver}")
|
| 34 |
+
tags.append("cuda" if torch.cuda.is_available() else "cpu")
|
| 35 |
+
except ImportError:
|
| 36 |
+
tags.append("no_torch")
|
| 37 |
+
rev = _git_short_rev()
|
| 38 |
+
if rev:
|
| 39 |
+
tags.append(f"git_{rev}")
|
| 40 |
+
task.add_tags(tags)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _git_short_rev() -> str | None:
|
| 44 |
+
root = project_root()
|
| 45 |
+
try:
|
| 46 |
+
p = subprocess.run(
|
| 47 |
+
["git", "rev-parse", "--short", "HEAD"],
|
| 48 |
+
cwd=str(root),
|
| 49 |
+
capture_output=True,
|
| 50 |
+
text=True,
|
| 51 |
+
timeout=6,
|
| 52 |
+
check=False,
|
| 53 |
+
)
|
| 54 |
+
if p.returncode == 0 and p.stdout:
|
| 55 |
+
return p.stdout.strip()
|
| 56 |
+
except (OSError, subprocess.TimeoutExpired):
|
| 57 |
+
pass
|
| 58 |
+
return None
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def upload_repro_artifacts(task) -> None:
|
| 62 |
+
"""Pin the exact YAML + requirements file used for this run."""
|
| 63 |
+
cfg = active_config_path()
|
| 64 |
+
if cfg.is_file():
|
| 65 |
+
task.upload_artifact(name="config_yaml", artifact_object=str(cfg))
|
| 66 |
+
req = project_root() / "requirements.txt"
|
| 67 |
+
if req.is_file():
|
| 68 |
+
task.upload_artifact(name="requirements_txt", artifact_object=str(req))
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def attach_output_metrics(output_model, metrics: dict[str, float | str]) -> None:
|
| 72 |
+
"""Surface headline metrics on the registered model card."""
|
| 73 |
+
for k, v in metrics.items():
|
| 74 |
+
key = str(k).replace("/", "_")
|
| 75 |
+
try:
|
| 76 |
+
output_model.set_metadata(key, str(v))
|
| 77 |
+
except Exception:
|
| 78 |
+
pass
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def task_done_summary(task, summary: str) -> None:
|
| 82 |
+
setter = getattr(task, "set_comment", None)
|
| 83 |
+
if callable(setter):
|
| 84 |
+
try:
|
| 85 |
+
setter(summary)
|
| 86 |
+
except Exception:
|
| 87 |
+
pass
|
config/default.yaml
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FocusGuard app and model config. Override with FOCUSGUARD_CONFIG env path if needed.
|
| 2 |
+
|
| 3 |
+
app:
|
| 4 |
+
db_path: "focus_guard.db"
|
| 5 |
+
inference_size: [640, 480]
|
| 6 |
+
inference_workers: 4
|
| 7 |
+
default_model: "mlp"
|
| 8 |
+
calibration_verify_target: [0.5, 0.5]
|
| 9 |
+
no_face_confidence_cap: 0.1
|
| 10 |
+
|
| 11 |
+
l2cs_boost:
|
| 12 |
+
base_weight: 0.35
|
| 13 |
+
l2cs_weight: 0.65
|
| 14 |
+
veto_threshold: 0.38
|
| 15 |
+
fused_threshold: 0.52
|
| 16 |
+
|
| 17 |
+
mlp:
|
| 18 |
+
model_name: "face_orientation"
|
| 19 |
+
epochs: 30
|
| 20 |
+
batch_size: 32
|
| 21 |
+
lr: 0.001
|
| 22 |
+
seed: 42
|
| 23 |
+
split_ratios: [0.7, 0.15, 0.15]
|
| 24 |
+
hidden_sizes: [64, 32]
|
| 25 |
+
|
| 26 |
+
xgboost:
|
| 27 |
+
n_estimators: 600
|
| 28 |
+
max_depth: 8
|
| 29 |
+
learning_rate: 0.1489
|
| 30 |
+
subsample: 0.9625
|
| 31 |
+
colsample_bytree: 0.9013
|
| 32 |
+
reg_alpha: 1.1407
|
| 33 |
+
reg_lambda: 2.4181
|
| 34 |
+
eval_metric: "logloss"
|
| 35 |
+
|
| 36 |
+
data:
|
| 37 |
+
split_ratios: [0.7, 0.15, 0.15]
|
| 38 |
+
clip:
|
| 39 |
+
yaw: [-45, 45]
|
| 40 |
+
pitch: [-30, 30]
|
| 41 |
+
roll: [-30, 30]
|
| 42 |
+
ear: [0, 0.85]
|
| 43 |
+
mar: [0, 1.0]
|
| 44 |
+
gaze_offset: [0, 0.50]
|
| 45 |
+
perclos: [0, 0.80]
|
| 46 |
+
blink_rate: [0, 30.0]
|
| 47 |
+
closure_duration: [0, 10.0]
|
| 48 |
+
yawn_duration: [0, 10.0]
|
| 49 |
+
|
| 50 |
+
pipeline:
|
| 51 |
+
geometric:
|
| 52 |
+
max_angle: 22.0
|
| 53 |
+
alpha: 0.7
|
| 54 |
+
beta: 0.3
|
| 55 |
+
threshold: 0.55
|
| 56 |
+
smoother:
|
| 57 |
+
alpha_up: 0.55
|
| 58 |
+
alpha_down: 0.45
|
| 59 |
+
grace_frames: 10
|
| 60 |
+
hybrid_defaults:
|
| 61 |
+
w_mlp: 0.3
|
| 62 |
+
w_geo: 0.7
|
| 63 |
+
threshold: 0.35
|
| 64 |
+
geo_face_weight: 0.7
|
| 65 |
+
geo_eye_weight: 0.3
|
| 66 |
+
mlp_threshold: 0.23
|
| 67 |
+
xgboost_threshold: 0.28
|
| 68 |
+
|
| 69 |
+
evaluation:
|
| 70 |
+
seed: 42
|
| 71 |
+
mlp_sklearn:
|
| 72 |
+
hidden_layer_sizes: [64, 32]
|
| 73 |
+
max_iter: 200
|
| 74 |
+
validation_fraction: 0.15
|
| 75 |
+
geo_weights:
|
| 76 |
+
face: 0.7
|
| 77 |
+
eye: 0.3
|
| 78 |
+
threshold_search:
|
| 79 |
+
alphas: [0.2, 0.85]
|
| 80 |
+
w_mlps: [0.3, 0.85]
|
data_preparation/README.md
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# data_preparation
|
| 2 |
+
|
| 3 |
+
Handles loading, splitting, scaling, and serving the collected dataset for training and evaluation.
|
| 4 |
+
|
| 5 |
+
## Links
|
| 6 |
+
|
| 7 |
+
- Participant consent form: [Consent document](https://drive.google.com/file/d/1g1Hc764ffljoKrjApD6nmWDCXJGYTR0j/view?usp=drive_link)
|
| 8 |
+
- Dataset (staff access): [Dataset folder](https://drive.google.com/drive/folders/1fwACM6i6uVGFkTlJKSlqVhizzgrHl_gY?usp=sharing)
|
| 9 |
+
|
| 10 |
+
## Data collection protocol
|
| 11 |
+
|
| 12 |
+
9 team members each recorded 5-10 minute webcam sessions using a purpose-built tool (`models/collect_features.py`). During recording:
|
| 13 |
+
|
| 14 |
+
- Participants simulated **focused** behaviour (reading, typing) and **unfocused** behaviour (looking at phone, turning away)
|
| 15 |
+
- Binary labels were annotated in real-time via key presses
|
| 16 |
+
- Sessions were recorded across different rooms, workspaces, and home offices using consumer webcams under varying lighting
|
| 17 |
+
- Real-time quality guidance warned if class balance fell outside 30-70% or if fewer than 10 state transitions occurred
|
| 18 |
+
- An automated post-collection quality report validated minimum duration (120s), sample count (3,000+ frames), balance, and transition frequency
|
| 19 |
+
|
| 20 |
+
All participants provided informed consent for their facial landmark data to be used within this coursework project. Raw video frames are never stored; only the 17-dimensional feature vector and binary labels are saved.
|
| 21 |
+
|
| 22 |
+
Raw participant dataset is excluded from this repository (coursework policy and privacy constraints). It is shared separately via the dataset link above.
|
| 23 |
+
|
| 24 |
+
## Dataset summary
|
| 25 |
+
|
| 26 |
+
| Metric | Value |
|
| 27 |
+
|--------|-------|
|
| 28 |
+
| Participants | 9 |
|
| 29 |
+
| Total frames | 144,793 |
|
| 30 |
+
| Class balance | 61.5% focused / 38.5% unfocused |
|
| 31 |
+
| Features extracted | 17 per frame |
|
| 32 |
+
| Features selected | 10 (used by ML models) |
|
| 33 |
+
|
| 34 |
+
## Data format
|
| 35 |
+
|
| 36 |
+
Training data lives under `data/collected_<participant>/` as `.npz` files. Each file contains:
|
| 37 |
+
|
| 38 |
+
| Key | Shape | Description |
|
| 39 |
+
|-----|-------|-------------|
|
| 40 |
+
| `features` | (N, 17) | Float array of extracted features |
|
| 41 |
+
| `labels` | (N,) | Binary: 0 = unfocused, 1 = focused |
|
| 42 |
+
| `feature_names` | (17,) | String names matching `FEATURE_NAMES` in `collect_features.py` |
|
| 43 |
+
|
| 44 |
+
Data files are not included in this repository due to privacy considerations.
|
| 45 |
+
|
| 46 |
+
## Files
|
| 47 |
+
|
| 48 |
+
| File | Purpose |
|
| 49 |
+
|------|---------|
|
| 50 |
+
| `prepare_dataset.py` | Core data pipeline: loads `.npz`, applies feature selection, stratified splits, StandardScaler on train only |
|
| 51 |
+
| `data_exploration.ipynb` | Exploratory analysis: feature distributions, class balance, per-person statistics, correlation heatmaps |
|
| 52 |
+
|
| 53 |
+
## Feature selection
|
| 54 |
+
|
| 55 |
+
`SELECTED_FEATURES["face_orientation"]` defines the 10 features used by all ML models:
|
| 56 |
+
|
| 57 |
+
**Head pose (3):** `head_deviation`, `s_face`, `pitch`
|
| 58 |
+
**Eye state (4):** `ear_left`, `ear_right`, `ear_avg`, `perclos`
|
| 59 |
+
**Gaze (3):** `h_gaze`, `gaze_offset`, `s_eye`
|
| 60 |
+
|
| 61 |
+
Excluded: `v_gaze` (noisy), `mar` (1.7% trigger rate), `yaw`/`roll` (redundant with `head_deviation`/`s_face`), `blink_rate`/`closure_duration`/`yawn_duration` (temporal overlap with `perclos`).
|
| 62 |
+
|
| 63 |
+
Selection was validated by XGBoost gain importance and LOPO channel ablation:
|
| 64 |
+
|
| 65 |
+
| Channel subset | Mean LOPO F1 |
|
| 66 |
+
|---------------|-------------|
|
| 67 |
+
| All 10 features | 0.829 |
|
| 68 |
+
| Eye state only | 0.807 |
|
| 69 |
+
| Head pose only | 0.748 |
|
| 70 |
+
| Gaze only | 0.726 |
|
| 71 |
+
|
| 72 |
+
## Key functions
|
| 73 |
+
|
| 74 |
+
| Function | What it does |
|
| 75 |
+
|----------|-------------|
|
| 76 |
+
| `load_all_pooled(model_name)` | Concatenates all participant data into one array |
|
| 77 |
+
| `load_per_person(model_name)` | Returns `{person: (X, y)}` dict for LOPO cross-validation |
|
| 78 |
+
| `get_numpy_splits(model_name)` | Returns scaled train/val/test numpy arrays (70/15/15 split) |
|
| 79 |
+
| `get_dataloaders(model_name)` | Returns PyTorch DataLoaders for MLP training |
|
| 80 |
+
| `get_default_split_config()` | Returns split ratios and seed from `config/default.yaml` |
|
| 81 |
+
|
| 82 |
+
## Data cleaning
|
| 83 |
+
|
| 84 |
+
Applied before splitting (in `ui/pipeline.py` at inference, in `prepare_dataset.py` for training):
|
| 85 |
+
|
| 86 |
+
1. Angles clipped to physiological ranges (yaw +/-45, pitch/roll +/-30)
|
| 87 |
+
2. `head_deviation` recomputed from clipped angles (not clipped after computation)
|
| 88 |
+
3. EAR clipped to [0, 0.85], MAR to [0, 1.0]
|
| 89 |
+
4. Physiological bounds on gaze_offset, PERCLOS, blink_rate, closure/yawn duration
|
| 90 |
+
5. StandardScaler fit on training split only, applied to val/test
|
data_preparation/__init__.py
ADDED
|
File without changes
|
data_preparation/data_exploration.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data_preparation/prepare_dataset.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Single source for pooled train/val/test data and splits.
|
| 3 |
+
|
| 4 |
+
- Data: load_all_pooled() / load_per_person() from data/collected_*/*.npz (same pattern everywhere).
|
| 5 |
+
- Splits: get_numpy_splits() / get_dataloaders() use stratified train/val/test with a fixed seed from config.
|
| 6 |
+
- Test is held out before any preprocessing; StandardScaler is fit on train only, then applied to val and test.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import glob
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
from sklearn.preprocessing import StandardScaler
|
| 14 |
+
from sklearn.model_selection import train_test_split
|
| 15 |
+
|
| 16 |
+
torch = None
|
| 17 |
+
Dataset = object # type: ignore
|
| 18 |
+
DataLoader = None
|
| 19 |
+
|
| 20 |
+
# Defaults for stratified split (overridden by config when available)
|
| 21 |
+
_DEFAULT_SPLIT_RATIOS = (0.7, 0.15, 0.15)
|
| 22 |
+
_DEFAULT_SPLIT_SEED = 42
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _require_torch():
|
| 26 |
+
global torch, Dataset, DataLoader
|
| 27 |
+
if torch is None:
|
| 28 |
+
try:
|
| 29 |
+
import torch as _torch
|
| 30 |
+
from torch.utils.data import Dataset as _Dataset, DataLoader as _DataLoader
|
| 31 |
+
except ImportError as exc: # pragma: no cover
|
| 32 |
+
raise ImportError("PyTorch not installed") from exc
|
| 33 |
+
|
| 34 |
+
torch = _torch
|
| 35 |
+
Dataset = _Dataset # type: ignore
|
| 36 |
+
DataLoader = _DataLoader # type: ignore
|
| 37 |
+
|
| 38 |
+
return torch, Dataset, DataLoader
|
| 39 |
+
|
| 40 |
+
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
|
| 41 |
+
|
| 42 |
+
SELECTED_FEATURES = {
|
| 43 |
+
"face_orientation": [
|
| 44 |
+
'head_deviation', 's_face', 's_eye', 'h_gaze', 'pitch',
|
| 45 |
+
'ear_left', 'ear_avg', 'ear_right', 'gaze_offset', 'perclos'
|
| 46 |
+
],
|
| 47 |
+
"eye_behaviour": [
|
| 48 |
+
'ear_left', 'ear_right', 'ear_avg', 'mar',
|
| 49 |
+
'blink_rate', 'closure_duration', 'perclos', 'yawn_duration'
|
| 50 |
+
]
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class FeatureVectorDataset(Dataset):
|
| 55 |
+
def __init__(self, features: np.ndarray, labels: np.ndarray):
|
| 56 |
+
torch_mod, _, _ = _require_torch()
|
| 57 |
+
self.features = torch_mod.tensor(features, dtype=torch_mod.float32)
|
| 58 |
+
self.labels = torch_mod.tensor(labels, dtype=torch_mod.long)
|
| 59 |
+
|
| 60 |
+
def __len__(self):
|
| 61 |
+
return len(self.labels)
|
| 62 |
+
|
| 63 |
+
def __getitem__(self, idx):
|
| 64 |
+
return self.features[idx], self.labels[idx]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ── Low-level helpers ────────────────────────────────────────────────────
|
| 68 |
+
|
| 69 |
+
def _clean_npz(raw, names):
|
| 70 |
+
"""Apply clipping rules in-place. Shared by all loaders."""
|
| 71 |
+
for col, lo, hi in [('yaw', -45, 45), ('pitch', -30, 30), ('roll', -30, 30)]:
|
| 72 |
+
if col in names:
|
| 73 |
+
raw[:, names.index(col)] = np.clip(raw[:, names.index(col)], lo, hi)
|
| 74 |
+
for feat in ['ear_left', 'ear_right', 'ear_avg']:
|
| 75 |
+
if feat in names:
|
| 76 |
+
raw[:, names.index(feat)] = np.clip(raw[:, names.index(feat)], 0, 0.85)
|
| 77 |
+
return raw
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _load_one_npz(npz_path, target_features):
|
| 81 |
+
"""Load a single .npz file, clean and select features. Returns (X, y, selected_feature_names)."""
|
| 82 |
+
data = np.load(npz_path, allow_pickle=True)
|
| 83 |
+
raw = data['features'].astype(np.float32)
|
| 84 |
+
labels = data['labels'].astype(np.int64)
|
| 85 |
+
names = list(data['feature_names'])
|
| 86 |
+
raw = _clean_npz(raw, names)
|
| 87 |
+
selected = [f for f in target_features if f in names]
|
| 88 |
+
idx = [names.index(f) for f in selected]
|
| 89 |
+
return raw[:, idx], labels, selected
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# ── Public data loaders ──────────────────────────────────────────────────
|
| 93 |
+
|
| 94 |
+
def load_all_pooled(model_name: str = "face_orientation", data_dir: str = None):
|
| 95 |
+
"""Load all collected_*/*.npz, clean, select features, concatenate.
|
| 96 |
+
|
| 97 |
+
Returns (X_all, y_all, all_feature_names).
|
| 98 |
+
"""
|
| 99 |
+
data_dir = data_dir or DATA_DIR
|
| 100 |
+
target_features = SELECTED_FEATURES.get(model_name, SELECTED_FEATURES["face_orientation"])
|
| 101 |
+
pattern = os.path.join(data_dir, "collected_*", "*.npz")
|
| 102 |
+
npz_files = sorted(glob.glob(pattern))
|
| 103 |
+
|
| 104 |
+
if not npz_files:
|
| 105 |
+
raise FileNotFoundError(
|
| 106 |
+
f"No .npz files matching {pattern}. "
|
| 107 |
+
"Collect data first with `python -m models.collect_features --name <name>`."
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
all_X, all_y = [], []
|
| 111 |
+
all_names = None
|
| 112 |
+
for npz_path in npz_files:
|
| 113 |
+
X, y, names = _load_one_npz(npz_path, target_features)
|
| 114 |
+
if all_names is None:
|
| 115 |
+
all_names = names
|
| 116 |
+
all_X.append(X)
|
| 117 |
+
all_y.append(y)
|
| 118 |
+
print(f"[DATA] + {os.path.basename(npz_path)}: {X.shape[0]} samples")
|
| 119 |
+
|
| 120 |
+
X_all = np.concatenate(all_X, axis=0)
|
| 121 |
+
y_all = np.concatenate(all_y, axis=0)
|
| 122 |
+
print(f"[DATA] Loaded {len(npz_files)} file(s) for '{model_name}': "
|
| 123 |
+
f"{X_all.shape[0]} total samples, {X_all.shape[1]} features")
|
| 124 |
+
return X_all, y_all, all_names
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def load_per_person(model_name: str = "face_orientation", data_dir: str = None):
|
| 128 |
+
"""Load collected_*/*.npz grouped by person (folder name).
|
| 129 |
+
|
| 130 |
+
Returns dict { person_name: (X, y) } where X/y are per-person numpy arrays.
|
| 131 |
+
Also returns (X_all, y_all) as pooled data.
|
| 132 |
+
"""
|
| 133 |
+
data_dir = data_dir or DATA_DIR
|
| 134 |
+
target_features = SELECTED_FEATURES.get(model_name, SELECTED_FEATURES["face_orientation"])
|
| 135 |
+
pattern = os.path.join(data_dir, "collected_*", "*.npz")
|
| 136 |
+
npz_files = sorted(glob.glob(pattern))
|
| 137 |
+
|
| 138 |
+
if not npz_files:
|
| 139 |
+
raise FileNotFoundError(f"No .npz files matching {pattern}")
|
| 140 |
+
|
| 141 |
+
by_person = {}
|
| 142 |
+
all_X, all_y = [], []
|
| 143 |
+
for npz_path in npz_files:
|
| 144 |
+
folder = os.path.basename(os.path.dirname(npz_path))
|
| 145 |
+
person = folder.replace("collected_", "", 1)
|
| 146 |
+
X, y, _ = _load_one_npz(npz_path, target_features)
|
| 147 |
+
all_X.append(X)
|
| 148 |
+
all_y.append(y)
|
| 149 |
+
if person not in by_person:
|
| 150 |
+
by_person[person] = []
|
| 151 |
+
by_person[person].append((X, y))
|
| 152 |
+
print(f"[DATA] + {person}/{os.path.basename(npz_path)}: {X.shape[0]} samples")
|
| 153 |
+
|
| 154 |
+
for person, chunks in by_person.items():
|
| 155 |
+
by_person[person] = (
|
| 156 |
+
np.concatenate([c[0] for c in chunks], axis=0),
|
| 157 |
+
np.concatenate([c[1] for c in chunks], axis=0),
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
X_all = np.concatenate(all_X, axis=0)
|
| 161 |
+
y_all = np.concatenate(all_y, axis=0)
|
| 162 |
+
print(f"[DATA] {len(by_person)} persons, {X_all.shape[0]} total samples, {X_all.shape[1]} features")
|
| 163 |
+
return by_person, X_all, y_all
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def load_raw_npz(npz_path):
|
| 167 |
+
"""Load a single .npz without cleaning or feature selection. For exploration notebooks."""
|
| 168 |
+
data = np.load(npz_path, allow_pickle=True)
|
| 169 |
+
features = data['features'].astype(np.float32)
|
| 170 |
+
labels = data['labels'].astype(np.int64)
|
| 171 |
+
names = list(data['feature_names'])
|
| 172 |
+
return features, labels, names
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# ── Legacy helpers (used by models/mlp/train.py and models/xgboost/train.py) ─
|
| 176 |
+
|
| 177 |
+
def _load_real_data(model_name: str):
|
| 178 |
+
X, y, _ = load_all_pooled(model_name)
|
| 179 |
+
return X, y
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _generate_synthetic_data(model_name: str):
|
| 183 |
+
target_features = SELECTED_FEATURES.get(model_name, SELECTED_FEATURES["face_orientation"])
|
| 184 |
+
n = 500
|
| 185 |
+
d = len(target_features)
|
| 186 |
+
c = 2
|
| 187 |
+
rng = np.random.RandomState(42)
|
| 188 |
+
features = rng.randn(n, d).astype(np.float32)
|
| 189 |
+
labels = rng.randint(0, c, size=n).astype(np.int64)
|
| 190 |
+
print(f"[DATA] Using synthetic data for '{model_name}': {n} samples, {d} features, {c} classes")
|
| 191 |
+
return features, labels
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def get_default_split_config():
|
| 195 |
+
"""Return (split_ratios, seed) from config so all scripts use the same split. Reproducible and consistent."""
|
| 196 |
+
try:
|
| 197 |
+
from config import get
|
| 198 |
+
data = get("data") or {}
|
| 199 |
+
ratios = data.get("split_ratios", list(_DEFAULT_SPLIT_RATIOS))
|
| 200 |
+
seed = get("mlp.seed") or _DEFAULT_SPLIT_SEED
|
| 201 |
+
return (tuple(ratios), int(seed))
|
| 202 |
+
except Exception:
|
| 203 |
+
return (_DEFAULT_SPLIT_RATIOS, _DEFAULT_SPLIT_SEED)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def _split_and_scale(features, labels, split_ratios, seed, scale):
|
| 207 |
+
"""Stratified train/val/test split. Test is held out first; val is split from the rest.
|
| 208 |
+
No training data is used for validation or test. Scaler is fit on train only, then
|
| 209 |
+
applied to val and test (no leakage from val/test into scaling).
|
| 210 |
+
"""
|
| 211 |
+
test_ratio = split_ratios[2]
|
| 212 |
+
val_ratio = split_ratios[1] / (split_ratios[0] + split_ratios[1])
|
| 213 |
+
|
| 214 |
+
X_train_val, X_test, y_train_val, y_test = train_test_split(
|
| 215 |
+
features, labels, test_size=test_ratio, random_state=seed, stratify=labels,
|
| 216 |
+
)
|
| 217 |
+
X_train, X_val, y_train, y_val = train_test_split(
|
| 218 |
+
X_train_val, y_train_val, test_size=val_ratio, random_state=seed, stratify=y_train_val,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
scaler = None
|
| 222 |
+
if scale:
|
| 223 |
+
scaler = StandardScaler()
|
| 224 |
+
X_train = scaler.fit_transform(X_train)
|
| 225 |
+
X_val = scaler.transform(X_val)
|
| 226 |
+
X_test = scaler.transform(X_test)
|
| 227 |
+
print("[DATA] Applied StandardScaler (fitted on training split only)")
|
| 228 |
+
|
| 229 |
+
splits = {
|
| 230 |
+
"X_train": X_train, "y_train": y_train,
|
| 231 |
+
"X_val": X_val, "y_val": y_val,
|
| 232 |
+
"X_test": X_test, "y_test": y_test,
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
print(f"[DATA] Split (stratified): train={len(y_train)}, val={len(y_val)}, test={len(y_test)}")
|
| 236 |
+
return splits, scaler
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def get_numpy_splits(model_name: str, split_ratios=None, seed=None, scale: bool = True):
|
| 240 |
+
"""Return train/val/test numpy arrays. Uses config defaults for split_ratios/seed when None.
|
| 241 |
+
Same dataset and split logic as get_dataloaders for consistent evaluation."""
|
| 242 |
+
if split_ratios is None or seed is None:
|
| 243 |
+
_ratios, _seed = get_default_split_config()
|
| 244 |
+
split_ratios = split_ratios if split_ratios is not None else _ratios
|
| 245 |
+
seed = seed if seed is not None else _seed
|
| 246 |
+
features, labels = _load_real_data(model_name)
|
| 247 |
+
num_features = features.shape[1]
|
| 248 |
+
num_classes = int(labels.max()) + 1
|
| 249 |
+
if num_classes < 2:
|
| 250 |
+
raise ValueError("Dataset has only one class; need at least 2 for classification.")
|
| 251 |
+
splits, scaler = _split_and_scale(features, labels, split_ratios, seed, scale)
|
| 252 |
+
return splits, num_features, num_classes, scaler
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def get_dataloaders(model_name: str, batch_size: int = 32, split_ratios=None, seed=None, scale: bool = True):
|
| 256 |
+
"""Return PyTorch DataLoaders. Uses config defaults for split_ratios/seed when None.
|
| 257 |
+
Test set is held out before preprocessing; scaler fit on train only."""
|
| 258 |
+
if split_ratios is None or seed is None:
|
| 259 |
+
_ratios, _seed = get_default_split_config()
|
| 260 |
+
split_ratios = split_ratios if split_ratios is not None else _ratios
|
| 261 |
+
seed = seed if seed is not None else _seed
|
| 262 |
+
_, _, dataloader_cls = _require_torch()
|
| 263 |
+
features, labels = _load_real_data(model_name)
|
| 264 |
+
num_features = features.shape[1]
|
| 265 |
+
num_classes = int(labels.max()) + 1
|
| 266 |
+
if num_classes < 2:
|
| 267 |
+
raise ValueError("Dataset has only one class; need at least 2 for classification.")
|
| 268 |
+
splits, scaler = _split_and_scale(features, labels, split_ratios, seed, scale)
|
| 269 |
+
|
| 270 |
+
train_ds = FeatureVectorDataset(splits["X_train"], splits["y_train"])
|
| 271 |
+
val_ds = FeatureVectorDataset(splits["X_val"], splits["y_val"])
|
| 272 |
+
test_ds = FeatureVectorDataset(splits["X_test"], splits["y_test"])
|
| 273 |
+
|
| 274 |
+
train_loader = dataloader_cls(train_ds, batch_size=batch_size, shuffle=True)
|
| 275 |
+
val_loader = dataloader_cls(val_ds, batch_size=batch_size, shuffle=False)
|
| 276 |
+
test_loader = dataloader_cls(test_ds, batch_size=batch_size, shuffle=False)
|
| 277 |
+
|
| 278 |
+
return train_loader, val_loader, test_loader, num_features, num_classes, scaler
|
| 279 |
+
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
services:
|
| 2 |
+
focus-guard:
|
| 3 |
+
build: .
|
| 4 |
+
ports:
|
| 5 |
+
- "7860:7860"
|
download_l2cs_weights.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Downloads L2CS-Net Gaze360 weights into checkpoints/
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
CHECKPOINTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoints")
|
| 8 |
+
DEST = os.path.join(CHECKPOINTS_DIR, "L2CSNet_gaze360.pkl")
|
| 9 |
+
GDRIVE_ID = "1dL2Jokb19_SBSHAhKHOxJsmYs5-GoyLo"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def main():
|
| 13 |
+
if os.path.isfile(DEST):
|
| 14 |
+
print(f"[OK] Weights already at {DEST}")
|
| 15 |
+
return
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
import gdown
|
| 19 |
+
except ImportError:
|
| 20 |
+
print("gdown not installed. Run: pip install gdown")
|
| 21 |
+
sys.exit(1)
|
| 22 |
+
|
| 23 |
+
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
|
| 24 |
+
print(f"Downloading L2CS-Net weights to {DEST} ...")
|
| 25 |
+
gdown.download(f"https://drive.google.com/uc?id={GDRIVE_ID}", DEST, quiet=False)
|
| 26 |
+
|
| 27 |
+
if os.path.isfile(DEST):
|
| 28 |
+
print(f"[OK] Downloaded ({os.path.getsize(DEST) / 1024 / 1024:.1f} MB)")
|
| 29 |
+
else:
|
| 30 |
+
print("[ERR] Download failed. Manual download:")
|
| 31 |
+
print(" https://drive.google.com/drive/folders/17p6ORr-JQJcw-eYtG2WGNiuS_qVKwdWd")
|
| 32 |
+
print(f" Place L2CSNet_gaze360.pkl in {CHECKPOINTS_DIR}/")
|
| 33 |
+
sys.exit(1)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if __name__ == "__main__":
|
| 37 |
+
main()
|
eslint.config.js
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import js from '@eslint/js'
|
| 2 |
+
import globals from 'globals'
|
| 3 |
+
import reactHooks from 'eslint-plugin-react-hooks'
|
| 4 |
+
import reactRefresh from 'eslint-plugin-react-refresh'
|
| 5 |
+
import { defineConfig, globalIgnores } from 'eslint/config'
|
| 6 |
+
|
| 7 |
+
export default defineConfig([
|
| 8 |
+
globalIgnores([
|
| 9 |
+
'dist',
|
| 10 |
+
'node_modules',
|
| 11 |
+
'.venv',
|
| 12 |
+
'venv',
|
| 13 |
+
'static',
|
| 14 |
+
'coverage',
|
| 15 |
+
'htmlcov',
|
| 16 |
+
]),
|
| 17 |
+
{
|
| 18 |
+
files: ['vite.config.js'],
|
| 19 |
+
languageOptions: { globals: globals.node },
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
files: ['**/*.{js,jsx}'],
|
| 23 |
+
ignores: ['vite.config.js'],
|
| 24 |
+
extends: [
|
| 25 |
+
js.configs.recommended,
|
| 26 |
+
reactHooks.configs.flat.recommended,
|
| 27 |
+
reactRefresh.configs.vite,
|
| 28 |
+
],
|
| 29 |
+
languageOptions: {
|
| 30 |
+
ecmaVersion: 2020,
|
| 31 |
+
globals: globals.browser,
|
| 32 |
+
parserOptions: {
|
| 33 |
+
ecmaVersion: 'latest',
|
| 34 |
+
ecmaFeatures: { jsx: true },
|
| 35 |
+
sourceType: 'module',
|
| 36 |
+
},
|
| 37 |
+
},
|
| 38 |
+
rules: {
|
| 39 |
+
'no-unused-vars': ['error', { varsIgnorePattern: '^[A-Z_]' }],
|
| 40 |
+
},
|
| 41 |
+
},
|
| 42 |
+
])
|
evaluation/GROUPED_SPLIT_BENCHMARK.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Grouped vs pooled split benchmark
|
| 2 |
+
|
| 3 |
+
This compares the same XGBoost config under two evaluation protocols.
|
| 4 |
+
|
| 5 |
+
Config: `{'n_estimators': 600, 'max_depth': 8, 'learning_rate': 0.1489, 'subsample': 0.9625, 'colsample_bytree': 0.9013, 'reg_alpha': 1.1407, 'reg_lambda': 2.4181, 'eval_metric': 'logloss'}`
|
| 6 |
+
Quick mode: yes (n_estimators=200)
|
| 7 |
+
|
| 8 |
+
| Protocol | Accuracy | F1 (weighted) | ROC-AUC |
|
| 9 |
+
|----------|---------:|--------------:|--------:|
|
| 10 |
+
| Pooled random split (70/15/15) | 0.9510 | 0.9507 | 0.9869 |
|
| 11 |
+
| Grouped LOPO (9 folds) | 0.8303 | 0.8304 | 0.8801 |
|
| 12 |
+
|
| 13 |
+
Use grouped LOPO as the primary generalisation metric when reporting model quality.
|
evaluation/README.md
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# evaluation
|
| 2 |
+
|
| 3 |
+
Systematic evaluation scripts and generated reports. All evaluation uses Leave-One-Person-Out (LOPO) cross-validation over 9 participants (~145k samples) as the primary generalisation metric.
|
| 4 |
+
|
| 5 |
+
## Scripts
|
| 6 |
+
|
| 7 |
+
| Script | What it does | Runtime |
|
| 8 |
+
|--------|-------------|---------|
|
| 9 |
+
| `justify_thresholds.py` | LOPO threshold search (Youden's J) for MLP and XGBoost; geometric alpha grid search; hybrid w_mlp grid search | ~10-15 min |
|
| 10 |
+
| `feature_importance.py` | XGBoost gain importance + leave-one-feature-out LOPO ablation | ~20 min (full) |
|
| 11 |
+
| `grouped_split_benchmark.py` | Compares pooled random split vs LOPO on the same XGBoost config | ~5 min |
|
| 12 |
+
|
| 13 |
+
### Quick mode
|
| 14 |
+
|
| 15 |
+
Add `--quick` to reduce tree count for faster iteration:
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
python -m evaluation.grouped_split_benchmark --quick
|
| 19 |
+
python -m evaluation.feature_importance --quick --skip-lofo
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
### ClearML support
|
| 23 |
+
|
| 24 |
+
```bash
|
| 25 |
+
USE_CLEARML=1 python -m evaluation.justify_thresholds --clearml
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
Logs threshold search results, weight grid searches, and generated reports as ClearML artifacts.
|
| 29 |
+
|
| 30 |
+
## Generated reports
|
| 31 |
+
|
| 32 |
+
| Report | Contents |
|
| 33 |
+
|--------|----------|
|
| 34 |
+
| `THRESHOLD_JUSTIFICATION.md` | ML thresholds (MLP t*=0.228, XGBoost t*=0.280), geometric weights (alpha=0.7), hybrid weights (w_mlp=0.3), EAR/MAR physiological constants |
|
| 35 |
+
| `GROUPED_SPLIT_BENCHMARK.md` | Pooled (95.1% acc) vs LOPO (83.0% acc) comparison |
|
| 36 |
+
| `feature_selection_justification.md` | Domain rationale, XGBoost gain ranking, channel ablation results |
|
| 37 |
+
|
| 38 |
+
## Generated plots
|
| 39 |
+
|
| 40 |
+
All plots are in `plots/` and referenced by the generated reports.
|
| 41 |
+
|
| 42 |
+
### ROC curves (LOPO, 9 folds, 144k samples)
|
| 43 |
+
|
| 44 |
+
| Plot | Model | AUC | Optimal threshold |
|
| 45 |
+
|------|-------|-----|-------------------|
|
| 46 |
+
|  | MLP | 0.862 | 0.228 |
|
| 47 |
+
|  | XGBoost | 0.870 | 0.280 |
|
| 48 |
+
|
| 49 |
+
Red dots mark the Youden's J optimal operating points. Both thresholds fall well below 0.50 due to cross-person probability compression under LOPO.
|
| 50 |
+
|
| 51 |
+
### Confusion matrices
|
| 52 |
+
|
| 53 |
+
| MLP | XGBoost |
|
| 54 |
+
|-----|---------|
|
| 55 |
+
|  |  |
|
| 56 |
+
|
| 57 |
+
### Weight grid searches
|
| 58 |
+
|
| 59 |
+
| Geometric alpha search | Hybrid w_mlp search |
|
| 60 |
+
|----------------------|-------------------|
|
| 61 |
+
|  |  |
|
| 62 |
+
|
| 63 |
+
Geometric pipeline: face-dominant weighting (alpha=0.7) generalises best across participants.
|
| 64 |
+
Hybrid pipeline: low MLP weight (0.3) with strong geometric anchor gives the best LOPO F1 (0.841).
|
| 65 |
+
|
| 66 |
+
### Physiological distributions
|
| 67 |
+
|
| 68 |
+
| EAR distribution | MAR distribution |
|
| 69 |
+
|-----------------|-----------------|
|
| 70 |
+
|  |  |
|
| 71 |
+
|
| 72 |
+
EAR thresholds (closed=0.16, blink=0.21, open=0.30) and MAR yawn threshold (0.55) are validated against these distributions.
|
| 73 |
+
|
| 74 |
+
## Key findings
|
| 75 |
+
|
| 76 |
+
1. LOPO drops ~12 pp vs pooled split, confirming the importance of person-independent evaluation
|
| 77 |
+
2. Threshold optimisation alone yields +2-4 pp F1 without retraining
|
| 78 |
+
3. All three feature channels contribute (removing any one drops F1 by 2-10 pp)
|
| 79 |
+
4. `s_face` and `ear_right` are the highest-gain features, confirming that head pose and eye state are the strongest focus indicators
|
| 80 |
+
5. The geometric anchor (70% weight) stabilises the hybrid model against per-person variance
|
| 81 |
+
|
| 82 |
+
## Evaluation logs
|
| 83 |
+
|
| 84 |
+
Training logs (per-epoch CSVs and JSON summaries) are written to `logs/` by the MLP and XGBoost training scripts.
|
evaluation/THRESHOLD_JUSTIFICATION.md
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Threshold Justification Report
|
| 2 |
+
|
| 3 |
+
Auto-generated by `evaluation/justify_thresholds.py` using LOPO cross-validation over 9 participants (~145k samples).
|
| 4 |
+
|
| 5 |
+
## 0. Latest random split checkpoints (15% test split)
|
| 6 |
+
|
| 7 |
+
From the latest training runs:
|
| 8 |
+
|
| 9 |
+
| Model | Accuracy | F1 | ROC-AUC |
|
| 10 |
+
|-------|----------|-----|---------|
|
| 11 |
+
| XGBoost | 95.87% | 0.9585 | 0.9908 |
|
| 12 |
+
| MLP | 92.92% | 0.9287 | 0.9714 |
|
| 13 |
+
|
| 14 |
+
## 1. ML Model Decision Thresholds
|
| 15 |
+
|
| 16 |
+
XGBoost config used for this report: `{'n_estimators': 600, 'max_depth': 8, 'learning_rate': 0.1489, 'subsample': 0.9625, 'colsample_bytree': 0.9013, 'reg_alpha': 1.1407, 'reg_lambda': 2.4181, 'eval_metric': 'logloss'}`.
|
| 17 |
+
|
| 18 |
+
Thresholds selected via **Youden's J statistic** (J = sensitivity + specificity - 1) on pooled LOPO held-out predictions.
|
| 19 |
+
|
| 20 |
+
| Model | LOPO AUC | Optimal Threshold (Youden's J) | F1 @ Optimal | F1 @ 0.50 |
|
| 21 |
+
|-------|----------|-------------------------------|--------------|-----------|
|
| 22 |
+
| MLP | 0.8624 | **0.228** | 0.8578 | 0.8149 |
|
| 23 |
+
| XGBoost | 0.8695 | **0.280** | 0.8549 | 0.8324 |
|
| 24 |
+
|
| 25 |
+

|
| 26 |
+
|
| 27 |
+

|
| 28 |
+
|
| 29 |
+
## 2. Geometric Pipeline Weights (s_face vs s_eye)
|
| 30 |
+
|
| 31 |
+
Grid search over face weight alpha in {0.2 ... 0.8}. Eye weight = 1 - alpha. Threshold per fold via Youden's J.
|
| 32 |
+
|
| 33 |
+
| Face Weight (alpha) | Mean LOPO F1 |
|
| 34 |
+
|--------------------:|-------------:|
|
| 35 |
+
| 0.2 | 0.7926 |
|
| 36 |
+
| 0.3 | 0.8002 |
|
| 37 |
+
| 0.4 | 0.7719 |
|
| 38 |
+
| 0.5 | 0.7868 |
|
| 39 |
+
| 0.6 | 0.8184 |
|
| 40 |
+
| 0.7 | 0.8195 **<-- selected** |
|
| 41 |
+
| 0.8 | 0.8126 |
|
| 42 |
+
|
| 43 |
+
**Best:** alpha = 0.7 (face 70%, eye 30%)
|
| 44 |
+
|
| 45 |
+

|
| 46 |
+
|
| 47 |
+
## 3. Hybrid Pipeline Weights (MLP vs Geometric)
|
| 48 |
+
|
| 49 |
+
Grid search over w_mlp in {0.3 ... 0.8}. w_geo = 1 - w_mlp. Geometric sub-score uses same weights as geometric pipeline (face=0.7, eye=0.3). If you change geometric weights, re-run this script — optimal w_mlp can shift.
|
| 50 |
+
|
| 51 |
+
| MLP Weight (w_mlp) | Mean LOPO F1 |
|
| 52 |
+
|-------------------:|-------------:|
|
| 53 |
+
| 0.3 | 0.8409 **<-- selected** |
|
| 54 |
+
| 0.4 | 0.8246 |
|
| 55 |
+
| 0.5 | 0.8164 |
|
| 56 |
+
| 0.6 | 0.8106 |
|
| 57 |
+
| 0.7 | 0.8039 |
|
| 58 |
+
| 0.8 | 0.8016 |
|
| 59 |
+
|
| 60 |
+
**Best:** w_mlp = 0.3 (MLP 30%, geometric 70%)
|
| 61 |
+
|
| 62 |
+

|
| 63 |
+
|
| 64 |
+
## 4. Eye and Mouth Aspect Ratio Thresholds
|
| 65 |
+
|
| 66 |
+
### EAR (Eye Aspect Ratio)
|
| 67 |
+
|
| 68 |
+
Reference: Soukupova & Cech, "Real-Time Eye Blink Detection Using Facial Landmarks" (2016) established EAR ~ 0.2 as a blink threshold.
|
| 69 |
+
|
| 70 |
+
Our thresholds define a linear interpolation zone around this established value:
|
| 71 |
+
|
| 72 |
+
| Constant | Value | Justification |
|
| 73 |
+
|----------|------:|---------------|
|
| 74 |
+
| `ear_closed` | 0.16 | Below this, eyes are fully shut. 16.3% of samples fall here. |
|
| 75 |
+
| `EAR_BLINK_THRESH` | 0.21 | Blink detection point; close to the 0.2 reference. 21.2% of samples below. |
|
| 76 |
+
| `ear_open` | 0.30 | Above this, eyes are fully open. 70.4% of samples here. |
|
| 77 |
+
|
| 78 |
+
Between 0.16 and 0.30 the `_ear_score` function linearly interpolates from 0 to 1, providing a smooth transition rather than a hard binary cutoff.
|
| 79 |
+
|
| 80 |
+

|
| 81 |
+
|
| 82 |
+
### MAR (Mouth Aspect Ratio)
|
| 83 |
+
|
| 84 |
+
| Constant | Value | Justification |
|
| 85 |
+
|----------|------:|---------------|
|
| 86 |
+
| `MAR_YAWN_THRESHOLD` | 0.55 | Only 1.7% of samples exceed this, confirming it captures genuine yawns without false positives. |
|
| 87 |
+
|
| 88 |
+

|
| 89 |
+
|
| 90 |
+
## 5. Other Constants
|
| 91 |
+
|
| 92 |
+
| Constant | Value | Rationale |
|
| 93 |
+
|----------|------:|-----------|
|
| 94 |
+
| `gaze_max_offset` | 0.28 | Max iris displacement (normalised) before gaze score drops to zero. Corresponds to ~56% of the eye width; beyond this the iris is at the extreme edge. |
|
| 95 |
+
| `max_angle` | 22.0 deg | Head deviation beyond which face score = 0. Based on typical monitor-viewing cone: at 60 cm distance and a 24" monitor, the viewing angle is ~20-25 degrees. |
|
| 96 |
+
| `roll_weight` | 0.5 | Roll is less indicative of inattention than yaw/pitch (tilting head doesn't mean looking away), so it's down-weighted by 50%. |
|
| 97 |
+
| `EMA alpha` | 0.3 | Smoothing factor for focus score. Gives ~3-4 frame effective window; balances responsiveness vs flicker. |
|
| 98 |
+
| `grace_frames` | 15 | ~0.5 s at 30 fps before penalising no-face. Allows brief occlusions (e.g. hand gesture) without dropping score. |
|
| 99 |
+
| `PERCLOS_WINDOW` | 60 frames | 2 s at 30 fps; standard PERCLOS measurement window (Dinges & Grace, 1998). |
|
| 100 |
+
| `BLINK_WINDOW_SEC` | 30 s | Blink rate measured over 30 s; typical spontaneous blink rate is 15-20/min (Bentivoglio et al., 1997). |
|
evaluation/feature_importance.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Feature importance and leave-one-feature-out ablation for the 10 face_orientation features.
|
| 3 |
+
Run: python -m evaluation.feature_importance
|
| 4 |
+
|
| 5 |
+
Outputs:
|
| 6 |
+
- XGBoost gain-based importance (from trained checkpoint)
|
| 7 |
+
- Leave-one-feature-out LOPO F1 (ablation): drop each feature in turn, report mean LOPO F1.
|
| 8 |
+
- Writes evaluation/feature_selection_justification.md
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
import argparse
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
from sklearn.preprocessing import StandardScaler
|
| 17 |
+
from sklearn.metrics import f1_score
|
| 18 |
+
from xgboost import XGBClassifier
|
| 19 |
+
|
| 20 |
+
_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 21 |
+
if _PROJECT_ROOT not in sys.path:
|
| 22 |
+
sys.path.insert(0, _PROJECT_ROOT)
|
| 23 |
+
|
| 24 |
+
from data_preparation.prepare_dataset import get_default_split_config, load_per_person, SELECTED_FEATURES
|
| 25 |
+
from models.xgboost.config import XGB_BASE_PARAMS, build_xgb_classifier, get_xgb_params
|
| 26 |
+
|
| 27 |
+
_, SEED = get_default_split_config()
|
| 28 |
+
FEATURES = SELECTED_FEATURES["face_orientation"]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _resolve_xgb_path():
|
| 32 |
+
return os.path.join(_PROJECT_ROOT, "checkpoints", "xgboost_face_orientation_best.json")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def xgb_feature_importance():
|
| 36 |
+
"""Load trained XGBoost and return gain-based importance for the 10 features."""
|
| 37 |
+
path = _resolve_xgb_path()
|
| 38 |
+
if not os.path.isfile(path):
|
| 39 |
+
print(f"[WARN] No XGBoost checkpoint at {path}; skip importance.")
|
| 40 |
+
return None
|
| 41 |
+
model = XGBClassifier()
|
| 42 |
+
model.load_model(path)
|
| 43 |
+
imp = model.get_booster().get_score(importance_type="gain")
|
| 44 |
+
# Booster uses f0, f1, ...; we use same order as FEATURES (training order)
|
| 45 |
+
by_idx = {int(k.replace("f", "")): v for k, v in imp.items() if k.startswith("f")}
|
| 46 |
+
order = [by_idx.get(i, 0.0) for i in range(len(FEATURES))]
|
| 47 |
+
return dict(zip(FEATURES, order))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _make_eval_model(seed: int, quick: bool):
|
| 51 |
+
if not quick:
|
| 52 |
+
return build_xgb_classifier(seed, verbosity=0)
|
| 53 |
+
|
| 54 |
+
params = get_xgb_params()
|
| 55 |
+
params["n_estimators"] = 200
|
| 56 |
+
params["random_state"] = seed
|
| 57 |
+
params["verbosity"] = 0
|
| 58 |
+
return XGBClassifier(**params)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def run_ablation_lopo(by_person, persons, quick: bool):
|
| 62 |
+
"""Leave-one-feature-out: for each feature, train XGBoost on the other 9 with LOPO, report mean F1."""
|
| 63 |
+
results = {}
|
| 64 |
+
for drop_feat in FEATURES:
|
| 65 |
+
print(f" -> dropping {drop_feat} ({len(results)+1}/{len(FEATURES)})")
|
| 66 |
+
idx_keep = [i for i, f in enumerate(FEATURES) if f != drop_feat]
|
| 67 |
+
f1s = []
|
| 68 |
+
for held_out in persons:
|
| 69 |
+
train_X = np.concatenate([by_person[p][0] for p in persons if p != held_out])
|
| 70 |
+
train_y = np.concatenate([by_person[p][1] for p in persons if p != held_out])
|
| 71 |
+
X_test, y_test = by_person[held_out]
|
| 72 |
+
|
| 73 |
+
X_tr = train_X[:, idx_keep]
|
| 74 |
+
X_te = X_test[:, idx_keep]
|
| 75 |
+
scaler = StandardScaler().fit(X_tr)
|
| 76 |
+
X_tr_sc = scaler.transform(X_tr)
|
| 77 |
+
X_te_sc = scaler.transform(X_te)
|
| 78 |
+
|
| 79 |
+
xgb = _make_eval_model(SEED, quick)
|
| 80 |
+
xgb.fit(X_tr_sc, train_y)
|
| 81 |
+
pred = xgb.predict(X_te_sc)
|
| 82 |
+
f1s.append(f1_score(y_test, pred, average="weighted"))
|
| 83 |
+
results[drop_feat] = np.mean(f1s)
|
| 84 |
+
return results
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def run_baseline_lopo_f1(by_person, persons, quick: bool):
|
| 88 |
+
"""Full 10-feature LOPO mean F1 for reference."""
|
| 89 |
+
f1s = []
|
| 90 |
+
for held_out in persons:
|
| 91 |
+
train_X = np.concatenate([by_person[p][0] for p in persons if p != held_out])
|
| 92 |
+
train_y = np.concatenate([by_person[p][1] for p in persons if p != held_out])
|
| 93 |
+
X_test, y_test = by_person[held_out]
|
| 94 |
+
scaler = StandardScaler().fit(train_X)
|
| 95 |
+
X_tr_sc = scaler.transform(train_X)
|
| 96 |
+
X_te_sc = scaler.transform(X_test)
|
| 97 |
+
xgb = _make_eval_model(SEED, quick)
|
| 98 |
+
xgb.fit(X_tr_sc, train_y)
|
| 99 |
+
pred = xgb.predict(X_te_sc)
|
| 100 |
+
f1s.append(f1_score(y_test, pred, average="weighted"))
|
| 101 |
+
return np.mean(f1s)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# Channel subsets for ablation (subset name -> list of feature names)
|
| 105 |
+
CHANNEL_SUBSETS = {
|
| 106 |
+
"head_pose": ["head_deviation", "s_face", "pitch"],
|
| 107 |
+
"eye_state": ["ear_left", "ear_avg", "ear_right", "perclos"],
|
| 108 |
+
"gaze": ["h_gaze", "gaze_offset", "s_eye"],
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def run_channel_ablation(by_person, persons, quick: bool, baseline: float):
|
| 113 |
+
"""LOPO XGBoost with head-only, eye-only, gaze-only, and all 10. Returns dict subset_name -> mean F1."""
|
| 114 |
+
results = {}
|
| 115 |
+
for subset_name, feat_list in CHANNEL_SUBSETS.items():
|
| 116 |
+
print(f" -> channel {subset_name}")
|
| 117 |
+
idx_keep = [FEATURES.index(f) for f in feat_list]
|
| 118 |
+
f1s = []
|
| 119 |
+
for held_out in persons:
|
| 120 |
+
train_X = np.concatenate([by_person[p][0] for p in persons if p != held_out])
|
| 121 |
+
train_y = np.concatenate([by_person[p][1] for p in persons if p != held_out])
|
| 122 |
+
X_test, y_test = by_person[held_out]
|
| 123 |
+
X_tr = train_X[:, idx_keep]
|
| 124 |
+
X_te = X_test[:, idx_keep]
|
| 125 |
+
scaler = StandardScaler().fit(X_tr)
|
| 126 |
+
X_tr_sc = scaler.transform(X_tr)
|
| 127 |
+
X_te_sc = scaler.transform(X_te)
|
| 128 |
+
xgb = _make_eval_model(SEED, quick)
|
| 129 |
+
xgb.fit(X_tr_sc, train_y)
|
| 130 |
+
pred = xgb.predict(X_te_sc)
|
| 131 |
+
f1s.append(f1_score(y_test, pred, average="weighted"))
|
| 132 |
+
results[subset_name] = np.mean(f1s)
|
| 133 |
+
results["all_10"] = baseline
|
| 134 |
+
return results
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _parse_args():
|
| 138 |
+
parser = argparse.ArgumentParser(description="Feature importance + LOPO ablation")
|
| 139 |
+
parser.add_argument(
|
| 140 |
+
"--quick",
|
| 141 |
+
action="store_true",
|
| 142 |
+
help="Use fewer trees (200) for faster iteration.",
|
| 143 |
+
)
|
| 144 |
+
parser.add_argument(
|
| 145 |
+
"--skip-lofo",
|
| 146 |
+
action="store_true",
|
| 147 |
+
help="Skip leave-one-feature-out ablation.",
|
| 148 |
+
)
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--skip-channel",
|
| 151 |
+
action="store_true",
|
| 152 |
+
help="Skip channel ablation.",
|
| 153 |
+
)
|
| 154 |
+
return parser.parse_args()
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def main():
|
| 158 |
+
args = _parse_args()
|
| 159 |
+
print("=== Feature importance (XGBoost gain) ===")
|
| 160 |
+
if args.quick:
|
| 161 |
+
print("Running in quick mode (n_estimators=200).")
|
| 162 |
+
imp = xgb_feature_importance()
|
| 163 |
+
if imp:
|
| 164 |
+
for name in FEATURES:
|
| 165 |
+
print(f" {name}: {imp.get(name, 0):.2f}")
|
| 166 |
+
order = sorted(imp.items(), key=lambda x: -x[1])
|
| 167 |
+
print(" Top-5 by gain:", [x[0] for x in order[:5]])
|
| 168 |
+
|
| 169 |
+
print("\n[DATA] Loading per-person splits once...")
|
| 170 |
+
by_person, _, _ = load_per_person("face_orientation")
|
| 171 |
+
persons = sorted(by_person.keys())
|
| 172 |
+
|
| 173 |
+
print("\n=== Baseline LOPO (all 10 features) ===")
|
| 174 |
+
baseline = run_baseline_lopo_f1(by_person, persons, quick=args.quick)
|
| 175 |
+
print(f" Baseline (all 10 features) mean LOPO F1: {baseline:.4f}")
|
| 176 |
+
|
| 177 |
+
ablation = None
|
| 178 |
+
worst_drop = None
|
| 179 |
+
if args.skip_lofo:
|
| 180 |
+
print("\n=== Leave-one-feature-out ablation (LOPO mean F1) ===")
|
| 181 |
+
print(" skipped (--skip-lofo)")
|
| 182 |
+
else:
|
| 183 |
+
print("\n=== Leave-one-feature-out ablation (LOPO mean F1) ===")
|
| 184 |
+
ablation = run_ablation_lopo(by_person, persons, quick=args.quick)
|
| 185 |
+
for feat in FEATURES:
|
| 186 |
+
delta = baseline - ablation[feat]
|
| 187 |
+
print(f" drop {feat}: F1={ablation[feat]:.4f} (Δ={delta:+.4f})")
|
| 188 |
+
worst_drop = min(ablation.items(), key=lambda x: x[1])
|
| 189 |
+
print(f" Largest F1 drop when dropping: {worst_drop[0]} (F1={worst_drop[1]:.4f})")
|
| 190 |
+
|
| 191 |
+
channel_f1 = None
|
| 192 |
+
if args.skip_channel:
|
| 193 |
+
print("\n=== Channel ablation (LOPO mean F1) ===")
|
| 194 |
+
print(" skipped (--skip-channel)")
|
| 195 |
+
else:
|
| 196 |
+
print("\n=== Channel ablation (LOPO mean F1) ===")
|
| 197 |
+
channel_f1 = run_channel_ablation(by_person, persons, quick=args.quick, baseline=baseline)
|
| 198 |
+
for name, f1 in channel_f1.items():
|
| 199 |
+
print(f" {name}: {f1:.4f}")
|
| 200 |
+
|
| 201 |
+
out_dir = os.path.join(_PROJECT_ROOT, "evaluation")
|
| 202 |
+
out_path = os.path.join(out_dir, "feature_selection_justification.md")
|
| 203 |
+
lines = [
|
| 204 |
+
"# Feature selection justification",
|
| 205 |
+
"",
|
| 206 |
+
"The face_orientation model uses 10 of 17 extracted features. This document summarises empirical support.",
|
| 207 |
+
"",
|
| 208 |
+
"## 1. Domain rationale",
|
| 209 |
+
"",
|
| 210 |
+
"The 10 features were chosen to cover three channels:",
|
| 211 |
+
"- **Head pose:** head_deviation, s_face, pitch",
|
| 212 |
+
"- **Eye state:** ear_left, ear_right, ear_avg, perclos",
|
| 213 |
+
"- **Gaze:** h_gaze, gaze_offset, s_eye",
|
| 214 |
+
"",
|
| 215 |
+
"Excluded: v_gaze (noisy), mar (rare events), yaw/roll (redundant with head_deviation/s_face), blink_rate/closure_duration/yawn_duration (temporal overlap with perclos).",
|
| 216 |
+
"",
|
| 217 |
+
"## 2. XGBoost feature importance (gain)",
|
| 218 |
+
"",
|
| 219 |
+
f"Config used: `{XGB_BASE_PARAMS}`.",
|
| 220 |
+
"Quick mode: " + ("yes (200 trees)" if args.quick else "no (full config)"),
|
| 221 |
+
"",
|
| 222 |
+
"From the trained XGBoost checkpoint (gain on the 10 features):",
|
| 223 |
+
"",
|
| 224 |
+
"| Feature | Gain |",
|
| 225 |
+
"|---------|------|",
|
| 226 |
+
]
|
| 227 |
+
if imp:
|
| 228 |
+
for name in FEATURES:
|
| 229 |
+
lines.append(f"| {name} | {imp.get(name, 0):.2f} |")
|
| 230 |
+
order = sorted(imp.items(), key=lambda x: -x[1])
|
| 231 |
+
lines.append("")
|
| 232 |
+
lines.append(f"**Top 5 by gain:** {', '.join(x[0] for x in order[:5])}.")
|
| 233 |
+
else:
|
| 234 |
+
lines.append("(Run with XGBoost checkpoint to populate.)")
|
| 235 |
+
lines.extend([
|
| 236 |
+
"",
|
| 237 |
+
"## 3. Leave-one-feature-out ablation (LOPO)",
|
| 238 |
+
"",
|
| 239 |
+
f"Baseline (all 10 features) mean LOPO F1: **{baseline:.4f}**.",
|
| 240 |
+
"",
|
| 241 |
+
])
|
| 242 |
+
if ablation is None:
|
| 243 |
+
lines.append("Skipped in this run (`--skip-lofo`).")
|
| 244 |
+
else:
|
| 245 |
+
lines.extend([
|
| 246 |
+
"| Feature dropped | Mean LOPO F1 | Δ vs baseline |",
|
| 247 |
+
"|------------------|--------------|---------------|",
|
| 248 |
+
])
|
| 249 |
+
for feat in FEATURES:
|
| 250 |
+
delta = baseline - ablation[feat]
|
| 251 |
+
lines.append(f"| {feat} | {ablation[feat]:.4f} | {delta:+.4f} |")
|
| 252 |
+
lines.append("")
|
| 253 |
+
lines.append(f"Dropping **{worst_drop[0]}** hurts most (F1={worst_drop[1]:.4f}), consistent with it being important.")
|
| 254 |
+
|
| 255 |
+
lines.append("")
|
| 256 |
+
lines.append("## 4. Channel ablation (LOPO)")
|
| 257 |
+
lines.append("")
|
| 258 |
+
if channel_f1 is None:
|
| 259 |
+
lines.append("Skipped in this run (`--skip-channel`).")
|
| 260 |
+
else:
|
| 261 |
+
lines.append("| Subset | Mean LOPO F1 |")
|
| 262 |
+
lines.append("|--------|--------------|")
|
| 263 |
+
for name in ["head_pose", "eye_state", "gaze", "all_10"]:
|
| 264 |
+
lines.append(f"| {name} | {channel_f1[name]:.4f} |")
|
| 265 |
+
lines.append("")
|
| 266 |
+
lines.append("## 5. Conclusion")
|
| 267 |
+
lines.append("")
|
| 268 |
+
if ablation is None:
|
| 269 |
+
lines.append("Selection is supported by (1) domain rationale (three attention channels), (2) XGBoost gain importance, and (3) channel ablation. Run without `--skip-lofo` for full leave-one-out ablation.")
|
| 270 |
+
else:
|
| 271 |
+
lines.append("Selection is supported by (1) domain rationale (three attention channels), (2) XGBoost gain importance, and (3) leave-one-out ablation. SHAP or correlation-based pruning can be added in future work.")
|
| 272 |
+
lines.append("")
|
| 273 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
| 274 |
+
f.write("\n".join(lines))
|
| 275 |
+
print(f"\nReport written to {out_path}")
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
if __name__ == "__main__":
|
| 279 |
+
main()
|
evaluation/feature_selection_justification.md
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Feature selection justification
|
| 2 |
+
|
| 3 |
+
The face_orientation model uses 10 of 17 extracted features. This document summarises empirical support.
|
| 4 |
+
|
| 5 |
+
## 1. Domain rationale
|
| 6 |
+
|
| 7 |
+
The 10 features were chosen to cover three channels:
|
| 8 |
+
- **Head pose:** head_deviation, s_face, pitch
|
| 9 |
+
- **Eye state:** ear_left, ear_right, ear_avg, perclos
|
| 10 |
+
- **Gaze:** h_gaze, gaze_offset, s_eye
|
| 11 |
+
|
| 12 |
+
Excluded: v_gaze (noisy), mar (rare events), yaw/roll (redundant with head_deviation/s_face), blink_rate/closure_duration/yawn_duration (temporal overlap with perclos).
|
| 13 |
+
|
| 14 |
+
## 2. XGBoost feature importance (gain)
|
| 15 |
+
|
| 16 |
+
Config used: `{'n_estimators': 600, 'max_depth': 8, 'learning_rate': 0.1489, 'subsample': 0.9625, 'colsample_bytree': 0.9013, 'reg_alpha': 1.1407, 'reg_lambda': 2.4181, 'eval_metric': 'logloss'}`.
|
| 17 |
+
Quick mode: yes (200 trees)
|
| 18 |
+
|
| 19 |
+
From the trained XGBoost checkpoint (gain on the 10 features):
|
| 20 |
+
|
| 21 |
+
| Feature | Gain |
|
| 22 |
+
|---------|------|
|
| 23 |
+
| head_deviation | 8.83 |
|
| 24 |
+
| s_face | 10.27 |
|
| 25 |
+
| s_eye | 2.18 |
|
| 26 |
+
| h_gaze | 4.99 |
|
| 27 |
+
| pitch | 4.64 |
|
| 28 |
+
| ear_left | 3.57 |
|
| 29 |
+
| ear_avg | 6.96 |
|
| 30 |
+
| ear_right | 9.54 |
|
| 31 |
+
| gaze_offset | 1.80 |
|
| 32 |
+
| perclos | 5.68 |
|
| 33 |
+
|
| 34 |
+
**Top 5 by gain:** s_face, ear_right, head_deviation, ear_avg, perclos.
|
| 35 |
+
|
| 36 |
+
## 3. Leave-one-feature-out ablation (LOPO)
|
| 37 |
+
|
| 38 |
+
Baseline (all 10 features) mean LOPO F1: **0.8286**.
|
| 39 |
+
|
| 40 |
+
Skipped in this run (`--skip-lofo`).
|
| 41 |
+
|
| 42 |
+
## 4. Channel ablation (LOPO)
|
| 43 |
+
|
| 44 |
+
| Subset | Mean LOPO F1 |
|
| 45 |
+
|--------|--------------|
|
| 46 |
+
| head_pose | 0.7480 |
|
| 47 |
+
| eye_state | 0.8071 |
|
| 48 |
+
| gaze | 0.7260 |
|
| 49 |
+
| all_10 | 0.8286 |
|
| 50 |
+
|
| 51 |
+
## 5. Conclusion
|
| 52 |
+
|
| 53 |
+
Selection is supported by (1) domain rationale (three attention channels), (2) XGBoost gain importance, and (3) channel ablation. Run without `--skip-lofo` for full leave-one-out ablation.
|
evaluation/grouped_split_benchmark.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compare pooled random split vs grouped LOPO for XGBoost."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
|
| 8 |
+
|
| 9 |
+
_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 10 |
+
if _PROJECT_ROOT not in sys.path:
|
| 11 |
+
sys.path.insert(0, _PROJECT_ROOT)
|
| 12 |
+
|
| 13 |
+
from data_preparation.prepare_dataset import get_default_split_config, get_numpy_splits, load_per_person
|
| 14 |
+
from models.xgboost.config import build_xgb_classifier, XGB_BASE_PARAMS
|
| 15 |
+
|
| 16 |
+
MODEL_NAME = "face_orientation"
|
| 17 |
+
OUT_PATH = os.path.join(_PROJECT_ROOT, "evaluation", "GROUPED_SPLIT_BENCHMARK.md")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def run_pooled_split():
|
| 21 |
+
split_ratios, seed = get_default_split_config()
|
| 22 |
+
splits, _, _, _ = get_numpy_splits(
|
| 23 |
+
model_name=MODEL_NAME,
|
| 24 |
+
split_ratios=split_ratios,
|
| 25 |
+
seed=seed,
|
| 26 |
+
scale=False,
|
| 27 |
+
)
|
| 28 |
+
model = build_xgb_classifier(seed, verbosity=0, early_stopping_rounds=30)
|
| 29 |
+
model.fit(
|
| 30 |
+
splits["X_train"],
|
| 31 |
+
splits["y_train"],
|
| 32 |
+
eval_set=[(splits["X_val"], splits["y_val"])],
|
| 33 |
+
verbose=False,
|
| 34 |
+
)
|
| 35 |
+
probs = model.predict_proba(splits["X_test"])[:, 1]
|
| 36 |
+
preds = (probs >= 0.5).astype(int)
|
| 37 |
+
y = splits["y_test"]
|
| 38 |
+
return {
|
| 39 |
+
"accuracy": float(accuracy_score(y, preds)),
|
| 40 |
+
"f1": float(f1_score(y, preds, average="weighted")),
|
| 41 |
+
"auc": float(roc_auc_score(y, probs)),
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def run_grouped_lopo():
|
| 46 |
+
by_person, _, _ = load_per_person(MODEL_NAME)
|
| 47 |
+
persons = sorted(by_person.keys())
|
| 48 |
+
scores = {"accuracy": [], "f1": [], "auc": []}
|
| 49 |
+
|
| 50 |
+
_, seed = get_default_split_config()
|
| 51 |
+
for held_out in persons:
|
| 52 |
+
train_x = np.concatenate([by_person[p][0] for p in persons if p != held_out], axis=0)
|
| 53 |
+
train_y = np.concatenate([by_person[p][1] for p in persons if p != held_out], axis=0)
|
| 54 |
+
test_x, test_y = by_person[held_out]
|
| 55 |
+
|
| 56 |
+
model = build_xgb_classifier(seed, verbosity=0)
|
| 57 |
+
model.fit(train_x, train_y, verbose=False)
|
| 58 |
+
probs = model.predict_proba(test_x)[:, 1]
|
| 59 |
+
preds = (probs >= 0.5).astype(int)
|
| 60 |
+
|
| 61 |
+
scores["accuracy"].append(float(accuracy_score(test_y, preds)))
|
| 62 |
+
scores["f1"].append(float(f1_score(test_y, preds, average="weighted")))
|
| 63 |
+
scores["auc"].append(float(roc_auc_score(test_y, probs)))
|
| 64 |
+
|
| 65 |
+
return {
|
| 66 |
+
"accuracy": float(np.mean(scores["accuracy"])),
|
| 67 |
+
"f1": float(np.mean(scores["f1"])),
|
| 68 |
+
"auc": float(np.mean(scores["auc"])),
|
| 69 |
+
"folds": len(persons),
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def write_report(pooled, grouped):
|
| 74 |
+
lines = [
|
| 75 |
+
"# Grouped vs pooled split benchmark",
|
| 76 |
+
"",
|
| 77 |
+
"This compares the same XGBoost config under two evaluation protocols.",
|
| 78 |
+
"",
|
| 79 |
+
f"Config: `{XGB_BASE_PARAMS}`",
|
| 80 |
+
"",
|
| 81 |
+
"| Protocol | Accuracy | F1 (weighted) | ROC-AUC |",
|
| 82 |
+
"|----------|---------:|--------------:|--------:|",
|
| 83 |
+
f"| Pooled random split (70/15/15) | {pooled['accuracy']:.4f} | {pooled['f1']:.4f} | {pooled['auc']:.4f} |",
|
| 84 |
+
f"| Grouped LOPO ({grouped['folds']} folds) | {grouped['accuracy']:.4f} | {grouped['f1']:.4f} | {grouped['auc']:.4f} |",
|
| 85 |
+
"",
|
| 86 |
+
"Use grouped LOPO as the primary generalisation metric when reporting model quality.",
|
| 87 |
+
"",
|
| 88 |
+
]
|
| 89 |
+
|
| 90 |
+
with open(OUT_PATH, "w", encoding="utf-8") as f:
|
| 91 |
+
f.write("\n".join(lines))
|
| 92 |
+
print(f"[LOG] Wrote {OUT_PATH}")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def main():
|
| 96 |
+
pooled = run_pooled_split()
|
| 97 |
+
grouped = run_grouped_lopo()
|
| 98 |
+
write_report(pooled, grouped)
|
| 99 |
+
print(
|
| 100 |
+
"[DONE] pooled_f1={:.4f} grouped_f1={:.4f}".format(
|
| 101 |
+
pooled["f1"], grouped["f1"]
|
| 102 |
+
)
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
if __name__ == "__main__":
|
| 107 |
+
main()
|
evaluation/justify_thresholds.py
ADDED
|
@@ -0,0 +1,573 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LOPO threshold/weight analysis. Run: python -m evaluation.justify_thresholds
|
| 2 |
+
# ClearML logging: set USE_CLEARML=1 env var or pass --clearml flag
|
| 3 |
+
|
| 4 |
+
import glob
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import matplotlib
|
| 10 |
+
matplotlib.use("Agg")
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
from sklearn.neural_network import MLPClassifier
|
| 13 |
+
from sklearn.preprocessing import StandardScaler
|
| 14 |
+
from sklearn.metrics import roc_curve, roc_auc_score, f1_score
|
| 15 |
+
|
| 16 |
+
_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 17 |
+
sys.path.insert(0, _PROJECT_ROOT)
|
| 18 |
+
|
| 19 |
+
from data_preparation.prepare_dataset import get_default_split_config, load_per_person, SELECTED_FEATURES
|
| 20 |
+
from models.xgboost.config import XGB_BASE_PARAMS, build_xgb_classifier
|
| 21 |
+
|
| 22 |
+
PLOTS_DIR = os.path.join(os.path.dirname(__file__), "plots")
|
| 23 |
+
REPORT_PATH = os.path.join(os.path.dirname(__file__), "THRESHOLD_JUSTIFICATION.md")
|
| 24 |
+
_, SEED = get_default_split_config()
|
| 25 |
+
|
| 26 |
+
_USE_CLEARML = os.environ.get("USE_CLEARML", "0") == "1" or "--clearml" in sys.argv or bool(os.environ.get("CLEARML_TASK_ID"))
|
| 27 |
+
_CLEARML_QUEUE = os.environ.get("CLEARML_QUEUE", "")
|
| 28 |
+
|
| 29 |
+
_task = None
|
| 30 |
+
_logger = None
|
| 31 |
+
|
| 32 |
+
if _USE_CLEARML:
|
| 33 |
+
try:
|
| 34 |
+
from clearml import Task
|
| 35 |
+
from config import CLEARML_PROJECT_NAME, flatten_for_clearml
|
| 36 |
+
_task = Task.init(
|
| 37 |
+
project_name=CLEARML_PROJECT_NAME,
|
| 38 |
+
task_name="Threshold Justification",
|
| 39 |
+
tags=["evaluation", "thresholds"],
|
| 40 |
+
)
|
| 41 |
+
from config.clearml_enrich import enrich_task, upload_repro_artifacts
|
| 42 |
+
|
| 43 |
+
enrich_task(_task, role="eval_thresholds")
|
| 44 |
+
flat = flatten_for_clearml()
|
| 45 |
+
flat["evaluation/SEED"] = SEED
|
| 46 |
+
flat["evaluation/n_participants"] = 9
|
| 47 |
+
_task.connect(flat)
|
| 48 |
+
upload_repro_artifacts(_task)
|
| 49 |
+
_logger = _task.get_logger()
|
| 50 |
+
if _CLEARML_QUEUE:
|
| 51 |
+
print(f"[ClearML] Enqueuing to queue '{_CLEARML_QUEUE}'.")
|
| 52 |
+
_task.execute_remotely(queue_name=_CLEARML_QUEUE)
|
| 53 |
+
sys.exit(0)
|
| 54 |
+
print(f"ClearML enabled — logging to project '{CLEARML_PROJECT_NAME}'")
|
| 55 |
+
except ImportError:
|
| 56 |
+
print("WARNING: ClearML not installed. Continuing without logging.")
|
| 57 |
+
_USE_CLEARML = False
|
| 58 |
+
|
| 59 |
+
def _youdens_j(y_true, y_prob):
|
| 60 |
+
fpr, tpr, thresholds = roc_curve(y_true, y_prob)
|
| 61 |
+
j = tpr - fpr
|
| 62 |
+
idx = j.argmax()
|
| 63 |
+
auc = roc_auc_score(y_true, y_prob)
|
| 64 |
+
return float(thresholds[idx]), fpr, tpr, thresholds, float(auc)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _f1_at_threshold(y_true, y_prob, threshold):
|
| 68 |
+
return f1_score(y_true, (y_prob >= threshold).astype(int), zero_division=0)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _plot_roc(fpr, tpr, auc, opt_thresh, opt_idx, title, path, clearml_title=None):
|
| 72 |
+
fig, ax = plt.subplots(figsize=(6, 5))
|
| 73 |
+
ax.plot(fpr, tpr, lw=2, label=f"ROC (AUC = {auc:.4f})")
|
| 74 |
+
ax.plot(fpr[opt_idx], tpr[opt_idx], "ro", markersize=10,
|
| 75 |
+
label=f"Youden's J optimum (t = {opt_thresh:.3f})")
|
| 76 |
+
ax.plot([0, 1], [0, 1], "k--", lw=1, alpha=0.5)
|
| 77 |
+
ax.set_xlabel("False Positive Rate")
|
| 78 |
+
ax.set_ylabel("True Positive Rate")
|
| 79 |
+
ax.set_title(title)
|
| 80 |
+
ax.legend(loc="lower right")
|
| 81 |
+
fig.tight_layout()
|
| 82 |
+
|
| 83 |
+
# Log to ClearML before closing the figure
|
| 84 |
+
if _logger and clearml_title:
|
| 85 |
+
_logger.report_matplotlib_figure(
|
| 86 |
+
title=clearml_title, series="ROC", figure=fig, iteration=0
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
fig.savefig(path, dpi=150)
|
| 90 |
+
plt.close(fig)
|
| 91 |
+
print(f" saved {path}")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def run_lopo_models():
|
| 95 |
+
print("\n=== LOPO: MLP and XGBoost ===")
|
| 96 |
+
by_person, _, _ = load_per_person("face_orientation")
|
| 97 |
+
persons = sorted(by_person.keys())
|
| 98 |
+
|
| 99 |
+
results = {"mlp": {"y": [], "p": []}, "xgb": {"y": [], "p": []}}
|
| 100 |
+
|
| 101 |
+
for i, held_out in enumerate(persons):
|
| 102 |
+
X_test, y_test = by_person[held_out]
|
| 103 |
+
|
| 104 |
+
train_X = np.concatenate([by_person[p][0] for p in persons if p != held_out])
|
| 105 |
+
train_y = np.concatenate([by_person[p][1] for p in persons if p != held_out])
|
| 106 |
+
|
| 107 |
+
scaler = StandardScaler().fit(train_X)
|
| 108 |
+
X_tr_sc = scaler.transform(train_X)
|
| 109 |
+
X_te_sc = scaler.transform(X_test)
|
| 110 |
+
|
| 111 |
+
mlp = MLPClassifier(
|
| 112 |
+
hidden_layer_sizes=(64, 32), activation="relu",
|
| 113 |
+
max_iter=200, early_stopping=True, validation_fraction=0.15,
|
| 114 |
+
random_state=SEED, verbose=False,
|
| 115 |
+
)
|
| 116 |
+
mlp.fit(X_tr_sc, train_y)
|
| 117 |
+
mlp_prob = mlp.predict_proba(X_te_sc)[:, 1]
|
| 118 |
+
results["mlp"]["y"].append(y_test)
|
| 119 |
+
results["mlp"]["p"].append(mlp_prob)
|
| 120 |
+
|
| 121 |
+
xgb = build_xgb_classifier(SEED, verbosity=0)
|
| 122 |
+
xgb.fit(X_tr_sc, train_y)
|
| 123 |
+
xgb_prob = xgb.predict_proba(X_te_sc)[:, 1]
|
| 124 |
+
results["xgb"]["y"].append(y_test)
|
| 125 |
+
results["xgb"]["p"].append(xgb_prob)
|
| 126 |
+
|
| 127 |
+
print(f" fold {i+1}/{len(persons)}: held out {held_out} "
|
| 128 |
+
f"({X_test.shape[0]} samples)")
|
| 129 |
+
|
| 130 |
+
for key in results:
|
| 131 |
+
results[key]["y"] = np.concatenate(results[key]["y"])
|
| 132 |
+
results[key]["p"] = np.concatenate(results[key]["p"])
|
| 133 |
+
|
| 134 |
+
return results
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def analyse_model_thresholds(results):
|
| 138 |
+
print("\n=== Model threshold analysis ===")
|
| 139 |
+
model_stats = {}
|
| 140 |
+
|
| 141 |
+
for name, label in [("mlp", "MLP"), ("xgb", "XGBoost")]:
|
| 142 |
+
y, p = results[name]["y"], results[name]["p"]
|
| 143 |
+
opt_t, fpr, tpr, thresholds, auc = _youdens_j(y, p)
|
| 144 |
+
j = tpr - fpr
|
| 145 |
+
opt_idx = j.argmax()
|
| 146 |
+
f1_opt = _f1_at_threshold(y, p, opt_t)
|
| 147 |
+
f1_50 = _f1_at_threshold(y, p, 0.50)
|
| 148 |
+
|
| 149 |
+
path = os.path.join(PLOTS_DIR, f"roc_{name}.png")
|
| 150 |
+
_plot_roc(fpr, tpr, auc, opt_t, opt_idx,
|
| 151 |
+
f"LOPO ROC — {label} (9 folds, 144k samples)", path,
|
| 152 |
+
clearml_title=f"ROC_{label}")
|
| 153 |
+
|
| 154 |
+
model_stats[name] = {
|
| 155 |
+
"label": label, "auc": auc,
|
| 156 |
+
"opt_threshold": opt_t, "f1_opt": f1_opt, "f1_50": f1_50,
|
| 157 |
+
}
|
| 158 |
+
print(f" {label}: AUC={auc:.4f}, optimal threshold={opt_t:.3f} "
|
| 159 |
+
f"(F1={f1_opt:.4f}), F1@0.50={f1_50:.4f}")
|
| 160 |
+
|
| 161 |
+
# Log scalars to ClearML
|
| 162 |
+
if _logger:
|
| 163 |
+
_logger.report_single_value(f"{label} Optimal Threshold", opt_t)
|
| 164 |
+
_logger.report_single_value(f"{label} AUC", auc)
|
| 165 |
+
_logger.report_single_value(f"{label} F1 @ Optimal", f1_opt)
|
| 166 |
+
_logger.report_single_value(f"{label} F1 @ 0.5", f1_50)
|
| 167 |
+
|
| 168 |
+
return model_stats
|
| 169 |
+
|
| 170 |
+
def run_geo_weight_search():
|
| 171 |
+
print("\n=== Geometric weight grid search ===")
|
| 172 |
+
|
| 173 |
+
by_person, _, _ = load_per_person("face_orientation")
|
| 174 |
+
persons = sorted(by_person.keys())
|
| 175 |
+
features = SELECTED_FEATURES["face_orientation"]
|
| 176 |
+
sf_idx = features.index("s_face")
|
| 177 |
+
se_idx = features.index("s_eye")
|
| 178 |
+
|
| 179 |
+
alphas = np.arange(0.2, 0.85, 0.1).round(1)
|
| 180 |
+
alpha_f1 = {a: [] for a in alphas}
|
| 181 |
+
|
| 182 |
+
for held_out in persons:
|
| 183 |
+
X_test, y_test = by_person[held_out]
|
| 184 |
+
sf = X_test[:, sf_idx]
|
| 185 |
+
se = X_test[:, se_idx]
|
| 186 |
+
|
| 187 |
+
train_X = np.concatenate([by_person[p][0] for p in persons if p != held_out])
|
| 188 |
+
train_y = np.concatenate([by_person[p][1] for p in persons if p != held_out])
|
| 189 |
+
sf_tr = train_X[:, sf_idx]
|
| 190 |
+
se_tr = train_X[:, se_idx]
|
| 191 |
+
|
| 192 |
+
for a in alphas:
|
| 193 |
+
score_tr = a * sf_tr + (1.0 - a) * se_tr
|
| 194 |
+
opt_t, *_ = _youdens_j(train_y, score_tr)
|
| 195 |
+
|
| 196 |
+
score_te = a * sf + (1.0 - a) * se
|
| 197 |
+
f1 = _f1_at_threshold(y_test, score_te, opt_t)
|
| 198 |
+
alpha_f1[a].append(f1)
|
| 199 |
+
|
| 200 |
+
mean_f1 = {a: np.mean(f1s) for a, f1s in alpha_f1.items()}
|
| 201 |
+
best_alpha = max(mean_f1, key=mean_f1.get)
|
| 202 |
+
|
| 203 |
+
fig, ax = plt.subplots(figsize=(7, 4))
|
| 204 |
+
ax.bar([f"{a:.1f}" for a in alphas],
|
| 205 |
+
[mean_f1[a] for a in alphas], color="steelblue")
|
| 206 |
+
ax.set_xlabel("Face weight (alpha); eye weight = 1 - alpha")
|
| 207 |
+
ax.set_ylabel("Mean LOPO F1")
|
| 208 |
+
ax.set_title("Geometric Pipeline: Face vs Eye Weight Search")
|
| 209 |
+
ax.set_ylim(bottom=max(0, min(mean_f1.values()) - 0.05))
|
| 210 |
+
for i, a in enumerate(alphas):
|
| 211 |
+
ax.text(i, mean_f1[a] + 0.003, f"{mean_f1[a]:.3f}",
|
| 212 |
+
ha="center", va="bottom", fontsize=8)
|
| 213 |
+
fig.tight_layout()
|
| 214 |
+
|
| 215 |
+
# Log to ClearML before closing
|
| 216 |
+
if _logger:
|
| 217 |
+
_logger.report_matplotlib_figure(
|
| 218 |
+
title="Geo Weight Search", series="F1 vs Alpha", figure=fig, iteration=0
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
path = os.path.join(PLOTS_DIR, "geo_weight_search.png")
|
| 222 |
+
fig.savefig(path, dpi=150)
|
| 223 |
+
plt.close(fig)
|
| 224 |
+
print(f" saved {path}")
|
| 225 |
+
|
| 226 |
+
print(f" Best alpha (face weight) = {best_alpha:.1f}, "
|
| 227 |
+
f"mean LOPO F1 = {mean_f1[best_alpha]:.4f}")
|
| 228 |
+
|
| 229 |
+
# Log scalars to ClearML
|
| 230 |
+
if _logger:
|
| 231 |
+
_logger.report_single_value("Geo Best Alpha", best_alpha)
|
| 232 |
+
for i, a in enumerate(sorted(alphas)):
|
| 233 |
+
_logger.report_scalar(
|
| 234 |
+
"Geo Weight Search", "Mean LOPO F1",
|
| 235 |
+
iteration=i, value=mean_f1[a]
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
return dict(mean_f1), best_alpha
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def run_hybrid_weight_search(lopo_results):
|
| 242 |
+
print("\n=== Hybrid weight grid search ===")
|
| 243 |
+
|
| 244 |
+
by_person, _, _ = load_per_person("face_orientation")
|
| 245 |
+
persons = sorted(by_person.keys())
|
| 246 |
+
features = SELECTED_FEATURES["face_orientation"]
|
| 247 |
+
sf_idx = features.index("s_face")
|
| 248 |
+
se_idx = features.index("s_eye")
|
| 249 |
+
|
| 250 |
+
GEO_FACE_W = 0.7
|
| 251 |
+
GEO_EYE_W = 0.3
|
| 252 |
+
|
| 253 |
+
w_mlps = np.arange(0.3, 0.85, 0.1).round(1)
|
| 254 |
+
wmf1 = {w: [] for w in w_mlps}
|
| 255 |
+
mlp_p = lopo_results["mlp"]["p"]
|
| 256 |
+
offset = 0
|
| 257 |
+
for held_out in persons:
|
| 258 |
+
X_test, y_test = by_person[held_out]
|
| 259 |
+
n = X_test.shape[0]
|
| 260 |
+
mlp_prob_fold = mlp_p[offset:offset + n]
|
| 261 |
+
offset += n
|
| 262 |
+
|
| 263 |
+
sf = X_test[:, sf_idx]
|
| 264 |
+
se = X_test[:, se_idx]
|
| 265 |
+
geo_score = np.clip(GEO_FACE_W * sf + GEO_EYE_W * se, 0, 1)
|
| 266 |
+
|
| 267 |
+
train_X = np.concatenate([by_person[p][0] for p in persons if p != held_out])
|
| 268 |
+
train_y = np.concatenate([by_person[p][1] for p in persons if p != held_out])
|
| 269 |
+
sf_tr = train_X[:, sf_idx]
|
| 270 |
+
se_tr = train_X[:, se_idx]
|
| 271 |
+
geo_tr = np.clip(GEO_FACE_W * sf_tr + GEO_EYE_W * se_tr, 0, 1)
|
| 272 |
+
|
| 273 |
+
scaler = StandardScaler().fit(train_X)
|
| 274 |
+
mlp_tr = MLPClassifier(
|
| 275 |
+
hidden_layer_sizes=(64, 32), activation="relu",
|
| 276 |
+
max_iter=200, early_stopping=True, validation_fraction=0.15,
|
| 277 |
+
random_state=SEED, verbose=False,
|
| 278 |
+
)
|
| 279 |
+
mlp_tr.fit(scaler.transform(train_X), train_y)
|
| 280 |
+
mlp_prob_tr = mlp_tr.predict_proba(scaler.transform(train_X))[:, 1]
|
| 281 |
+
|
| 282 |
+
for w in w_mlps:
|
| 283 |
+
combo_tr = w * mlp_prob_tr + (1.0 - w) * geo_tr
|
| 284 |
+
opt_t, *_ = _youdens_j(train_y, combo_tr)
|
| 285 |
+
|
| 286 |
+
combo_te = w * mlp_prob_fold + (1.0 - w) * geo_score
|
| 287 |
+
f1 = _f1_at_threshold(y_test, combo_te, opt_t)
|
| 288 |
+
wmf1[w].append(f1)
|
| 289 |
+
|
| 290 |
+
mean_f1 = {w: np.mean(f1s) for w, f1s in wmf1.items()}
|
| 291 |
+
best_w = max(mean_f1, key=mean_f1.get)
|
| 292 |
+
|
| 293 |
+
fig, ax = plt.subplots(figsize=(7, 4))
|
| 294 |
+
ax.bar([f"{w:.1f}" for w in w_mlps],
|
| 295 |
+
[mean_f1[w] for w in w_mlps], color="darkorange")
|
| 296 |
+
ax.set_xlabel("MLP weight (w_mlp); geo weight = 1 - w_mlp")
|
| 297 |
+
ax.set_ylabel("Mean LOPO F1")
|
| 298 |
+
ax.set_title("Hybrid Pipeline: MLP vs Geometric Weight Search")
|
| 299 |
+
ax.set_ylim(bottom=max(0, min(mean_f1.values()) - 0.05))
|
| 300 |
+
for i, w in enumerate(w_mlps):
|
| 301 |
+
ax.text(i, mean_f1[w] + 0.003, f"{mean_f1[w]:.3f}",
|
| 302 |
+
ha="center", va="bottom", fontsize=8)
|
| 303 |
+
fig.tight_layout()
|
| 304 |
+
|
| 305 |
+
# Log to ClearML before closing
|
| 306 |
+
if _logger:
|
| 307 |
+
_logger.report_matplotlib_figure(
|
| 308 |
+
title="Hybrid Weight Search", series="F1 vs w_mlp", figure=fig, iteration=0
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
path = os.path.join(PLOTS_DIR, "hybrid_weight_search.png")
|
| 312 |
+
fig.savefig(path, dpi=150)
|
| 313 |
+
plt.close(fig)
|
| 314 |
+
print(f" saved {path}")
|
| 315 |
+
|
| 316 |
+
print(f" Best w_mlp = {best_w:.1f}, mean LOPO F1 = {mean_f1[best_w]:.4f}")
|
| 317 |
+
|
| 318 |
+
# Log scalars to ClearML
|
| 319 |
+
if _logger:
|
| 320 |
+
_logger.report_single_value("Hybrid Best w_mlp", best_w)
|
| 321 |
+
for i, w in enumerate(sorted(w_mlps)):
|
| 322 |
+
_logger.report_scalar(
|
| 323 |
+
"Hybrid Weight Search", "Mean LOPO F1",
|
| 324 |
+
iteration=i, value=mean_f1[w]
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
return dict(mean_f1), best_w
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def plot_distributions():
|
| 331 |
+
print("\n=== EAR / MAR distributions ===")
|
| 332 |
+
npz_files = sorted(glob.glob(os.path.join(_PROJECT_ROOT, "data", "collected_*", "*.npz")))
|
| 333 |
+
|
| 334 |
+
all_ear_l, all_ear_r, all_mar, all_labels = [], [], [], []
|
| 335 |
+
for f in npz_files:
|
| 336 |
+
d = np.load(f, allow_pickle=True)
|
| 337 |
+
names = list(d["feature_names"])
|
| 338 |
+
feat = d["features"].astype(np.float32)
|
| 339 |
+
lab = d["labels"].astype(np.int64)
|
| 340 |
+
all_ear_l.append(feat[:, names.index("ear_left")])
|
| 341 |
+
all_ear_r.append(feat[:, names.index("ear_right")])
|
| 342 |
+
all_mar.append(feat[:, names.index("mar")])
|
| 343 |
+
all_labels.append(lab)
|
| 344 |
+
|
| 345 |
+
ear_l = np.concatenate(all_ear_l)
|
| 346 |
+
ear_r = np.concatenate(all_ear_r)
|
| 347 |
+
mar = np.concatenate(all_mar)
|
| 348 |
+
labels = np.concatenate(all_labels)
|
| 349 |
+
ear_min = np.minimum(ear_l, ear_r)
|
| 350 |
+
ear_plot = np.clip(ear_min, 0, 0.85)
|
| 351 |
+
mar_plot = np.clip(mar, 0, 1.5)
|
| 352 |
+
|
| 353 |
+
# EAR distribution plot
|
| 354 |
+
fig_ear, ax = plt.subplots(figsize=(7, 4))
|
| 355 |
+
ax.hist(ear_plot[labels == 1], bins=100, alpha=0.6, label="Focused (1)", density=True)
|
| 356 |
+
ax.hist(ear_plot[labels == 0], bins=100, alpha=0.6, label="Unfocused (0)", density=True)
|
| 357 |
+
for val, lbl, c in [
|
| 358 |
+
(0.16, "ear_closed = 0.16", "red"),
|
| 359 |
+
(0.21, "EAR_BLINK = 0.21", "orange"),
|
| 360 |
+
(0.30, "ear_open = 0.30", "green"),
|
| 361 |
+
]:
|
| 362 |
+
ax.axvline(val, color=c, ls="--", lw=1.5, label=lbl)
|
| 363 |
+
ax.set_xlabel("min(left_EAR, right_EAR)")
|
| 364 |
+
ax.set_ylabel("Density")
|
| 365 |
+
ax.set_title("EAR Distribution by Class (144k samples)")
|
| 366 |
+
ax.legend(fontsize=8)
|
| 367 |
+
fig_ear.tight_layout()
|
| 368 |
+
|
| 369 |
+
# Log to ClearML before closing
|
| 370 |
+
if _logger:
|
| 371 |
+
_logger.report_matplotlib_figure(
|
| 372 |
+
title="EAR Distribution", series="by class", figure=fig_ear, iteration=0
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
path = os.path.join(PLOTS_DIR, "ear_distribution.png")
|
| 376 |
+
fig_ear.savefig(path, dpi=150)
|
| 377 |
+
plt.close(fig_ear)
|
| 378 |
+
print(f" saved {path}")
|
| 379 |
+
|
| 380 |
+
# MAR distribution plot
|
| 381 |
+
fig_mar, ax = plt.subplots(figsize=(7, 4))
|
| 382 |
+
ax.hist(mar_plot[labels == 1], bins=100, alpha=0.6, label="Focused (1)", density=True)
|
| 383 |
+
ax.hist(mar_plot[labels == 0], bins=100, alpha=0.6, label="Unfocused (0)", density=True)
|
| 384 |
+
ax.axvline(0.55, color="red", ls="--", lw=1.5, label="MAR_YAWN = 0.55")
|
| 385 |
+
ax.set_xlabel("Mouth Aspect Ratio (MAR)")
|
| 386 |
+
ax.set_ylabel("Density")
|
| 387 |
+
ax.set_title("MAR Distribution by Class (144k samples)")
|
| 388 |
+
ax.legend(fontsize=8)
|
| 389 |
+
fig_mar.tight_layout()
|
| 390 |
+
|
| 391 |
+
# Log to ClearML before closing
|
| 392 |
+
if _logger:
|
| 393 |
+
_logger.report_matplotlib_figure(
|
| 394 |
+
title="MAR Distribution", series="by class", figure=fig_mar, iteration=0
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
path = os.path.join(PLOTS_DIR, "mar_distribution.png")
|
| 398 |
+
fig_mar.savefig(path, dpi=150)
|
| 399 |
+
plt.close(fig_mar)
|
| 400 |
+
print(f" saved {path}")
|
| 401 |
+
|
| 402 |
+
closed_pct = np.mean(ear_min < 0.16) * 100
|
| 403 |
+
blink_pct = np.mean(ear_min < 0.21) * 100
|
| 404 |
+
open_pct = np.mean(ear_min >= 0.30) * 100
|
| 405 |
+
yawn_pct = np.mean(mar > 0.55) * 100
|
| 406 |
+
|
| 407 |
+
stats = {
|
| 408 |
+
"ear_below_016": closed_pct,
|
| 409 |
+
"ear_below_021": blink_pct,
|
| 410 |
+
"ear_above_030": open_pct,
|
| 411 |
+
"mar_above_055": yawn_pct,
|
| 412 |
+
"n_samples": len(ear_min),
|
| 413 |
+
}
|
| 414 |
+
print(f" EAR<0.16 (closed): {closed_pct:.1f}% | EAR<0.21 (blink): {blink_pct:.1f}% | "
|
| 415 |
+
f"EAR>=0.30 (open): {open_pct:.1f}%")
|
| 416 |
+
print(f" MAR>0.55 (yawn): {yawn_pct:.1f}%")
|
| 417 |
+
return stats
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def write_report(model_stats, geo_f1, best_alpha, hybrid_f1, best_w, dist_stats):
|
| 421 |
+
lines = []
|
| 422 |
+
lines.append("# Threshold Justification Report")
|
| 423 |
+
lines.append("")
|
| 424 |
+
lines.append("Auto-generated by `evaluation/justify_thresholds.py` using LOPO cross-validation "
|
| 425 |
+
"over 9 participants (~145k samples).")
|
| 426 |
+
lines.append("")
|
| 427 |
+
|
| 428 |
+
lines.append("## 1. ML Model Decision Thresholds")
|
| 429 |
+
lines.append("")
|
| 430 |
+
lines.append(f"XGBoost config used for this report: `{XGB_BASE_PARAMS}`.")
|
| 431 |
+
lines.append("")
|
| 432 |
+
lines.append("Thresholds selected via **Youden's J statistic** (J = sensitivity + specificity - 1) "
|
| 433 |
+
"on pooled LOPO held-out predictions.")
|
| 434 |
+
lines.append("")
|
| 435 |
+
lines.append("| Model | LOPO AUC | Optimal Threshold (Youden's J) | F1 @ Optimal | F1 @ 0.50 |")
|
| 436 |
+
lines.append("|-------|----------|-------------------------------|--------------|-----------|")
|
| 437 |
+
for key in ("mlp", "xgb"):
|
| 438 |
+
s = model_stats[key]
|
| 439 |
+
lines.append(f"| {s['label']} | {s['auc']:.4f} | **{s['opt_threshold']:.3f}** | "
|
| 440 |
+
f"{s['f1_opt']:.4f} | {s['f1_50']:.4f} |")
|
| 441 |
+
lines.append("")
|
| 442 |
+
lines.append("")
|
| 443 |
+
lines.append("")
|
| 444 |
+
lines.append("")
|
| 445 |
+
lines.append("")
|
| 446 |
+
|
| 447 |
+
lines.append("## 2. Geometric Pipeline Weights (s_face vs s_eye)")
|
| 448 |
+
lines.append("")
|
| 449 |
+
lines.append("Grid search over face weight alpha in {0.2 ... 0.8}. "
|
| 450 |
+
"Eye weight = 1 - alpha. Threshold per fold via Youden's J.")
|
| 451 |
+
lines.append("")
|
| 452 |
+
lines.append("| Face Weight (alpha) | Mean LOPO F1 |")
|
| 453 |
+
lines.append("|--------------------:|-------------:|")
|
| 454 |
+
for a in sorted(geo_f1.keys()):
|
| 455 |
+
marker = " **<-- selected**" if a == best_alpha else ""
|
| 456 |
+
lines.append(f"| {a:.1f} | {geo_f1[a]:.4f}{marker} |")
|
| 457 |
+
lines.append("")
|
| 458 |
+
lines.append(f"**Best:** alpha = {best_alpha:.1f} (face {best_alpha*100:.0f}%, "
|
| 459 |
+
f"eye {(1-best_alpha)*100:.0f}%)")
|
| 460 |
+
lines.append("")
|
| 461 |
+
lines.append("")
|
| 462 |
+
lines.append("")
|
| 463 |
+
|
| 464 |
+
lines.append("## 3. Hybrid Pipeline Weights (MLP vs Geometric)")
|
| 465 |
+
lines.append("")
|
| 466 |
+
lines.append("Grid search over w_mlp in {0.3 ... 0.8}. w_geo = 1 - w_mlp. "
|
| 467 |
+
"Geometric sub-score uses same weights as geometric pipeline (face=0.7, eye=0.3). "
|
| 468 |
+
"If you change geometric weights, re-run this script — optimal w_mlp can shift.")
|
| 469 |
+
lines.append("")
|
| 470 |
+
lines.append("| MLP Weight (w_mlp) | Mean LOPO F1 |")
|
| 471 |
+
lines.append("|-------------------:|-------------:|")
|
| 472 |
+
for w in sorted(hybrid_f1.keys()):
|
| 473 |
+
marker = " **<-- selected**" if w == best_w else ""
|
| 474 |
+
lines.append(f"| {w:.1f} | {hybrid_f1[w]:.4f}{marker} |")
|
| 475 |
+
lines.append("")
|
| 476 |
+
lines.append(f"**Best:** w_mlp = {best_w:.1f} (MLP {best_w*100:.0f}%, "
|
| 477 |
+
f"geometric {(1-best_w)*100:.0f}%)")
|
| 478 |
+
lines.append("")
|
| 479 |
+
lines.append("")
|
| 480 |
+
lines.append("")
|
| 481 |
+
|
| 482 |
+
lines.append("## 4. Eye and Mouth Aspect Ratio Thresholds")
|
| 483 |
+
lines.append("")
|
| 484 |
+
lines.append("### EAR (Eye Aspect Ratio)")
|
| 485 |
+
lines.append("")
|
| 486 |
+
lines.append("Reference: Soukupova & Cech, \"Real-Time Eye Blink Detection Using Facial "
|
| 487 |
+
"Landmarks\" (2016) established EAR ~ 0.2 as a blink threshold.")
|
| 488 |
+
lines.append("")
|
| 489 |
+
lines.append("Our thresholds define a linear interpolation zone around this established value:")
|
| 490 |
+
lines.append("")
|
| 491 |
+
lines.append("| Constant | Value | Justification |")
|
| 492 |
+
lines.append("|----------|------:|---------------|")
|
| 493 |
+
lines.append(f"| `ear_closed` | 0.16 | Below this, eyes are fully shut. "
|
| 494 |
+
f"{dist_stats['ear_below_016']:.1f}% of samples fall here. |")
|
| 495 |
+
lines.append(f"| `EAR_BLINK_THRESH` | 0.21 | Blink detection point; close to the 0.2 reference. "
|
| 496 |
+
f"{dist_stats['ear_below_021']:.1f}% of samples below. |")
|
| 497 |
+
lines.append(f"| `ear_open` | 0.30 | Above this, eyes are fully open. "
|
| 498 |
+
f"{dist_stats['ear_above_030']:.1f}% of samples here. |")
|
| 499 |
+
lines.append("")
|
| 500 |
+
lines.append("Between 0.16 and 0.30 the `_ear_score` function linearly interpolates from 0 to 1, "
|
| 501 |
+
"providing a smooth transition rather than a hard binary cutoff.")
|
| 502 |
+
lines.append("")
|
| 503 |
+
lines.append("")
|
| 504 |
+
lines.append("")
|
| 505 |
+
lines.append("### MAR (Mouth Aspect Ratio)")
|
| 506 |
+
lines.append("")
|
| 507 |
+
lines.append(f"| Constant | Value | Justification |")
|
| 508 |
+
lines.append("|----------|------:|---------------|")
|
| 509 |
+
lines.append(f"| `MAR_YAWN_THRESHOLD` | 0.55 | Only {dist_stats['mar_above_055']:.1f}% of "
|
| 510 |
+
f"samples exceed this, confirming it captures genuine yawns without false positives. |")
|
| 511 |
+
lines.append("")
|
| 512 |
+
lines.append("")
|
| 513 |
+
lines.append("")
|
| 514 |
+
|
| 515 |
+
lines.append("## 5. Other Constants")
|
| 516 |
+
lines.append("")
|
| 517 |
+
lines.append("| Constant | Value | Rationale |")
|
| 518 |
+
lines.append("|----------|------:|-----------|")
|
| 519 |
+
lines.append("| `gaze_max_offset` | 0.28 | Max iris displacement (normalised) before gaze score "
|
| 520 |
+
"drops to zero. Corresponds to ~56% of the eye width; beyond this the iris is at "
|
| 521 |
+
"the extreme edge. |")
|
| 522 |
+
lines.append("| `max_angle` | 22.0 deg | Head deviation beyond which face score = 0. Based on "
|
| 523 |
+
"typical monitor-viewing cone: at 60 cm distance and a 24\" monitor, the viewing "
|
| 524 |
+
"angle is ~20-25 degrees. |")
|
| 525 |
+
lines.append("| `roll_weight` | 0.5 | Roll is less indicative of inattention than yaw/pitch "
|
| 526 |
+
"(tilting head doesn't mean looking away), so it's down-weighted by 50%. |")
|
| 527 |
+
lines.append("| `EMA alpha` | 0.3 | Smoothing factor for focus score. "
|
| 528 |
+
"Gives ~3-4 frame effective window; balances responsiveness vs flicker. |")
|
| 529 |
+
lines.append("| `grace_frames` | 15 | ~0.5 s at 30 fps before penalising no-face. Allows brief "
|
| 530 |
+
"occlusions (e.g. hand gesture) without dropping score. |")
|
| 531 |
+
lines.append("| `PERCLOS_WINDOW` | 60 frames | 2 s at 30 fps; standard PERCLOS measurement "
|
| 532 |
+
"window (Dinges & Grace, 1998). |")
|
| 533 |
+
lines.append("| `BLINK_WINDOW_SEC` | 30 s | Blink rate measured over 30 s; typical spontaneous "
|
| 534 |
+
"blink rate is 15-20/min (Bentivoglio et al., 1997). |")
|
| 535 |
+
lines.append("")
|
| 536 |
+
|
| 537 |
+
with open(REPORT_PATH, "w", encoding="utf-8") as f:
|
| 538 |
+
f.write("\n".join(lines))
|
| 539 |
+
print(f"\nReport written to {REPORT_PATH}")
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
def main():
|
| 543 |
+
os.makedirs(PLOTS_DIR, exist_ok=True)
|
| 544 |
+
|
| 545 |
+
lopo_results = run_lopo_models()
|
| 546 |
+
model_stats = analyse_model_thresholds(lopo_results)
|
| 547 |
+
geo_f1, best_alpha = run_geo_weight_search()
|
| 548 |
+
hybrid_f1, best_w = run_hybrid_weight_search(lopo_results)
|
| 549 |
+
dist_stats = plot_distributions()
|
| 550 |
+
|
| 551 |
+
write_report(model_stats, geo_f1, best_alpha, hybrid_f1, best_w, dist_stats)
|
| 552 |
+
|
| 553 |
+
# Close ClearML task
|
| 554 |
+
if _task:
|
| 555 |
+
from config.clearml_enrich import task_done_summary
|
| 556 |
+
|
| 557 |
+
if os.path.isfile(REPORT_PATH):
|
| 558 |
+
_task.upload_artifact(
|
| 559 |
+
name="threshold_justification_report",
|
| 560 |
+
artifact_object=REPORT_PATH,
|
| 561 |
+
)
|
| 562 |
+
task_done_summary(
|
| 563 |
+
_task,
|
| 564 |
+
"LOPO threshold / weight analysis; see artifact threshold_justification_report and plots in Debug samples.",
|
| 565 |
+
)
|
| 566 |
+
_task.close()
|
| 567 |
+
print("ClearML task closed.")
|
| 568 |
+
|
| 569 |
+
print("\nDone.")
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
if __name__ == "__main__":
|
| 573 |
+
main()
|
evaluation/logs/.gitkeep
ADDED
|
File without changes
|
evaluation/plots/confusion_matrix_mlp.png
ADDED
|
evaluation/plots/confusion_matrix_xgb.png
ADDED
|
evaluation/plots/ear_distribution.png
ADDED
|
evaluation/plots/geo_weight_search.png
ADDED
|
evaluation/plots/hybrid_weight_search.png
ADDED
|
evaluation/plots/hybrid_xgb_weight_search.png
ADDED
|
evaluation/plots/mar_distribution.png
ADDED
|
evaluation/plots/roc_mlp.png
ADDED
|
evaluation/plots/roc_xgb.png
ADDED
|
index.html
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!doctype html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
|
| 4 |
+
<head>
|
| 5 |
+
<meta charset="UTF-8" />
|
| 6 |
+
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
| 7 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 8 |
+
<title>Focus Guard</title>
|
| 9 |
+
<link href="https://fonts.googleapis.com/css2?family=Nunito:wght@400;700&display=swap" rel="stylesheet">
|
| 10 |
+
</head>
|
| 11 |
+
|
| 12 |
+
<body>
|
| 13 |
+
<div id="root"></div>
|
| 14 |
+
<script type="module" src="/src/main.jsx"></script>
|
| 15 |
+
</body>
|
| 16 |
+
|
| 17 |
+
</html>
|