k22056537 commited on
Commit
f488769
·
1 Parent(s): f85441f

evaluation: channel ablation script + feature importance LOPO

Browse files
README.md CHANGED
@@ -1,25 +1,25 @@
1
  # FocusGuard
2
 
3
- Real-time webcam-based focus detection system combining geometric feature extraction with machine learning classification. The pipeline extracts 17 facial features (EAR, gaze, head pose, PERCLOS, blink rate, etc.) from MediaPipe landmarks and classifies attentiveness using MLP and XGBoost models. Served via a React + FastAPI web application with live WebSocket video.
4
 
5
- ## 1. Project Structure
6
 
7
  ```
8
- ├── data/ Raw collected sessions (collected_<name>/*.npz)
9
- ├── data_preparation/ Data loading, cleaning, and exploration
10
- ├── notebooks/ Training notebooks (MLP, XGBoost) with LOPO evaluation
11
- ├── models/ Feature extraction modules and training scripts
12
- ├── checkpoints/ All saved weights (mlp_best.pt, xgboost_*_best.json, GRU, scalers)
13
- ├── evaluation/ Training logs and metrics (JSON)
14
- ├── ui/ Live OpenCV demo and inference pipeline
15
- ├── src/ React/Vite frontend source
16
- ├── static/ Built frontend (served by FastAPI)
17
- ├── app.py / main.py FastAPI backend (API, WebSocket, DB)
18
- ├── requirements.txt Python dependencies
19
- └── package.json Frontend dependencies
20
  ```
21
 
22
- ## 2. Setup
23
 
24
  ```bash
25
  python -m venv venv
@@ -27,65 +27,56 @@ source venv/bin/activate
27
  pip install -r requirements.txt
28
  ```
29
 
30
- Frontend (only needed if modifying the React app):
31
 
32
  ```bash
33
  npm install
34
  npm run build
35
- cp -r dist/* static/
36
  ```
37
 
38
- ## 3. Running
39
 
40
- **Web application (API + frontend):**
41
 
42
  ```bash
43
- uvicorn main:app --host 0.0.0.0 --port 7860
 
44
  ```
45
 
46
- Open http://localhost:7860 in a browser.
47
 
48
- **Live camera demo (OpenCV):**
49
 
50
  ```bash
51
  python ui/live_demo.py
52
- python ui/live_demo.py --xgb # XGBoost mode
53
  ```
54
 
55
- **Training:**
56
 
57
  ```bash
58
- python -m models.mlp.train # MLP
59
- python -m models.xgboost.train # XGBoost
60
  ```
61
 
62
- ## 4. Dataset
63
-
64
- - **9 participants**, each recorded via webcam with real-time labelling (focused / unfocused)
65
- - **144,793 total samples**, 10 selected features, binary classification
66
- - Collected using `python -m models.collect_features --name <name>`
67
- - Stored as `.npz` files in `data/collected_<name>/`
68
-
69
- ## 5. Models
70
 
71
- | Model | Test Accuracy | Test F1 | ROC-AUC |
72
- |-------|--------------|---------|---------|
73
- | XGBoost (600 trees, depth 8, lr 0.149) | 95.87% | 0.959 | 0.991 |
74
- | MLP (64→32, 30 epochs, lr 1e-3) | 92.92% | 0.929 | 0.971 |
75
 
76
- Both evaluated on a held-out 15% stratified test split. LOPO (Leave-One-Person-Out) cross-validation available in `notebooks/`.
77
 
78
- ## 6. Feature Pipeline
 
 
 
79
 
80
- 1. **Face mesh** — MediaPipe 478-landmark detection
81
- 2. **Head pose** — solvePnP → yaw, pitch, roll, face score, gaze offset, head deviation
82
- 3. **Eye scorer** — EAR (left/right/avg), horizontal/vertical gaze ratio, MAR
83
- 4. **Temporal tracking** — PERCLOS, blink rate, closure duration, yawn duration
84
- 5. **Classification** — 10-feature vector → MLP or XGBoost → focused / unfocused
85
 
86
- ## 7. Tech Stack
 
 
 
 
87
 
88
- - **Backend:** Python, FastAPI, WebSocket, aiosqlite
89
- - **Frontend:** React, Vite, TypeScript
90
- - **ML:** PyTorch (MLP), XGBoost, scikit-learn
91
- - **Vision:** MediaPipe, OpenCV
 
1
  # FocusGuard
2
 
3
+ Webcam-based focus detection: MediaPipe face mesh 17 features (EAR, gaze, head pose, PERCLOS, etc.) MLP or XGBoost for focused/unfocused. React + FastAPI app with WebSocket video.
4
 
5
+ ## Project layout
6
 
7
  ```
8
+ ├── data/ collected_<name>/*.npz
9
+ ├── data_preparation/ loaders, split, scale
10
+ ├── notebooks/ MLP/XGB training + LOPO
11
+ ├── models/ face_mesh, head_pose, eye_scorer, train scripts
12
+ ├── checkpoints/ mlp_best.pt, xgboost_*_best.json, scalers
13
+ ├── evaluation/ logs, plots, justify_thresholds
14
+ ├── ui/ pipeline.py, live_demo.py
15
+ ├── src/ React frontend
16
+ ├── static/ built frontend (after npm run build)
17
+ ├── main.py, app.py FastAPI backend
18
+ ├── requirements.txt
19
+ └── package.json
20
  ```
21
 
22
+ ## Setup
23
 
24
  ```bash
25
  python -m venv venv
 
27
  pip install -r requirements.txt
28
  ```
29
 
30
+ To rebuild the frontend after changes:
31
 
32
  ```bash
33
  npm install
34
  npm run build
35
+ mkdir -p static && cp -r dist/* static/
36
  ```
37
 
38
+ ## Run
39
 
40
+ **Web app:** Use the venv and run uvicorn via Python so it picks up your deps (otherwise you get `ModuleNotFoundError: aiosqlite`):
41
 
42
  ```bash
43
+ source venv/bin/activate
44
+ python -m uvicorn main:app --host 0.0.0.0 --port 7860
45
  ```
46
 
47
+ Then open http://localhost:7860.
48
 
49
+ **OpenCV demo:**
50
 
51
  ```bash
52
  python ui/live_demo.py
53
+ python ui/live_demo.py --xgb
54
  ```
55
 
56
+ **Train:**
57
 
58
  ```bash
59
+ python -m models.mlp.train
60
+ python -m models.xgboost.train
61
  ```
62
 
63
+ ## Data
 
 
 
 
 
 
 
64
 
65
+ 9 participants, 144,793 samples, 10 features, binary labels. Collect with `python -m models.collect_features --name <name>`. Data lives in `data/collected_<name>/`.
 
 
 
66
 
67
+ ## Model numbers (15% test split)
68
 
69
+ | Model | Accuracy | F1 | ROC-AUC |
70
+ |-------|----------|-----|---------|
71
+ | XGBoost (600 trees, depth 8) | 95.87% | 0.959 | 0.991 |
72
+ | MLP (64→32) | 92.92% | 0.929 | 0.971 |
73
 
74
+ ## Pipeline
 
 
 
 
75
 
76
+ 1. Face mesh (MediaPipe 478 pts)
77
+ 2. Head pose → yaw, pitch, roll, scores, gaze offset
78
+ 3. Eye scorer → EAR, gaze ratio, MAR
79
+ 4. Temporal → PERCLOS, blink rate, yawn
80
+ 5. 10-d vector → MLP or XGBoost → focused / unfocused
81
 
82
+ **Stack:** FastAPI, aiosqlite, React/Vite, PyTorch, XGBoost, MediaPipe, OpenCV.
 
 
 
data_preparation/README.md CHANGED
@@ -1,75 +1,9 @@
1
  # data_preparation/
2
 
3
- Shared data loading, cleaning, and exploratory analysis.
4
 
5
- ## 1. Files
6
 
7
- | File | Description |
8
- |------|-------------|
9
- | `prepare_dataset.py` | Central data loading module used by all training scripts and notebooks |
10
- | `data_exploration.ipynb` | EDA notebook: feature distributions, class balance, correlations |
11
 
12
- ## 2. prepare_dataset.py
13
-
14
- Provides a consistent pipeline for loading raw `.npz` data from `data/`:
15
-
16
- | Function | Purpose |
17
- |----------|---------|
18
- | `load_all_pooled(model_name)` | Load all participants, clean, select features, concatenate |
19
- | `load_per_person(model_name)` | Load grouped by person (for LOPO cross-validation) |
20
- | `get_numpy_splits(model_name)` | Load + stratified 70/15/15 split + StandardScaler |
21
- | `get_dataloaders(model_name)` | Same as above, wrapped in PyTorch DataLoaders |
22
- | `_split_and_scale(features, labels, ...)` | Reusable split + optional scaling |
23
-
24
- ### Cleaning rules
25
-
26
- - `yaw` clipped to [-45, 45], `pitch`/`roll` to [-30, 30]
27
- - `ear_left`, `ear_right`, `ear_avg` clipped to [0, 0.85]
28
-
29
- ### Selected features (face_orientation)
30
-
31
- `head_deviation`, `s_face`, `s_eye`, `h_gaze`, `pitch`, `ear_left`, `ear_avg`, `ear_right`, `gaze_offset`, `perclos`
32
-
33
- ## 3. data_exploration.ipynb
34
-
35
- Run from this folder or from the project root. Covers:
36
-
37
- 1. Per-feature statistics (mean, std, min, max)
38
- 2. Class distribution (focused vs unfocused)
39
- 3. Feature histograms and box plots
40
- 4. Correlation matrix
41
-
42
- ## 4. How to run
43
-
44
- `prepare_dataset.py` is a **library module**, not a standalone script. You don’t run it directly; you import it from code that needs data.
45
-
46
- **From repo root:**
47
-
48
- ```bash
49
- # Optional: quick test that loading works
50
- python -c "
51
- from data_preparation.prepare_dataset import load_all_pooled
52
- X, y, names = load_all_pooled('face_orientation')
53
- print(f'Loaded {X.shape[0]} samples, {X.shape[1]} features: {names}')
54
- "
55
- ```
56
-
57
- **Used by:**
58
-
59
- - `python -m models.mlp.train`
60
- - `python -m models.xgboost.train`
61
- - `notebooks/mlp.ipynb`, `notebooks/xgboost.ipynb`
62
- - `data_preparation/data_exploration.ipynb`
63
-
64
- ## 5. Usage (in code)
65
-
66
- ```python
67
- from data_preparation.prepare_dataset import load_all_pooled, get_numpy_splits
68
-
69
- # pooled data
70
- X, y, names = load_all_pooled("face_orientation")
71
-
72
- # ready-to-train splits
73
- splits, n_features, n_classes, scaler = get_numpy_splits("face_orientation")
74
- X_train, y_train = splits["X_train"], splits["y_train"]
75
- ```
 
1
  # data_preparation/
2
 
3
+ Load and split the .npz data. Used by all training code and notebooks.
4
 
5
+ **prepare_dataset.py:** `load_all_pooled()`, `load_per_person()` for LOPO, `get_numpy_splits()` (XGBoost), `get_dataloaders()` (MLP). Cleans yaw/pitch/roll and EAR to fixed ranges. Face_orientation uses 10 features: head_deviation, s_face, s_eye, h_gaze, pitch, ear_left, ear_avg, ear_right, gaze_offset, perclos.
6
 
7
+ **data_exploration.ipynb:** EDA stats, class balance, histograms, correlations.
 
 
 
8
 
9
+ You don’t run prepare_dataset directly; import it from `models.mlp.train`, `models.xgboost.train`, or the notebooks.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_preparation/prepare_dataset.py CHANGED
@@ -209,6 +209,8 @@ def get_numpy_splits(model_name: str, split_ratios=(0.7, 0.15, 0.15), seed: int
209
  features, labels = _load_real_data(model_name)
210
  num_features = features.shape[1]
211
  num_classes = int(labels.max()) + 1
 
 
212
  splits, scaler = _split_and_scale(features, labels, split_ratios, seed, scale)
213
  return splits, num_features, num_classes, scaler
214
 
@@ -218,6 +220,8 @@ def get_dataloaders(model_name: str, batch_size: int = 32, split_ratios=(0.7, 0.
218
  features, labels = _load_real_data(model_name)
219
  num_features = features.shape[1]
220
  num_classes = int(labels.max()) + 1
 
 
221
  splits, scaler = _split_and_scale(features, labels, split_ratios, seed, scale)
222
 
223
  train_ds = FeatureVectorDataset(splits["X_train"], splits["y_train"])
 
209
  features, labels = _load_real_data(model_name)
210
  num_features = features.shape[1]
211
  num_classes = int(labels.max()) + 1
212
+ if num_classes < 2:
213
+ raise ValueError("Dataset has only one class; need at least 2 for classification.")
214
  splits, scaler = _split_and_scale(features, labels, split_ratios, seed, scale)
215
  return splits, num_features, num_classes, scaler
216
 
 
220
  features, labels = _load_real_data(model_name)
221
  num_features = features.shape[1]
222
  num_classes = int(labels.max()) + 1
223
+ if num_classes < 2:
224
+ raise ValueError("Dataset has only one class; need at least 2 for classification.")
225
  splits, scaler = _split_and_scale(features, labels, split_ratios, seed, scale)
226
 
227
  train_ds = FeatureVectorDataset(splits["X_train"], splits["y_train"])
evaluation/README.md CHANGED
@@ -1,79 +1,19 @@
1
  # evaluation/
2
 
3
- Training logs, threshold analysis, and performance metrics.
4
 
5
- ## 1. Contents
6
 
7
- ```
8
- logs/ # training run logs (JSON)
9
- plots/ # threshold justification figures (ROC, weight search, EAR/MAR)
10
- justify_thresholds.py # LOPO analysis script
11
- feature_importance.py # XGBoost importance + leave-one-out ablation
12
- THRESHOLD_JUSTIFICATION.md # report (auto-generated by justify_thresholds)
13
- feature_selection_justification.md # report (auto-generated by feature_importance)
14
- ```
15
-
16
- **Logs (when present):** Scripts write to `evaluation/logs/`. MLP script: `face_orientation_training_log.json`; XGBoost script/notebook: `xgboost_face_orientation_training_log.json`; MLP notebook may write `mlp_face_orientation_training_log.json`.
17
- ```
18
- logs/
19
- ├── face_orientation_training_log.json # from models/mlp/train.py
20
- ├── xgboost_face_orientation_training_log.json
21
- └── (optional) mlp_face_orientation_training_log.json # from notebooks/mlp.ipynb
22
- ```
23
-
24
- ## 2. Log Format
25
-
26
- Each JSON file records the full training history:
27
-
28
- **MLP logs:**
29
- ```json
30
- {
31
- "config": { "epochs": 30, "lr": 0.001, "batch_size": 32, ... },
32
- "history": {
33
- "train_loss": [0.287, 0.260, ...],
34
- "val_loss": [0.256, 0.245, ...],
35
- "train_acc": [0.889, 0.901, ...],
36
- "val_acc": [0.905, 0.909, ...]
37
- },
38
- "test": { "accuracy": 0.929, "f1": 0.929, "roc_auc": 0.971 }
39
- }
40
- ```
41
-
42
- **XGBoost logs:**
43
- ```json
44
- {
45
- "config": { "n_estimators": 600, "max_depth": 8, "learning_rate": 0.149, ... },
46
- "train_losses": [0.577, ...],
47
- "val_losses": [0.576, ...],
48
- "test": { "accuracy": 0.959, "f1": 0.959, "roc_auc": 0.991 }
49
- }
50
- ```
51
-
52
- ## 3. Threshold justification
53
 
54
- Thresholds and weights used in the app (geometric, MLP, XGBoost, hybrid) are justified in **THRESHOLD_JUSTIFICATION.md**. The report is generated by:
55
 
56
  ```bash
57
  python -m evaluation.justify_thresholds
58
  ```
59
 
60
- From repo root, with venv active. The script runs LOPO over 9 participants (~145k samples), computes ROC + Youden's J for ML/XGB thresholds, grid-searches geometric and hybrid weights, and plots EAR/MAR distributions. It writes:
61
-
62
- - `plots/roc_mlp.png`, `plots/roc_xgb.png`
63
- - `plots/geo_weight_search.png`, `plots/hybrid_weight_search.png`
64
- - `plots/ear_distribution.png`, `plots/mar_distribution.png`
65
- - `THRESHOLD_JUSTIFICATION.md`
66
-
67
- Takes ~10–15 minutes. Re-run after changing data or pipeline weights (e.g. geometric face/eye); hybrid optimal w_mlp depends on the geometric sub-score weights.
68
-
69
- ## 4. Feature selection justification
70
-
71
- Run `python -m evaluation.feature_importance` to compute XGBoost gain-based importance for the 10 face_orientation features and a leave-one-feature-out LOPO ablation. Writes **feature_selection_justification.md** with tables. Use this to justify the 10-of-17 feature set (ablation + importance; see PAPER_AUDIT §2.7).
72
 
73
- ## 5. Generated by
74
 
75
- - `python -m models.mlp.train` MLP log in `logs/`
76
- - `python -m models.xgboost.train` → XGBoost log in `logs/`
77
- - `python -m evaluation.justify_thresholds` → plots + THRESHOLD_JUSTIFICATION.md
78
- - `python -m evaluation.feature_importance` → feature_selection_justification.md
79
- - Notebooks in `notebooks/` can also write logs here
 
1
  # evaluation/
2
 
3
+ Training logs, threshold/weight analysis, and metrics.
4
 
5
+ **Contents:** `logs/` (JSON from training runs), `plots/` (ROC, weight search, EAR/MAR), `justify_thresholds.py`, `feature_importance.py`, and the generated markdown reports.
6
 
7
+ **Logs:** MLP writes `face_orientation_training_log.json`, XGBoost writes `xgboost_face_orientation_training_log.json`. Paths: `evaluation/logs/`.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ **Threshold report:** Generate `THRESHOLD_JUSTIFICATION.md` and plots with:
10
 
11
  ```bash
12
  python -m evaluation.justify_thresholds
13
  ```
14
 
15
+ (LOPO over 9 participants, Youdens J, weight grid search; ~10–15 min.) Outputs go to `plots/` and the markdown file.
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ **Feature importance:** Run `python -m evaluation.feature_importance` for XGBoost gain and leave-one-feature-out LOPO; writes `feature_selection_justification.md`.
18
 
19
+ **Who writes here:** `models.mlp.train`, `models.xgboost.train`, `evaluation.justify_thresholds`, `evaluation.feature_importance`, and the notebooks.
 
 
 
 
evaluation/feature_importance.py CHANGED
@@ -70,7 +70,7 @@ def run_ablation_lopo():
70
  n_estimators=600, max_depth=8, learning_rate=0.05,
71
  subsample=0.8, colsample_bytree=0.8,
72
  reg_alpha=0.1, reg_lambda=1.0,
73
- use_label_encoder=False, eval_metric="logloss",
74
  random_state=SEED, verbosity=0,
75
  )
76
  xgb.fit(X_tr_sc, train_y)
@@ -96,7 +96,7 @@ def run_baseline_lopo_f1():
96
  n_estimators=600, max_depth=8, learning_rate=0.05,
97
  subsample=0.8, colsample_bytree=0.8,
98
  reg_alpha=0.1, reg_lambda=1.0,
99
- use_label_encoder=False, eval_metric="logloss",
100
  random_state=SEED, verbosity=0,
101
  )
102
  xgb.fit(X_tr_sc, train_y)
@@ -105,6 +105,47 @@ def run_baseline_lopo_f1():
105
  return np.mean(f1s)
106
 
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  def main():
109
  print("=== Feature importance (XGBoost gain) ===")
110
  imp = xgb_feature_importance()
@@ -124,6 +165,11 @@ def main():
124
  worst_drop = min(ablation.items(), key=lambda x: x[1])
125
  print(f" Largest F1 drop when dropping: {worst_drop[0]} (F1={worst_drop[1]:.4f})")
126
 
 
 
 
 
 
127
  out_dir = os.path.join(_PROJECT_ROOT, "evaluation")
128
  out_path = os.path.join(out_dir, "feature_selection_justification.md")
129
  lines = [
 
70
  n_estimators=600, max_depth=8, learning_rate=0.05,
71
  subsample=0.8, colsample_bytree=0.8,
72
  reg_alpha=0.1, reg_lambda=1.0,
73
+ eval_metric="logloss",
74
  random_state=SEED, verbosity=0,
75
  )
76
  xgb.fit(X_tr_sc, train_y)
 
96
  n_estimators=600, max_depth=8, learning_rate=0.05,
97
  subsample=0.8, colsample_bytree=0.8,
98
  reg_alpha=0.1, reg_lambda=1.0,
99
+ eval_metric="logloss",
100
  random_state=SEED, verbosity=0,
101
  )
102
  xgb.fit(X_tr_sc, train_y)
 
105
  return np.mean(f1s)
106
 
107
 
108
+ # Channel subsets for ablation (subset name -> list of feature names)
109
+ CHANNEL_SUBSETS = {
110
+ "head_pose": ["head_deviation", "s_face", "pitch"],
111
+ "eye_state": ["ear_left", "ear_avg", "ear_right", "perclos"],
112
+ "gaze": ["h_gaze", "gaze_offset", "s_eye"],
113
+ }
114
+
115
+
116
+ def run_channel_ablation():
117
+ """LOPO XGBoost with head-only, eye-only, gaze-only, and all 10. Returns dict subset_name -> mean F1."""
118
+ by_person, _, _ = load_per_person("face_orientation")
119
+ persons = sorted(by_person.keys())
120
+ results = {}
121
+ for subset_name, feat_list in CHANNEL_SUBSETS.items():
122
+ idx_keep = [FEATURES.index(f) for f in feat_list]
123
+ f1s = []
124
+ for held_out in persons:
125
+ train_X = np.concatenate([by_person[p][0] for p in persons if p != held_out])
126
+ train_y = np.concatenate([by_person[p][1] for p in persons if p != held_out])
127
+ X_test, y_test = by_person[held_out]
128
+ X_tr = train_X[:, idx_keep]
129
+ X_te = X_test[:, idx_keep]
130
+ scaler = StandardScaler().fit(X_tr)
131
+ X_tr_sc = scaler.transform(X_tr)
132
+ X_te_sc = scaler.transform(X_te)
133
+ xgb = XGBClassifier(
134
+ n_estimators=600, max_depth=8, learning_rate=0.05,
135
+ subsample=0.8, colsample_bytree=0.8,
136
+ reg_alpha=0.1, reg_lambda=1.0,
137
+ eval_metric="logloss",
138
+ random_state=SEED, verbosity=0,
139
+ )
140
+ xgb.fit(X_tr_sc, train_y)
141
+ pred = xgb.predict(X_te_sc)
142
+ f1s.append(f1_score(y_test, pred, average="weighted"))
143
+ results[subset_name] = np.mean(f1s)
144
+ baseline = run_baseline_lopo_f1()
145
+ results["all_10"] = baseline
146
+ return results
147
+
148
+
149
  def main():
150
  print("=== Feature importance (XGBoost gain) ===")
151
  imp = xgb_feature_importance()
 
165
  worst_drop = min(ablation.items(), key=lambda x: x[1])
166
  print(f" Largest F1 drop when dropping: {worst_drop[0]} (F1={worst_drop[1]:.4f})")
167
 
168
+ print("\n=== Channel ablation (LOPO mean F1) ===")
169
+ channel_f1 = run_channel_ablation()
170
+ for name, f1 in channel_f1.items():
171
+ print(f" {name}: {f1:.4f}")
172
+
173
  out_dir = os.path.join(_PROJECT_ROOT, "evaluation")
174
  out_path = os.path.join(out_dir, "feature_selection_justification.md")
175
  lines = [
evaluation/justify_thresholds.py CHANGED
@@ -95,7 +95,7 @@ def run_lopo_models():
95
  n_estimators=600, max_depth=8, learning_rate=0.05,
96
  subsample=0.8, colsample_bytree=0.8,
97
  reg_alpha=0.1, reg_lambda=1.0,
98
- use_label_encoder=False, eval_metric="logloss",
99
  random_state=SEED, verbosity=0,
100
  )
101
  xgb.fit(X_tr_sc, train_y)
@@ -430,7 +430,7 @@ def run_hybrid_xgb_weight_search(lopo_results):
430
  n_estimators=600, max_depth=8, learning_rate=0.05,
431
  subsample=0.8, colsample_bytree=0.8,
432
  reg_alpha=0.1, reg_lambda=1.0,
433
- use_label_encoder=False, eval_metric="logloss",
434
  random_state=SEED, verbosity=0,
435
  )
436
  xgb_tr.fit(X_tr_sc, train_y)
@@ -504,7 +504,7 @@ def run_hybrid_lr_combiner(lopo_results, use_xgb=True):
504
  n_estimators=600, max_depth=8, learning_rate=0.05,
505
  subsample=0.8, colsample_bytree=0.8,
506
  reg_alpha=0.1, reg_lambda=1.0,
507
- use_label_encoder=False, eval_metric="logloss",
508
  random_state=SEED, verbosity=0,
509
  )
510
  xgb_tr.fit(X_tr_sc, train_y)
 
95
  n_estimators=600, max_depth=8, learning_rate=0.05,
96
  subsample=0.8, colsample_bytree=0.8,
97
  reg_alpha=0.1, reg_lambda=1.0,
98
+ eval_metric="logloss",
99
  random_state=SEED, verbosity=0,
100
  )
101
  xgb.fit(X_tr_sc, train_y)
 
430
  n_estimators=600, max_depth=8, learning_rate=0.05,
431
  subsample=0.8, colsample_bytree=0.8,
432
  reg_alpha=0.1, reg_lambda=1.0,
433
+ eval_metric="logloss",
434
  random_state=SEED, verbosity=0,
435
  )
436
  xgb_tr.fit(X_tr_sc, train_y)
 
504
  n_estimators=600, max_depth=8, learning_rate=0.05,
505
  subsample=0.8, colsample_bytree=0.8,
506
  reg_alpha=0.1, reg_lambda=1.0,
507
+ eval_metric="logloss",
508
  random_state=SEED, verbosity=0,
509
  )
510
  xgb_tr.fit(X_tr_sc, train_y)
evaluation/run_channel_ablation_only.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Run only channel ablation LOPO (no leave-one-out). Quick run for paper data."""
2
+ import os
3
+ import sys
4
+ import numpy as np
5
+ from sklearn.preprocessing import StandardScaler
6
+ from sklearn.metrics import f1_score
7
+ from xgboost import XGBClassifier
8
+
9
+ _PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
10
+ sys.path.insert(0, _PROJECT_ROOT)
11
+ from data_preparation.prepare_dataset import load_per_person, SELECTED_FEATURES
12
+
13
+ SEED = 42
14
+ FEATURES = SELECTED_FEATURES["face_orientation"]
15
+ CHANNEL_SUBSETS = {
16
+ "head_pose": ["head_deviation", "s_face", "pitch"],
17
+ "eye_state": ["ear_left", "ear_avg", "ear_right", "perclos"],
18
+ "gaze": ["h_gaze", "gaze_offset", "s_eye"],
19
+ }
20
+
21
+
22
+ def main():
23
+ by_person, _, _ = load_per_person("face_orientation")
24
+ persons = sorted(by_person.keys())
25
+ results = {}
26
+ for subset_name, feat_list in CHANNEL_SUBSETS.items():
27
+ idx_keep = [FEATURES.index(f) for f in feat_list]
28
+ f1s = []
29
+ for held_out in persons:
30
+ train_X = np.concatenate([by_person[p][0] for p in persons if p != held_out])
31
+ train_y = np.concatenate([by_person[p][1] for p in persons if p != held_out])
32
+ X_test, y_test = by_person[held_out]
33
+ X_tr = train_X[:, idx_keep]
34
+ X_te = X_test[:, idx_keep]
35
+ scaler = StandardScaler().fit(X_tr)
36
+ xgb = XGBClassifier(n_estimators=600, max_depth=8, learning_rate=0.05,
37
+ subsample=0.8, colsample_bytree=0.8, reg_alpha=0.1, reg_lambda=1.0,
38
+ eval_metric="logloss", random_state=SEED, verbosity=0)
39
+ xgb.fit(scaler.transform(X_tr), train_y)
40
+ pred = xgb.predict(scaler.transform(X_te))
41
+ f1s.append(f1_score(y_test, pred, average="weighted"))
42
+ results[subset_name] = np.mean(f1s)
43
+ print(f"{subset_name}: {results[subset_name]:.4f}")
44
+ # baseline
45
+ f1s = []
46
+ for held_out in persons:
47
+ train_X = np.concatenate([by_person[p][0] for p in persons if p != held_out])
48
+ train_y = np.concatenate([by_person[p][1] for p in persons if p != held_out])
49
+ X_test, y_test = by_person[held_out]
50
+ scaler = StandardScaler().fit(train_X)
51
+ xgb = XGBClassifier(n_estimators=600, max_depth=8, learning_rate=0.05,
52
+ subsample=0.8, colsample_bytree=0.8, reg_alpha=0.1, reg_lambda=1.0,
53
+ eval_metric="logloss", random_state=SEED, verbosity=0)
54
+ xgb.fit(scaler.transform(train_X), train_y)
55
+ pred = xgb.predict(scaler.transform(X_test))
56
+ f1s.append(f1_score(y_test, pred, average="weighted"))
57
+ results["all_10"] = np.mean(f1s)
58
+ print(f"all_10: {results['all_10']:.4f}")
59
+ return results
60
+
61
+
62
+ if __name__ == "__main__":
63
+ main()
main.py CHANGED
@@ -14,6 +14,7 @@ import math
14
  import os
15
  from pathlib import Path
16
  from typing import Callable
 
17
  import asyncio
18
  import concurrent.futures
19
  import threading
@@ -148,8 +149,42 @@ _MESH_INDICES = sorted(set(
148
  # Build a lookup: original_index -> position in sparse array, so client can reconstruct.
149
  _MESH_INDEX_SET = set(_MESH_INDICES)
150
 
151
- # Initialize FastAPI app
152
- app = FastAPI(title="Focus Guard API")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  # Add CORS middleware
155
  app.add_middleware(
@@ -163,7 +198,18 @@ app.add_middleware(
163
  # Global variables
164
  db_path = "focus_guard.db"
165
  pcs = set()
166
- _cached_model_name = "mlp" # in-memory cache, updated via /api/settings
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  async def _wait_for_ice_gathering(pc: RTCPeerConnection):
169
  if pc.iceGatheringState == "complete":
@@ -454,72 +500,10 @@ class _EventBuffer:
454
  except Exception as e:
455
  print(f"[DB] Flush error: {e}")
456
 
457
- # ================ STARTUP/SHUTDOWN ================
458
-
459
- pipelines = {
460
- "geometric": None,
461
- "mlp": None,
462
- "hybrid": None,
463
- "xgboost": None,
464
- }
465
-
466
- # Thread pool for CPU-bound inference so the event loop stays responsive.
467
- _inference_executor = concurrent.futures.ThreadPoolExecutor(
468
- max_workers=4,
469
- thread_name_prefix="inference",
470
- )
471
- # One lock per pipeline so shared state (TemporalTracker, etc.) is not corrupted when
472
- # multiple frames are processed in parallel by the thread pool.
473
- _pipeline_locks = {name: threading.Lock() for name in ("geometric", "mlp", "hybrid", "xgboost")}
474
-
475
-
476
  def _process_frame_safe(pipeline, frame, model_name: str):
477
- """Run process_frame in executor with per-pipeline lock."""
478
  with _pipeline_locks[model_name]:
479
  return pipeline.process_frame(frame)
480
 
481
- @app.on_event("startup")
482
- async def startup_event():
483
- global pipelines, _cached_model_name
484
- print(" Starting Focus Guard API...")
485
- await init_database()
486
- # Load cached model name from DB
487
- async with aiosqlite.connect(db_path) as db:
488
- cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
489
- row = await cursor.fetchone()
490
- if row:
491
- _cached_model_name = row[0]
492
- print("[OK] Database initialized")
493
-
494
- try:
495
- pipelines["geometric"] = FaceMeshPipeline()
496
- print("[OK] FaceMeshPipeline (geometric) loaded")
497
- except Exception as e:
498
- print(f"[WARN] FaceMeshPipeline unavailable: {e}")
499
-
500
- try:
501
- pipelines["mlp"] = MLPPipeline()
502
- print("[OK] MLPPipeline loaded")
503
- except Exception as e:
504
- print(f"[ERR] Failed to load MLPPipeline: {e}")
505
-
506
- try:
507
- pipelines["hybrid"] = HybridFocusPipeline()
508
- print("[OK] HybridFocusPipeline loaded")
509
- except Exception as e:
510
- print(f"[WARN] HybridFocusPipeline unavailable: {e}")
511
-
512
- try:
513
- pipelines["xgboost"] = XGBoostPipeline()
514
- print("[OK] XGBoostPipeline loaded")
515
- except Exception as e:
516
- print(f"[ERR] Failed to load XGBoostPipeline: {e}")
517
-
518
- @app.on_event("shutdown")
519
- async def shutdown_event():
520
- _inference_executor.shutdown(wait=False)
521
- print(" Shutting down Focus Guard API...")
522
-
523
  # ================ WEBRTC SIGNALING ================
524
 
525
  @app.post("/api/webrtc/offer")
@@ -898,6 +882,22 @@ async def update_settings(settings: SettingsUpdate):
898
  await db.commit()
899
  return {"status": "success", "updated": len(updates) > 0}
900
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
901
  @app.get("/api/stats/summary")
902
  async def get_stats_summary():
903
  async with aiosqlite.connect(db_path) as db:
 
14
  import os
15
  from pathlib import Path
16
  from typing import Callable
17
+ from contextlib import asynccontextmanager
18
  import asyncio
19
  import concurrent.futures
20
  import threading
 
149
  # Build a lookup: original_index -> position in sparse array, so client can reconstruct.
150
  _MESH_INDEX_SET = set(_MESH_INDICES)
151
 
152
+ @asynccontextmanager
153
+ async def lifespan(app):
154
+ global _cached_model_name
155
+ print(" Starting Focus Guard API...")
156
+ await init_database()
157
+ async with aiosqlite.connect(db_path) as db:
158
+ cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
159
+ row = await cursor.fetchone()
160
+ if row:
161
+ _cached_model_name = row[0]
162
+ print("[OK] Database initialized")
163
+ try:
164
+ pipelines["geometric"] = FaceMeshPipeline()
165
+ print("[OK] FaceMeshPipeline (geometric) loaded")
166
+ except Exception as e:
167
+ print(f"[WARN] FaceMeshPipeline unavailable: {e}")
168
+ try:
169
+ pipelines["mlp"] = MLPPipeline()
170
+ print("[OK] MLPPipeline loaded")
171
+ except Exception as e:
172
+ print(f"[ERR] Failed to load MLPPipeline: {e}")
173
+ try:
174
+ pipelines["hybrid"] = HybridFocusPipeline()
175
+ print("[OK] HybridFocusPipeline loaded")
176
+ except Exception as e:
177
+ print(f"[WARN] HybridFocusPipeline unavailable: {e}")
178
+ try:
179
+ pipelines["xgboost"] = XGBoostPipeline()
180
+ print("[OK] XGBoostPipeline loaded")
181
+ except Exception as e:
182
+ print(f"[ERR] Failed to load XGBoostPipeline: {e}")
183
+ yield
184
+ _inference_executor.shutdown(wait=False)
185
+ print(" Shutting down Focus Guard API...")
186
+
187
+ app = FastAPI(title="Focus Guard API", lifespan=lifespan)
188
 
189
  # Add CORS middleware
190
  app.add_middleware(
 
198
  # Global variables
199
  db_path = "focus_guard.db"
200
  pcs = set()
201
+ _cached_model_name = "mlp"
202
+ pipelines = {
203
+ "geometric": None,
204
+ "mlp": None,
205
+ "hybrid": None,
206
+ "xgboost": None,
207
+ }
208
+ _inference_executor = concurrent.futures.ThreadPoolExecutor(
209
+ max_workers=4,
210
+ thread_name_prefix="inference",
211
+ )
212
+ _pipeline_locks = {name: threading.Lock() for name in ("geometric", "mlp", "hybrid", "xgboost")}
213
 
214
  async def _wait_for_ice_gathering(pc: RTCPeerConnection):
215
  if pc.iceGatheringState == "complete":
 
500
  except Exception as e:
501
  print(f"[DB] Flush error: {e}")
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  def _process_frame_safe(pipeline, frame, model_name: str):
 
504
  with _pipeline_locks[model_name]:
505
  return pipeline.process_frame(frame)
506
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
  # ================ WEBRTC SIGNALING ================
508
 
509
  @app.post("/api/webrtc/offer")
 
882
  await db.commit()
883
  return {"status": "success", "updated": len(updates) > 0}
884
 
885
+ @app.get("/api/stats/system")
886
+ async def get_system_stats():
887
+ """Return server CPU and memory usage for UI display."""
888
+ try:
889
+ import psutil
890
+ cpu = psutil.cpu_percent(interval=0.1)
891
+ mem = psutil.virtual_memory()
892
+ return {
893
+ "cpu_percent": round(cpu, 1),
894
+ "memory_percent": round(mem.percent, 1),
895
+ "memory_used_mb": round(mem.used / (1024 * 1024), 0),
896
+ "memory_total_mb": round(mem.total / (1024 * 1024), 0),
897
+ }
898
+ except ImportError:
899
+ return {"cpu_percent": None, "memory_percent": None, "memory_used_mb": None, "memory_total_mb": None}
900
+
901
  @app.get("/api/stats/summary")
902
  async def get_stats_summary():
903
  async with aiosqlite.connect(db_path) as db:
models/README.md CHANGED
@@ -1,51 +1,16 @@
1
  # models/
2
 
3
- Feature extraction modules and model training scripts.
4
 
5
- ## 1. Feature Extraction
6
 
7
- Root-level modules form the real-time inference pipeline:
8
 
9
- | Module | Input | Output |
10
- |--------|-------|--------|
11
- | `face_mesh.py` | BGR frame | 478 MediaPipe landmarks |
12
- | `head_pose.py` | Landmarks, frame size | yaw, pitch, roll, face/eye score, gaze offset, head deviation |
13
- | `eye_scorer.py` | Landmarks | EAR (left/right/avg), gaze ratio (h/v), MAR |
14
- | `collect_features.py` | BGR frame | 17-d feature vector + temporal features (PERCLOS, blink rate, etc.) |
15
 
16
- ## 2. Training Scripts
17
 
18
- | Folder | Model | Command |
19
- |--------|-------|---------|
20
- | `mlp/` | PyTorch MLP (64→32, 2-class) | `python -m models.mlp.train` |
21
- | `xgboost/` | XGBoost (600 trees, depth 8) | `python -m models.xgboost.train` |
22
-
23
- ### mlp/
24
-
25
- - `train.py` — training loop with early stopping, ClearML opt-in
26
- - `sweep.py` — hyperparameter search (Optuna: lr, batch_size)
27
- - `eval_accuracy.py` — load checkpoint and print test metrics
28
- - Saves to **`checkpoints/mlp_best.pt`**
29
-
30
- ### xgboost/
31
-
32
- - `train.py` — training with eval-set logging
33
- - `sweep.py` / `sweep_local.py` — hyperparameter search (Optuna + ClearML)
34
- - `eval_accuracy.py` — load checkpoint and print test metrics
35
- - Saves to **`checkpoints/xgboost_face_orientation_best.json`**
36
-
37
- ## 3. Data Loading
38
-
39
- All training scripts import from `data_preparation.prepare_dataset`:
40
-
41
- ```python
42
- from data_preparation.prepare_dataset import get_numpy_splits # XGBoost
43
- from data_preparation.prepare_dataset import get_dataloaders # MLP (PyTorch)
44
- ```
45
-
46
- ## 4. Results
47
-
48
- | Model | Test Accuracy | F1 | ROC-AUC |
49
- |-------|--------------|-----|---------|
50
- | XGBoost | 95.87% | 0.959 | 0.991 |
51
- | MLP | 92.92% | 0.929 | 0.971 |
 
1
  # models/
2
 
3
+ Feature extraction (face mesh, head pose, eye scorer, collect_features) and training scripts.
4
 
5
+ **Extraction:** `face_mesh.py` landmarks; `head_pose.py` → yaw/pitch/roll, scores; `eye_scorer.py` → EAR, gaze, MAR; `collect_features.py` → 17-d vector + PERCLOS, blink, etc.
6
 
7
+ **Training:**
8
 
9
+ | Path | Command | Checkpoint |
10
+ |------|---------|------------|
11
+ | mlp/ | `python -m models.mlp.train` | checkpoints/mlp_best.pt |
12
+ | xgboost/ | `python -m models.xgboost.train` | checkpoints/xgboost_face_orientation_best.json |
 
 
13
 
14
+ MLP: train.py, sweep.py, eval_accuracy.py. XGB: train.py, sweep_local.py, eval_accuracy.py. Both use `data_preparation.prepare_dataset` (get_numpy_splits / get_dataloaders).
15
 
16
+ **Results:** XGBoost 95.87% acc, 0.959 F1, 0.991 AUC; MLP 92.92%, 0.929, 0.971.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/xgboost/add_accuracy.py CHANGED
@@ -32,7 +32,6 @@ for idx, row in df.iterrows():
32
  "reg_alpha": float(row["reg_alpha"]),
33
  "reg_lambda": float(row["reg_lambda"]),
34
  "random_state": 42,
35
- "use_label_encoder": False,
36
  "verbosity": 0,
37
  "eval_metric": "logloss"
38
  }
 
32
  "reg_alpha": float(row["reg_alpha"]),
33
  "reg_lambda": float(row["reg_lambda"]),
34
  "random_state": 42,
 
35
  "verbosity": 0,
36
  "eval_metric": "logloss"
37
  }
models/xgboost/sweep_local.py CHANGED
@@ -35,7 +35,6 @@ def objective(trial):
35
  "reg_lambda": trial.suggest_float("reg_lambda", 0.5, 5.0),
36
  "eval_metric": "logloss",
37
  "random_state": SEED,
38
- "use_label_encoder": False,
39
  "verbosity": 0
40
  }
41
 
 
35
  "reg_lambda": trial.suggest_float("reg_lambda", 0.5, 5.0),
36
  "eval_metric": "logloss",
37
  "random_state": SEED,
 
38
  "verbosity": 0
39
  }
40
 
models/xgboost/train.py CHANGED
@@ -81,7 +81,8 @@ def main():
81
  eval_set=[(X_train, y_train), (X_val, y_val)],
82
  verbose=10,
83
  )
84
- print(f"[TRAIN] Best iteration: {model.best_iteration} / {CFG['n_estimators']}")
 
85
 
86
  # ── Evaluation ────────────────────────────────────────────────
87
  evals = model.evals_result()
 
81
  eval_set=[(X_train, y_train), (X_val, y_val)],
82
  verbose=10,
83
  )
84
+ best_it = getattr(model, "best_iteration", None)
85
+ print(f"[TRAIN] Best iteration: {best_it} / {CFG['n_estimators']}")
86
 
87
  # ── Evaluation ────────────────────────────────────────────────
88
  evals = model.evals_result()
notebooks/README.md CHANGED
@@ -1,42 +1,7 @@
1
  # notebooks/
2
 
3
- Training and evaluation notebooks for MLP and XGBoost models.
4
 
5
- ## 1. Files
6
 
7
- | Notebook | Model | Description |
8
- |----------|-------|-------------|
9
- | `mlp.ipynb` | PyTorch MLP | Training, evaluation, and LOPO cross-validation |
10
- | `xgboost.ipynb` | XGBoost | Training, evaluation, and LOPO cross-validation |
11
-
12
- ## 2. Structure (both notebooks)
13
-
14
- Each notebook follows the same layout:
15
-
16
- 1. **Imports and CFG** — single config dict, project root setup
17
- 2. **ClearML (optional)** — opt-in experiment tracking
18
- 3. **Data loading** — uses `data_preparation.prepare_dataset` for consistent loading
19
- 4. **Random split training** — 70/15/15 stratified split with per-epoch/round logging
20
- 5. **Loss curves** — train vs validation loss plots
21
- 6. **Test evaluation** — accuracy, F1, ROC-AUC, classification report, confusion matrix
22
- 7. **Checkpoint saving** — model weights + JSON training log
23
- 8. **LOPO evaluation** — Leave-One-Person-Out cross-validation across all 9 participants
24
- 9. **LOPO summary** — per-person accuracy table + bar chart
25
-
26
- ## 3. Running
27
-
28
- Open in Jupyter or VS Code with the Python kernel set to the project venv:
29
-
30
- ```bash
31
- source venv/bin/activate
32
- jupyter notebook notebooks/mlp.ipynb
33
- ```
34
-
35
- Make sure the kernel's working directory is either the project root or `notebooks/` — the path resolution handles both.
36
-
37
- ## 4. Results
38
-
39
- | Model | Random Split Accuracy | Random Split F1 | LOPO (mean) |
40
- |-------|-----------------------|-----------------|-------------|
41
- | XGBoost | 95.87% | 0.959 | see notebook |
42
- | MLP | 92.92% | 0.929 | see notebook |
 
1
  # notebooks/
2
 
3
+ MLP and XGBoost training + LOPO evaluation.
4
 
5
+ **Files:** `mlp.ipynb`, `xgboost.ipynb`. Same flow: config → data from prepare_dataset → 70/15/15 train → loss curves → test metrics → save checkpoint + JSON log → LOPO over 9 participants.
6
 
7
+ Run in Jupyter with the project venv; set kernel cwd to repo root or `notebooks/`.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/xgboost.ipynb CHANGED
@@ -169,7 +169,6 @@
169
  " reg_alpha=CFG[\"reg_alpha\"],\n",
170
  " reg_lambda=CFG[\"reg_lambda\"],\n",
171
  " eval_metric=CFG[\"eval_metric\"],\n",
172
- " use_label_encoder=False,\n",
173
  " random_state=CFG[\"seed\"],\n",
174
  " verbosity=1,\n",
175
  ")\n",
@@ -355,7 +354,6 @@
355
  " reg_alpha=cfg[\"reg_alpha\"],\n",
356
  " reg_lambda=cfg[\"reg_lambda\"],\n",
357
  " eval_metric=cfg[\"eval_metric\"],\n",
358
- " use_label_encoder=False,\n",
359
  " random_state=cfg[\"seed\"],\n",
360
  " verbosity=0,\n",
361
  " )\n",
 
169
  " reg_alpha=CFG[\"reg_alpha\"],\n",
170
  " reg_lambda=CFG[\"reg_lambda\"],\n",
171
  " eval_metric=CFG[\"eval_metric\"],\n",
 
172
  " random_state=CFG[\"seed\"],\n",
173
  " verbosity=1,\n",
174
  ")\n",
 
354
  " reg_alpha=cfg[\"reg_alpha\"],\n",
355
  " reg_lambda=cfg[\"reg_lambda\"],\n",
356
  " eval_metric=cfg[\"eval_metric\"],\n",
 
357
  " random_state=cfg[\"seed\"],\n",
358
  " verbosity=0,\n",
359
  " )\n",
requirements.txt CHANGED
@@ -11,7 +11,10 @@ joblib>=1.2.0
11
  torch>=2.0.0
12
  fastapi>=0.104.0
13
  uvicorn[standard]>=0.24.0
 
14
  aiosqlite>=0.19.0
 
15
  pydantic>=2.0.0
16
  xgboost>=2.0.0
17
  clearml>=2.0.2
 
 
11
  torch>=2.0.0
12
  fastapi>=0.104.0
13
  uvicorn[standard]>=0.24.0
14
+ httpx>=0.24.0
15
  aiosqlite>=0.19.0
16
+ psutil>=5.9.0
17
  pydantic>=2.0.0
18
  xgboost>=2.0.0
19
  clearml>=2.0.2
20
+ pytest>=7.0.0
src/App.css CHANGED
@@ -106,16 +106,15 @@ body {
106
  /* PAGE A SPECIFIC */
107
  #page-a {
108
  justify-content: center; /* Center vertically */
109
- /* 注意:因为 React 结构变化,如果感觉偏下,可以微调这个 margin-top */
110
  margin-top: -40px;
111
- flex: 1; /* 确保它占满剩余空间以便垂直居中 */
112
  }
113
 
114
  #page-a h1 {
115
  font-size: 80px;
116
  margin: 0 0 10px 0;
117
  color: #000;
118
- text-align: center; /* 确保文字居中 */
119
  }
120
 
121
  #page-a p {
@@ -145,13 +144,13 @@ body {
145
  #page-b {
146
  justify-content: space-evenly; /* Distribute vertical space */
147
  padding-bottom: 20px;
148
- min-height: calc(100vh - 60px); /* 再次确保高度足够 */
149
  }
150
 
151
  /* 1. Display Area */
152
  #display-area {
153
  width: 60%;
154
- height: 50vh; /* 改用 vh 单位,确保在不同屏幕下的高度比例 */
155
  min-height: 300px;
156
  border: 2px solid #ddd;
157
  border-radius: 12px;
@@ -162,14 +161,13 @@ body {
162
  color: #555;
163
  font-size: 24px;
164
  position: relative;
165
- /* 确保视频元素也能居中且不溢出 */
166
  overflow: hidden;
167
  }
168
 
169
  #display-area video {
170
  width: 100%;
171
  height: 100%;
172
- object-fit: cover; /* 类似于 background-size: cover */
173
  }
174
 
175
  /* 2. Timeline Area */
@@ -367,7 +365,6 @@ body {
367
  #focus-chart {
368
  display: block;
369
  margin: 0 auto;
370
- /* 确保图表在容器内自适应 */
371
  max-width: 100%;
372
  }
373
 
@@ -504,13 +501,12 @@ input[type="number"] {
504
  font-family: 'Nunito', sans-serif;
505
  }
506
 
507
- /* --- 新的代码:让按钮居中且变宽 --- */
508
  .setting-group .action-btn {
509
- display: inline-block; /* 允许并排显示 */
510
- width: 48%; /* 两个按钮各占约一半宽度 (留2%缝隙) */
511
- margin: 15px 1%; /* 上下 15px,左右 1% 间距来实现居中和分隔 */
512
- text-align: center; /* 文字居中 */
513
- box-sizing: border-box; /* 确保边框不会撑大按钮导致换行 */
514
  }
515
 
516
  #save-settings {
@@ -625,7 +621,6 @@ details p {
625
  }
626
 
627
  /* ================ SESSION SUMMARY MODAL ================ */
628
- /* 如果将来要做弹窗,这些样式可以直接复用 */
629
  .modal-overlay {
630
  position: fixed;
631
  top: 0;
@@ -743,17 +738,14 @@ details p {
743
  flex-direction: column;
744
  }
745
  }
746
- /* =========================================
747
- SESSION RESULT OVERLAY (新增)
748
- ========================================= */
749
-
750
  .session-result-overlay {
751
  position: absolute;
752
  top: 0;
753
  left: 0;
754
  width: 100%;
755
  height: 100%;
756
- background-color: rgba(0, 0, 0, 0.85); /* 深色半透明背景 */
757
  display: flex;
758
  flex-direction: column;
759
  justify-content: center;
@@ -761,13 +753,13 @@ details p {
761
  color: white;
762
  z-index: 10;
763
  animation: fadeIn 0.5s ease;
764
- backdrop-filter: blur(5px); /* 背景模糊效果 (可选) */
765
  }
766
 
767
  .session-result-overlay h3 {
768
  font-size: 32px;
769
  margin-bottom: 30px;
770
- color: #4cd137; /* 绿色标题 */
771
  text-transform: uppercase;
772
  letter-spacing: 2px;
773
  }
@@ -775,7 +767,7 @@ details p {
775
  .session-result-overlay .result-item {
776
  display: flex;
777
  justify-content: space-between;
778
- width: 200px; /* 控制宽度 */
779
  margin-bottom: 15px;
780
  font-size: 20px;
781
  border-bottom: 1px solid rgba(255,255,255,0.2);
@@ -790,7 +782,7 @@ details p {
790
  .session-result-overlay .value {
791
  color: #fff;
792
  font-weight: bold;
793
- font-family: 'Courier New', monospace; /* 看起来像数据 */
794
  }
795
 
796
  @keyframes fadeIn {
@@ -798,7 +790,7 @@ details p {
798
  to { opacity: 1; transform: scale(1); }
799
  }
800
 
801
- /* ================= 迎宾弹窗样式 ================= */
802
  .welcome-modal-overlay {
803
  position: fixed;
804
  top: 0; left: 0; right: 0; bottom: 0;
@@ -822,7 +814,7 @@ border: 1px solid #333;
822
  .welcome-modal p { margin-bottom: 30px; color: #ccc; }
823
  .welcome-buttons { display: flex; gap: 20px; justify-content: center; }
824
 
825
- /* ================= 左上角头像样式 (修改版) ================= */
826
  #top-menu {
827
  position: relative;
828
  display: flex;
@@ -852,30 +844,29 @@ border: 2px solid transparent;
852
 
853
  .avatar-circle.user { background-color: #555; }
854
  .avatar-circle.admin { background-color: #ffaa00; border-color: #fff; box-shadow: 0 0 10px rgba(255, 170, 0, 0.5); }
855
- /* ================= 首页按钮 2x2 响应式网格 (终极修正版) ================= */
856
  .home-button-grid {
857
  display: grid;
858
- grid-template-columns: 1fr 1fr; /* 强行平分两列,绝不妥协 */
859
- gap: 20px; /* 按钮之间的间距 */
860
  width: 100%;
861
- max-width: 500px; /* 限制最大宽度,宽屏下也不会显得傻大黑粗 */
862
- margin: 40px auto 0 auto; /* 上边距40px,左右auto保证绝对居中! */
863
  }
864
 
865
  .home-button-grid .btn-main {
866
  width: 100%;
867
- height: 60px; /* 统一高度,像一块块整齐的砖头 */
868
- margin: 0; /* 清除默认外边距 */
869
  padding: 10px;
870
  font-size: 1rem;
871
  display: flex;
872
  justify-content: center;
873
  align-items: center;
874
  text-align: center;
875
- box-sizing: border-box; /* 保证边框和内边距不会撑破格子 */
876
  }
877
 
878
- /* 📱 手机端专属适配 (屏幕宽度小于 600px 时自动缩放) */
879
  @media (max-width: 600px) {
880
  .home-button-grid {
881
  gap: 15px;
 
106
  /* PAGE A SPECIFIC */
107
  #page-a {
108
  justify-content: center; /* Center vertically */
 
109
  margin-top: -40px;
110
+ flex: 1;
111
  }
112
 
113
  #page-a h1 {
114
  font-size: 80px;
115
  margin: 0 0 10px 0;
116
  color: #000;
117
+ text-align: center;
118
  }
119
 
120
  #page-a p {
 
144
  #page-b {
145
  justify-content: space-evenly; /* Distribute vertical space */
146
  padding-bottom: 20px;
147
+ min-height: calc(100vh - 60px);
148
  }
149
 
150
  /* 1. Display Area */
151
  #display-area {
152
  width: 60%;
153
+ height: 50vh;
154
  min-height: 300px;
155
  border: 2px solid #ddd;
156
  border-radius: 12px;
 
161
  color: #555;
162
  font-size: 24px;
163
  position: relative;
 
164
  overflow: hidden;
165
  }
166
 
167
  #display-area video {
168
  width: 100%;
169
  height: 100%;
170
+ object-fit: cover;
171
  }
172
 
173
  /* 2. Timeline Area */
 
365
  #focus-chart {
366
  display: block;
367
  margin: 0 auto;
 
368
  max-width: 100%;
369
  }
370
 
 
501
  font-family: 'Nunito', sans-serif;
502
  }
503
 
 
504
  .setting-group .action-btn {
505
+ display: inline-block;
506
+ width: 48%;
507
+ margin: 15px 1%;
508
+ text-align: center;
509
+ box-sizing: border-box;
510
  }
511
 
512
  #save-settings {
 
621
  }
622
 
623
  /* ================ SESSION SUMMARY MODAL ================ */
 
624
  .modal-overlay {
625
  position: fixed;
626
  top: 0;
 
738
  flex-direction: column;
739
  }
740
  }
741
+ /* SESSION RESULT OVERLAY */
 
 
 
742
  .session-result-overlay {
743
  position: absolute;
744
  top: 0;
745
  left: 0;
746
  width: 100%;
747
  height: 100%;
748
+ background-color: rgba(0, 0, 0, 0.85);
749
  display: flex;
750
  flex-direction: column;
751
  justify-content: center;
 
753
  color: white;
754
  z-index: 10;
755
  animation: fadeIn 0.5s ease;
756
+ backdrop-filter: blur(5px);
757
  }
758
 
759
  .session-result-overlay h3 {
760
  font-size: 32px;
761
  margin-bottom: 30px;
762
+ color: #4cd137;
763
  text-transform: uppercase;
764
  letter-spacing: 2px;
765
  }
 
767
  .session-result-overlay .result-item {
768
  display: flex;
769
  justify-content: space-between;
770
+ width: 200px;
771
  margin-bottom: 15px;
772
  font-size: 20px;
773
  border-bottom: 1px solid rgba(255,255,255,0.2);
 
782
  .session-result-overlay .value {
783
  color: #fff;
784
  font-weight: bold;
785
+ font-family: 'Courier New', monospace;
786
  }
787
 
788
  @keyframes fadeIn {
 
790
  to { opacity: 1; transform: scale(1); }
791
  }
792
 
793
+ /* welcome modal */
794
  .welcome-modal-overlay {
795
  position: fixed;
796
  top: 0; left: 0; right: 0; bottom: 0;
 
814
  .welcome-modal p { margin-bottom: 30px; color: #ccc; }
815
  .welcome-buttons { display: flex; gap: 20px; justify-content: center; }
816
 
817
+ /* top avatar */
818
  #top-menu {
819
  position: relative;
820
  display: flex;
 
844
 
845
  .avatar-circle.user { background-color: #555; }
846
  .avatar-circle.admin { background-color: #ffaa00; border-color: #fff; box-shadow: 0 0 10px rgba(255, 170, 0, 0.5); }
847
+ /* home 2x2 button grid */
848
  .home-button-grid {
849
  display: grid;
850
+ grid-template-columns: 1fr 1fr;
851
+ gap: 20px;
852
  width: 100%;
853
+ max-width: 500px;
854
+ margin: 40px auto 0 auto;
855
  }
856
 
857
  .home-button-grid .btn-main {
858
  width: 100%;
859
+ height: 60px;
860
+ margin: 0;
861
  padding: 10px;
862
  font-size: 1rem;
863
  display: flex;
864
  justify-content: center;
865
  align-items: center;
866
  text-align: center;
867
+ box-sizing: border-box;
868
  }
869
 
 
870
  @media (max-width: 600px) {
871
  .home-button-grid {
872
  gap: 15px;
src/App.jsx CHANGED
@@ -16,7 +16,7 @@ function App() {
16
  const [sessionResult, setSessionResult] = useState(null);
17
  const [role, setRole] = useState('user');
18
 
19
- // 刚进网页时,静默清空数据库,不弹窗!
20
  useEffect(() => {
21
  fetch('/api/history', { method: 'DELETE' }).catch(err => console.error(err));
22
 
@@ -37,7 +37,7 @@ function App() {
37
  };
38
  }, []);
39
 
40
- // 点击头像,直接跳转到首页(Home)
41
  const handleAvatarClick = () => {
42
  setActiveTab('home');
43
  };
@@ -76,7 +76,7 @@ function App() {
76
  </button>
77
  </nav>
78
 
79
- {/* 把核心状态传给 Home 组件 */}
80
  {activeTab === 'home' && <Home setActiveTab={setActiveTab} role={role} setRole={setRole} />}
81
 
82
  <FocusPageLocal
 
16
  const [sessionResult, setSessionResult] = useState(null);
17
  const [role, setRole] = useState('user');
18
 
19
+ //
20
  useEffect(() => {
21
  fetch('/api/history', { method: 'DELETE' }).catch(err => console.error(err));
22
 
 
37
  };
38
  }, []);
39
 
40
+ //
41
  const handleAvatarClick = () => {
42
  setActiveTab('home');
43
  };
 
76
  </button>
77
  </nav>
78
 
79
+ {/* pass state to Home */}
80
  {activeTab === 'home' && <Home setActiveTab={setActiveTab} role={role} setRole={setRole} />}
81
 
82
  <FocusPageLocal
src/components/Achievement.jsx CHANGED
@@ -7,10 +7,10 @@ function Achievement() {
7
  avg_focus_score: 0,
8
  streak_days: 0
9
  });
 
10
  const [badges, setBadges] = useState([]);
11
  const [loading, setLoading] = useState(true);
12
 
13
- // 格式化时间显示
14
  const formatTime = (seconds) => {
15
  const hours = Math.floor(seconds / 3600);
16
  const minutes = Math.floor((seconds % 3600) / 60);
@@ -18,7 +18,6 @@ function Achievement() {
18
  return `${minutes}m`;
19
  };
20
 
21
- // 加载统计数据
22
  useEffect(() => {
23
  fetch('/api/stats/summary')
24
  .then(res => res.json())
@@ -33,11 +32,21 @@ function Achievement() {
33
  });
34
  }, []);
35
 
36
- // 根据统计数据计算徽章
 
 
 
 
 
 
 
 
 
 
 
37
  const calculateBadges = (data) => {
38
  const earnedBadges = [];
39
 
40
- // 首次会话徽章
41
  if (data.total_sessions >= 1) {
42
  earnedBadges.push({
43
  id: 'first-session',
@@ -48,7 +57,6 @@ function Achievement() {
48
  });
49
  }
50
 
51
- // 10次会话徽章
52
  if (data.total_sessions >= 10) {
53
  earnedBadges.push({
54
  id: 'ten-sessions',
@@ -59,7 +67,6 @@ function Achievement() {
59
  });
60
  }
61
 
62
- // 50次会话徽章
63
  if (data.total_sessions >= 50) {
64
  earnedBadges.push({
65
  id: 'fifty-sessions',
@@ -70,7 +77,6 @@ function Achievement() {
70
  });
71
  }
72
 
73
- // 专注大师徽章 (平均专注度 > 80%)
74
  if (data.avg_focus_score >= 0.8 && data.total_sessions >= 5) {
75
  earnedBadges.push({
76
  id: 'focus-master',
@@ -81,7 +87,6 @@ function Achievement() {
81
  });
82
  }
83
 
84
- // 连续天数徽章
85
  if (data.streak_days >= 7) {
86
  earnedBadges.push({
87
  id: 'week-streak',
@@ -102,7 +107,6 @@ function Achievement() {
102
  });
103
  }
104
 
105
- // 时长徽章 (10小时+)
106
  if (data.total_focus_time >= 36000) {
107
  earnedBadges.push({
108
  id: 'ten-hours',
@@ -113,7 +117,6 @@ function Achievement() {
113
  });
114
  }
115
 
116
- // 未解锁徽章(示例)
117
  const allBadges = [
118
  {
119
  id: 'first-session',
@@ -186,6 +189,22 @@ function Achievement() {
186
  </div>
187
  ) : (
188
  <>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  <div className="stats-grid">
190
  <div className="stat-card">
191
  <div className="stat-number" id="total-sessions">{stats.total_sessions}</div>
 
7
  avg_focus_score: 0,
8
  streak_days: 0
9
  });
10
+ const [systemStats, setSystemStats] = useState(null);
11
  const [badges, setBadges] = useState([]);
12
  const [loading, setLoading] = useState(true);
13
 
 
14
  const formatTime = (seconds) => {
15
  const hours = Math.floor(seconds / 3600);
16
  const minutes = Math.floor((seconds % 3600) / 60);
 
18
  return `${minutes}m`;
19
  };
20
 
 
21
  useEffect(() => {
22
  fetch('/api/stats/summary')
23
  .then(res => res.json())
 
32
  });
33
  }, []);
34
 
35
+ useEffect(() => {
36
+ const fetchSystem = () => {
37
+ fetch('/api/stats/system')
38
+ .then(res => res.json())
39
+ .then(data => setSystemStats(data))
40
+ .catch(() => setSystemStats(null));
41
+ };
42
+ fetchSystem();
43
+ const interval = setInterval(fetchSystem, 3000);
44
+ return () => clearInterval(interval);
45
+ }, []);
46
+
47
  const calculateBadges = (data) => {
48
  const earnedBadges = [];
49
 
 
50
  if (data.total_sessions >= 1) {
51
  earnedBadges.push({
52
  id: 'first-session',
 
57
  });
58
  }
59
 
 
60
  if (data.total_sessions >= 10) {
61
  earnedBadges.push({
62
  id: 'ten-sessions',
 
67
  });
68
  }
69
 
 
70
  if (data.total_sessions >= 50) {
71
  earnedBadges.push({
72
  id: 'fifty-sessions',
 
77
  });
78
  }
79
 
 
80
  if (data.avg_focus_score >= 0.8 && data.total_sessions >= 5) {
81
  earnedBadges.push({
82
  id: 'focus-master',
 
87
  });
88
  }
89
 
 
90
  if (data.streak_days >= 7) {
91
  earnedBadges.push({
92
  id: 'week-streak',
 
107
  });
108
  }
109
 
 
110
  if (data.total_focus_time >= 36000) {
111
  earnedBadges.push({
112
  id: 'ten-hours',
 
117
  });
118
  }
119
 
 
120
  const allBadges = [
121
  {
122
  id: 'first-session',
 
189
  </div>
190
  ) : (
191
  <>
192
+ {systemStats && systemStats.cpu_percent != null && (
193
+ <div style={{
194
+ textAlign: 'center',
195
+ marginBottom: '12px',
196
+ padding: '8px 12px',
197
+ background: 'rgba(0,0,0,0.2)',
198
+ borderRadius: '8px',
199
+ fontSize: '13px',
200
+ color: '#aaa'
201
+ }}>
202
+ Server: CPU <strong style={{ color: '#8f8' }}>{systemStats.cpu_percent}%</strong>
203
+ {' · '}
204
+ RAM <strong style={{ color: '#8af' }}>{systemStats.memory_percent}%</strong>
205
+ {systemStats.memory_used_mb != null && ` (${systemStats.memory_used_mb}/${systemStats.memory_total_mb} MB)`}
206
+ </div>
207
+ )}
208
  <div className="stats-grid">
209
  <div className="stat-card">
210
  <div className="stat-number" id="total-sessions">{stats.total_sessions}</div>
src/components/Customise.jsx CHANGED
@@ -6,10 +6,10 @@ function Customise() {
6
  const [notificationsEnabled, setNotificationsEnabled] = useState(true);
7
  const [threshold, setThreshold] = useState(30);
8
 
9
- // 引用隐藏的文件输入框
10
  const fileInputRef = useRef(null);
11
 
12
- // 1. 加载设置
13
  useEffect(() => {
14
  fetch('/api/settings')
15
  .then(res => res.json())
@@ -24,7 +24,7 @@ function Customise() {
24
  .catch(err => console.error("Failed to load settings", err));
25
  }, []);
26
 
27
- // 2. 保存设置
28
  const handleSave = async () => {
29
  const settings = {
30
  sensitivity: parseInt(sensitivity),
@@ -46,34 +46,34 @@ function Customise() {
46
  }
47
  };
48
 
49
- // 3. 导出数据 (Export)
50
  const handleExport = async () => {
51
  try {
52
- // 请求获取所有历史记录
53
  const response = await fetch('/api/sessions?filter=all');
54
  if (!response.ok) throw new Error("Failed to fetch data");
55
 
56
  const data = await response.json();
57
 
58
- // 创建 JSON Blob
59
  const jsonString = JSON.stringify(data, null, 2);
60
- // 在浏览器缓存里存一份
61
  localStorage.setItem('focus_magic_backup', jsonString);
62
 
63
  const blob = new Blob([jsonString], { type: 'application/json' });
64
 
65
- // 创建临时下载链接
66
  const url = URL.createObjectURL(blob);
67
  const link = document.createElement('a');
68
  link.href = url;
69
- // 文件名包含当前日期
70
  link.download = `focus-guard-backup-${new Date().toISOString().slice(0, 10)}.json`;
71
 
72
- // 触发下载
73
  document.body.appendChild(link);
74
  link.click();
75
 
76
- // 清理
77
  document.body.removeChild(link);
78
  URL.revokeObjectURL(url);
79
  } catch (error) {
@@ -82,12 +82,12 @@ function Customise() {
82
  }
83
  };
84
 
85
- // 4. 触发导入文件选择
86
  const triggerImport = () => {
87
  fileInputRef.current.click();
88
  };
89
 
90
- // 5. 处理文件导入 (Import)
91
  const handleFileChange = async (event) => {
92
  const file = event.target.files[0];
93
  if (!file) return;
@@ -98,12 +98,12 @@ function Customise() {
98
  const content = e.target.result;
99
  const sessions = JSON.parse(content);
100
 
101
- // 简单的验证:确保它是一个数组
102
  if (!Array.isArray(sessions)) {
103
  throw new Error("Invalid file format: Expected a list of sessions.");
104
  }
105
 
106
- // 发送给后端进行存储
107
  const response = await fetch('/api/import', {
108
  method: 'POST',
109
  headers: { 'Content-Type': 'application/json' },
@@ -119,13 +119,13 @@ function Customise() {
119
  } catch (err) {
120
  alert("Error parsing file: " + err.message);
121
  }
122
- // 清空 input,允许重复上传同一个文件
123
  event.target.value = '';
124
  };
125
  reader.readAsText(file);
126
  };
127
 
128
- // 6. 清除历史 (Clear History)
129
  const handleClearHistory = async () => {
130
  if (!window.confirm("Are you sure? This will delete ALL your session history permanently.")) {
131
  return;
@@ -187,7 +187,7 @@ function Customise() {
187
  <div className="setting-group">
188
  <h2>Data Management</h2>
189
 
190
- {/* 隐藏的文件输入框,只接受 json */}
191
  <input
192
  type="file"
193
  ref={fileInputRef}
@@ -197,17 +197,17 @@ function Customise() {
197
  />
198
 
199
  <div style={{ display: 'flex', gap: '10px', justifyContent: 'center', flexWrap: 'wrap' }}>
200
- {/* Export 按钮 */}
201
  <button id="export-data" className="action-btn blue" onClick={handleExport} style={{ width: '30%', minWidth: '120px' }}>
202
  Export Data
203
  </button>
204
 
205
- {/* Import 按钮 */}
206
  <button id="import-data" className="action-btn yellow" onClick={triggerImport} style={{ width: '30%', minWidth: '120px' }}>
207
  Import Data
208
  </button>
209
 
210
- {/* Clear 按钮 */}
211
  <button id="clear-history" className="action-btn red" onClick={handleClearHistory} style={{ width: '30%', minWidth: '120px' }}>
212
  Clear History
213
  </button>
 
6
  const [notificationsEnabled, setNotificationsEnabled] = useState(true);
7
  const [threshold, setThreshold] = useState(30);
8
 
9
+ //
10
  const fileInputRef = useRef(null);
11
 
12
+ //
13
  useEffect(() => {
14
  fetch('/api/settings')
15
  .then(res => res.json())
 
24
  .catch(err => console.error("Failed to load settings", err));
25
  }, []);
26
 
27
+ //
28
  const handleSave = async () => {
29
  const settings = {
30
  sensitivity: parseInt(sensitivity),
 
46
  }
47
  };
48
 
49
+ //
50
  const handleExport = async () => {
51
  try {
52
+ //
53
  const response = await fetch('/api/sessions?filter=all');
54
  if (!response.ok) throw new Error("Failed to fetch data");
55
 
56
  const data = await response.json();
57
 
58
+ //
59
  const jsonString = JSON.stringify(data, null, 2);
60
+ //
61
  localStorage.setItem('focus_magic_backup', jsonString);
62
 
63
  const blob = new Blob([jsonString], { type: 'application/json' });
64
 
65
+ //
66
  const url = URL.createObjectURL(blob);
67
  const link = document.createElement('a');
68
  link.href = url;
69
+ //
70
  link.download = `focus-guard-backup-${new Date().toISOString().slice(0, 10)}.json`;
71
 
72
+ //
73
  document.body.appendChild(link);
74
  link.click();
75
 
76
+ //
77
  document.body.removeChild(link);
78
  URL.revokeObjectURL(url);
79
  } catch (error) {
 
82
  }
83
  };
84
 
85
+ //
86
  const triggerImport = () => {
87
  fileInputRef.current.click();
88
  };
89
 
90
+ //
91
  const handleFileChange = async (event) => {
92
  const file = event.target.files[0];
93
  if (!file) return;
 
98
  const content = e.target.result;
99
  const sessions = JSON.parse(content);
100
 
101
+ //
102
  if (!Array.isArray(sessions)) {
103
  throw new Error("Invalid file format: Expected a list of sessions.");
104
  }
105
 
106
+ //
107
  const response = await fetch('/api/import', {
108
  method: 'POST',
109
  headers: { 'Content-Type': 'application/json' },
 
119
  } catch (err) {
120
  alert("Error parsing file: " + err.message);
121
  }
122
+ //
123
  event.target.value = '';
124
  };
125
  reader.readAsText(file);
126
  };
127
 
128
+ //
129
  const handleClearHistory = async () => {
130
  if (!window.confirm("Are you sure? This will delete ALL your session history permanently.")) {
131
  return;
 
187
  <div className="setting-group">
188
  <h2>Data Management</h2>
189
 
190
+ {/* hidden file input, json only */}
191
  <input
192
  type="file"
193
  ref={fileInputRef}
 
197
  />
198
 
199
  <div style={{ display: 'flex', gap: '10px', justifyContent: 'center', flexWrap: 'wrap' }}>
200
+ {/* Export */}
201
  <button id="export-data" className="action-btn blue" onClick={handleExport} style={{ width: '30%', minWidth: '120px' }}>
202
  Export Data
203
  </button>
204
 
205
+ {/* Import */}
206
  <button id="import-data" className="action-btn yellow" onClick={triggerImport} style={{ width: '30%', minWidth: '120px' }}>
207
  Import Data
208
  </button>
209
 
210
+ {/* Clear */}
211
  <button id="clear-history" className="action-btn red" onClick={handleClearHistory} style={{ width: '30%', minWidth: '120px' }}>
212
  Clear History
213
  </button>
src/components/FocusPage.jsx CHANGED
@@ -6,9 +6,9 @@ function FocusPage({ videoManager, sessionResult, setSessionResult, isActive, di
6
 
7
  const videoRef = displayVideoRef;
8
 
9
- // 辅助函数:格式化时间
10
  const formatDuration = (seconds) => {
11
- // 如果是 0,直接显示 0s (或者你可以保留原来的 0m 0s)
12
  if (seconds === 0) return "0s";
13
 
14
  const mins = Math.floor(seconds / 60);
@@ -19,7 +19,7 @@ function FocusPage({ videoManager, sessionResult, setSessionResult, isActive, di
19
  useEffect(() => {
20
  if (!videoManager) return;
21
 
22
- // 设置回调函数来更新时间轴
23
  const originalOnStatusUpdate = videoManager.callbacks.onStatusUpdate;
24
  videoManager.callbacks.onStatusUpdate = (isFocused) => {
25
  setTimelineEvents(prev => {
@@ -27,11 +27,11 @@ function FocusPage({ videoManager, sessionResult, setSessionResult, isActive, di
27
  if (newEvents.length > 60) newEvents.shift();
28
  return newEvents;
29
  });
30
- // 调用原始回调(如果有)
31
  if (originalOnStatusUpdate) originalOnStatusUpdate(isFocused);
32
  };
33
 
34
- // 清理函数:不再自动停止session,只清理回调
35
  return () => {
36
  if (videoManager) {
37
  videoManager.callbacks.onStatusUpdate = originalOnStatusUpdate;
@@ -42,7 +42,7 @@ function FocusPage({ videoManager, sessionResult, setSessionResult, isActive, di
42
  const handleStart = async () => {
43
  try {
44
  if (videoManager) {
45
- setSessionResult(null); // 开始时清除结果层
46
  setTimelineEvents([]);
47
 
48
  console.log('🎬 Initializing camera...');
@@ -114,16 +114,16 @@ function FocusPage({ videoManager, sessionResult, setSessionResult, isActive, di
114
  }
115
  };
116
 
117
- // 浮窗功能
118
  const handleFloatingWindow = () => {
119
  handlePiP();
120
  };
121
 
122
  // ==========================================
123
- // 新增功能:预览按钮的处理函数
124
  // ==========================================
125
  const handlePreview = () => {
126
- // 强制设置一个 0 分 0 秒的假数据,触发 overlay 显示
127
  setSessionResult({
128
  duration_seconds: 0,
129
  focus_score: 0
@@ -165,7 +165,7 @@ function FocusPage({ videoManager, sessionResult, setSessionResult, isActive, di
165
  style={{ width: '100%', height: '100%', objectFit: 'contain' }}
166
  />
167
 
168
- {/* 结果覆盖层 */}
169
  {sessionResult && (
170
  <div className="session-result-overlay">
171
  <h3>Session Complete!</h3>
@@ -178,7 +178,7 @@ function FocusPage({ videoManager, sessionResult, setSessionResult, isActive, di
178
  <span className="value">{(sessionResult.focus_score * 100).toFixed(1)}%</span>
179
  </div>
180
 
181
- {/* 新增:加一个小按钮方便关闭预览 */}
182
  <button
183
  onClick={handleCloseOverlay}
184
  style={{
@@ -226,11 +226,11 @@ function FocusPage({ videoManager, sessionResult, setSessionResult, isActive, di
226
  <button id="btn-cam-start" className="action-btn green" onClick={handleStart}>Start</button>
227
  <button id="btn-floating" className="action-btn yellow" onClick={handleFloatingWindow}>Floating Window</button>
228
 
229
- {/* 修改:把 Models 按钮暂时改成 Preview 按钮,或者加在它后面 */}
230
  <button
231
  id="btn-preview"
232
  className="action-btn"
233
- style={{ backgroundColor: '#6c5ce7' }} // 紫色按钮以示区别
234
  onClick={handlePreview}
235
  >
236
  Preview Result
 
6
 
7
  const videoRef = displayVideoRef;
8
 
9
+ //
10
  const formatDuration = (seconds) => {
11
+ //
12
  if (seconds === 0) return "0s";
13
 
14
  const mins = Math.floor(seconds / 60);
 
19
  useEffect(() => {
20
  if (!videoManager) return;
21
 
22
+ //
23
  const originalOnStatusUpdate = videoManager.callbacks.onStatusUpdate;
24
  videoManager.callbacks.onStatusUpdate = (isFocused) => {
25
  setTimelineEvents(prev => {
 
27
  if (newEvents.length > 60) newEvents.shift();
28
  return newEvents;
29
  });
30
+ //
31
  if (originalOnStatusUpdate) originalOnStatusUpdate(isFocused);
32
  };
33
 
34
+ //
35
  return () => {
36
  if (videoManager) {
37
  videoManager.callbacks.onStatusUpdate = originalOnStatusUpdate;
 
42
  const handleStart = async () => {
43
  try {
44
  if (videoManager) {
45
+ setSessionResult(null);
46
  setTimelineEvents([]);
47
 
48
  console.log('🎬 Initializing camera...');
 
114
  }
115
  };
116
 
117
+ //
118
  const handleFloatingWindow = () => {
119
  handlePiP();
120
  };
121
 
122
  // ==========================================
123
+ //
124
  // ==========================================
125
  const handlePreview = () => {
126
+ //
127
  setSessionResult({
128
  duration_seconds: 0,
129
  focus_score: 0
 
165
  style={{ width: '100%', height: '100%', objectFit: 'contain' }}
166
  />
167
 
168
+ {/* result overlay */}
169
  {sessionResult && (
170
  <div className="session-result-overlay">
171
  <h3>Session Complete!</h3>
 
178
  <span className="value">{(sessionResult.focus_score * 100).toFixed(1)}%</span>
179
  </div>
180
 
181
+ {/* close preview */}
182
  <button
183
  onClick={handleCloseOverlay}
184
  style={{
 
226
  <button id="btn-cam-start" className="action-btn green" onClick={handleStart}>Start</button>
227
  <button id="btn-floating" className="action-btn yellow" onClick={handleFloatingWindow}>Floating Window</button>
228
 
229
+ {/* preview button */}
230
  <button
231
  id="btn-preview"
232
  className="action-btn"
233
+ style={{ backgroundColor: '#6c5ce7' }}
234
  onClick={handlePreview}
235
  >
236
  Preview Result
src/components/FocusPageLocal.jsx CHANGED
@@ -4,15 +4,15 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
4
  const [currentFrame, setCurrentFrame] = useState(15);
5
  const [timelineEvents, setTimelineEvents] = useState([]);
6
  const [stats, setStats] = useState(null);
 
7
  const [availableModels, setAvailableModels] = useState([]);
8
  const [currentModel, setCurrentModel] = useState('mlp');
9
 
10
  const localVideoRef = useRef(null);
11
  const displayCanvasRef = useRef(null);
12
- const pipVideoRef = useRef(null); // 用于 PiP 的隐藏 video 元素
13
  const pipStreamRef = useRef(null);
14
 
15
- // 辅助函数:格式化时间
16
  const formatDuration = (seconds) => {
17
  if (seconds === 0) return "0s";
18
  const mins = Math.floor(seconds / 60);
@@ -23,7 +23,6 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
23
  useEffect(() => {
24
  if (!videoManager) return;
25
 
26
- // 设置回调函数来更新时间轴
27
  const originalOnStatusUpdate = videoManager.callbacks.onStatusUpdate;
28
  videoManager.callbacks.onStatusUpdate = (isFocused) => {
29
  setTimelineEvents(prev => {
@@ -34,7 +33,6 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
34
  if (originalOnStatusUpdate) originalOnStatusUpdate(isFocused);
35
  };
36
 
37
- // 定期更新统计信息
38
  const statsInterval = setInterval(() => {
39
  if (videoManager && videoManager.getStats) {
40
  setStats(videoManager.getStats());
@@ -60,6 +58,19 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
60
  .catch(err => console.error('Failed to fetch models:', err));
61
  }, []);
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  const handleModelChange = async (modelName) => {
64
  try {
65
  const res = await fetch('/api/settings', {
@@ -129,7 +140,7 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
129
 
130
  const handlePiP = async () => {
131
  try {
132
- // 检查是否有视频管理器和是否在运行
133
  if (!videoManager || !videoManager.isStreaming) {
134
  alert('Please start the video first.');
135
  return;
@@ -140,20 +151,20 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
140
  return;
141
  }
142
 
143
- // 如果已经在 PiP 模式,且是本视频,退出
144
  if (document.pictureInPictureElement === pipVideoRef.current) {
145
  await document.exitPictureInPicture();
146
  console.log('PiP exited');
147
  return;
148
  }
149
 
150
- // 检查浏览器支持
151
  if (!document.pictureInPictureEnabled) {
152
  alert('Picture-in-Picture is not supported in this browser.');
153
  return;
154
  }
155
 
156
- // 创建或获取 PiP video 元素
157
  const pipVideo = pipVideoRef.current;
158
  if (!pipVideo) {
159
  alert('PiP video element not ready.');
@@ -162,7 +173,7 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
162
 
163
  const isSafariPiP = typeof pipVideo.webkitSetPresentationMode === 'function';
164
 
165
- // 优先用画布流(带检测框),失败再回退到摄像头流
166
  let stream = pipStreamRef.current;
167
  if (!stream) {
168
  const capture = displayCanvasRef.current.captureStream;
@@ -180,7 +191,7 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
180
  pipStreamRef.current = stream;
181
  }
182
 
183
- // 确保流有轨道
184
  if (!stream || stream.getTracks().length === 0) {
185
  alert('Failed to capture video stream from canvas.');
186
  return;
@@ -188,7 +199,7 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
188
 
189
  pipVideo.srcObject = stream;
190
 
191
- // 播放视频(Safari 可能不会触发 onloadedmetadata)
192
  if (pipVideo.readyState < 2) {
193
  await new Promise((resolve) => {
194
  const onReady = () => {
@@ -198,7 +209,7 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
198
  };
199
  pipVideo.addEventListener('loadeddata', onReady);
200
  pipVideo.addEventListener('canplay', onReady);
201
- // 兜底:短延迟后继续尝试
202
  setTimeout(resolve, 600);
203
  });
204
  }
@@ -206,17 +217,17 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
206
  try {
207
  await pipVideo.play();
208
  } catch (_) {
209
- // Safari 可能拒绝自动播放,但仍可进入 PiP
210
  }
211
 
212
- // Safari 支持(优先)
213
  if (isSafariPiP) {
214
  try {
215
  pipVideo.webkitSetPresentationMode('picture-in-picture');
216
  console.log('PiP activated (Safari)');
217
  return;
218
  } catch (e) {
219
- // 如果画布流失败,回退到摄像头流再试一次
220
  const cameraStream = localVideoRef.current?.srcObject;
221
  if (cameraStream && cameraStream !== pipVideo.srcObject) {
222
  pipVideo.srcObject = cameraStream;
@@ -231,7 +242,7 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
231
  }
232
  }
233
 
234
- // 标准 API
235
  if (typeof pipVideo.requestPictureInPicture === 'function') {
236
  await pipVideo.requestPictureInPicture();
237
  console.log('PiP activated');
@@ -263,7 +274,7 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
263
  return;
264
  }
265
 
266
- // 获取当前统计数据
267
  const currentStats = videoManager.getStats();
268
 
269
  if (!currentStats.sessionId) {
@@ -271,15 +282,15 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
271
  return;
272
  }
273
 
274
- // 计算当前持续时间(从 session 开始到现在)
275
  const sessionDuration = Math.floor((Date.now() - (videoManager.sessionStartTime || Date.now())) / 1000);
276
 
277
- // 计算当前专注分数
278
  const focusScore = currentStats.framesProcessed > 0
279
  ? (currentStats.framesProcessed * (currentStats.currentStatus ? 1 : 0)) / currentStats.framesProcessed
280
  : 0;
281
 
282
- // 显示当前实时数据
283
  setSessionResult({
284
  duration_seconds: sessionDuration,
285
  focus_score: focusScore,
@@ -320,7 +331,7 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
320
  <main id="page-b" className="page" style={pageStyle}>
321
  {/* 1. Camera / Display Area */}
322
  <section id="display-area" style={{ position: 'relative', overflow: 'hidden' }}>
323
- {/* 用于 PiP 的隐藏 video 元素(保持在 DOM 以提高兼容性) */}
324
  <video
325
  ref={pipVideoRef}
326
  muted
@@ -334,7 +345,7 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
334
  pointerEvents: 'none'
335
  }}
336
  />
337
- {/* 本地视频流(隐藏,仅用于截图) */}
338
  <video
339
  ref={localVideoRef}
340
  muted
@@ -343,7 +354,7 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
343
  style={{ display: 'none' }}
344
  />
345
 
346
- {/* 显示处理后的视频(使用 Canvas) */}
347
  <canvas
348
  ref={displayCanvasRef}
349
  width={640}
@@ -356,7 +367,7 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
356
  }}
357
  />
358
 
359
- {/* 结果覆盖层 */}
360
  {sessionResult && (
361
  <div className="session-result-overlay">
362
  <h3>Session Complete!</h3>
@@ -386,7 +397,7 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
386
  </div>
387
  )}
388
 
389
- {/* 性能统计显示(开发模式) */}
390
  {stats && stats.isStreaming && (
391
  <div style={{
392
  position: 'absolute',
@@ -405,10 +416,36 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
405
  <div>Latency: {stats.avgLatency.toFixed(0)}ms</div>
406
  <div>Status: {stats.currentStatus ? 'Focused' : 'Not Focused'}</div>
407
  <div>Confidence: {(stats.lastConfidence * 100).toFixed(1)}%</div>
 
 
 
 
 
 
408
  </div>
409
  )}
410
  </section>
411
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  {/* 2. Model Selector */}
413
  {availableModels.length > 0 && (
414
  <section style={{
 
4
  const [currentFrame, setCurrentFrame] = useState(15);
5
  const [timelineEvents, setTimelineEvents] = useState([]);
6
  const [stats, setStats] = useState(null);
7
+ const [systemStats, setSystemStats] = useState(null);
8
  const [availableModels, setAvailableModels] = useState([]);
9
  const [currentModel, setCurrentModel] = useState('mlp');
10
 
11
  const localVideoRef = useRef(null);
12
  const displayCanvasRef = useRef(null);
13
+ const pipVideoRef = useRef(null);
14
  const pipStreamRef = useRef(null);
15
 
 
16
  const formatDuration = (seconds) => {
17
  if (seconds === 0) return "0s";
18
  const mins = Math.floor(seconds / 60);
 
23
  useEffect(() => {
24
  if (!videoManager) return;
25
 
 
26
  const originalOnStatusUpdate = videoManager.callbacks.onStatusUpdate;
27
  videoManager.callbacks.onStatusUpdate = (isFocused) => {
28
  setTimelineEvents(prev => {
 
33
  if (originalOnStatusUpdate) originalOnStatusUpdate(isFocused);
34
  };
35
 
 
36
  const statsInterval = setInterval(() => {
37
  if (videoManager && videoManager.getStats) {
38
  setStats(videoManager.getStats());
 
58
  .catch(err => console.error('Failed to fetch models:', err));
59
  }, []);
60
 
61
+ // Poll server CPU/memory for UI
62
+ useEffect(() => {
63
+ const fetchSystem = () => {
64
+ fetch('/api/stats/system')
65
+ .then(res => res.json())
66
+ .then(data => setSystemStats(data))
67
+ .catch(() => setSystemStats(null));
68
+ };
69
+ fetchSystem();
70
+ const interval = setInterval(fetchSystem, 3000);
71
+ return () => clearInterval(interval);
72
+ }, []);
73
+
74
  const handleModelChange = async (modelName) => {
75
  try {
76
  const res = await fetch('/api/settings', {
 
140
 
141
  const handlePiP = async () => {
142
  try {
143
+ //
144
  if (!videoManager || !videoManager.isStreaming) {
145
  alert('Please start the video first.');
146
  return;
 
151
  return;
152
  }
153
 
154
+ //
155
  if (document.pictureInPictureElement === pipVideoRef.current) {
156
  await document.exitPictureInPicture();
157
  console.log('PiP exited');
158
  return;
159
  }
160
 
161
+ //
162
  if (!document.pictureInPictureEnabled) {
163
  alert('Picture-in-Picture is not supported in this browser.');
164
  return;
165
  }
166
 
167
+ //
168
  const pipVideo = pipVideoRef.current;
169
  if (!pipVideo) {
170
  alert('PiP video element not ready.');
 
173
 
174
  const isSafariPiP = typeof pipVideo.webkitSetPresentationMode === 'function';
175
 
176
+ //
177
  let stream = pipStreamRef.current;
178
  if (!stream) {
179
  const capture = displayCanvasRef.current.captureStream;
 
191
  pipStreamRef.current = stream;
192
  }
193
 
194
+ //
195
  if (!stream || stream.getTracks().length === 0) {
196
  alert('Failed to capture video stream from canvas.');
197
  return;
 
199
 
200
  pipVideo.srcObject = stream;
201
 
202
+ //
203
  if (pipVideo.readyState < 2) {
204
  await new Promise((resolve) => {
205
  const onReady = () => {
 
209
  };
210
  pipVideo.addEventListener('loadeddata', onReady);
211
  pipVideo.addEventListener('canplay', onReady);
212
+ //
213
  setTimeout(resolve, 600);
214
  });
215
  }
 
217
  try {
218
  await pipVideo.play();
219
  } catch (_) {
220
+ //
221
  }
222
 
223
+ //
224
  if (isSafariPiP) {
225
  try {
226
  pipVideo.webkitSetPresentationMode('picture-in-picture');
227
  console.log('PiP activated (Safari)');
228
  return;
229
  } catch (e) {
230
+ //
231
  const cameraStream = localVideoRef.current?.srcObject;
232
  if (cameraStream && cameraStream !== pipVideo.srcObject) {
233
  pipVideo.srcObject = cameraStream;
 
242
  }
243
  }
244
 
245
+ //
246
  if (typeof pipVideo.requestPictureInPicture === 'function') {
247
  await pipVideo.requestPictureInPicture();
248
  console.log('PiP activated');
 
274
  return;
275
  }
276
 
277
+ //
278
  const currentStats = videoManager.getStats();
279
 
280
  if (!currentStats.sessionId) {
 
282
  return;
283
  }
284
 
285
+ //
286
  const sessionDuration = Math.floor((Date.now() - (videoManager.sessionStartTime || Date.now())) / 1000);
287
 
288
+ //
289
  const focusScore = currentStats.framesProcessed > 0
290
  ? (currentStats.framesProcessed * (currentStats.currentStatus ? 1 : 0)) / currentStats.framesProcessed
291
  : 0;
292
 
293
+ //
294
  setSessionResult({
295
  duration_seconds: sessionDuration,
296
  focus_score: focusScore,
 
331
  <main id="page-b" className="page" style={pageStyle}>
332
  {/* 1. Camera / Display Area */}
333
  <section id="display-area" style={{ position: 'relative', overflow: 'hidden' }}>
334
+ {/* hidden PiP video element */}
335
  <video
336
  ref={pipVideoRef}
337
  muted
 
345
  pointerEvents: 'none'
346
  }}
347
  />
348
+ {/* local video (hidden, for capture) */}
349
  <video
350
  ref={localVideoRef}
351
  muted
 
354
  style={{ display: 'none' }}
355
  />
356
 
357
+ {/* processed video (canvas) */}
358
  <canvas
359
  ref={displayCanvasRef}
360
  width={640}
 
367
  }}
368
  />
369
 
370
+ {/* result overlay */}
371
  {sessionResult && (
372
  <div className="session-result-overlay">
373
  <h3>Session Complete!</h3>
 
397
  </div>
398
  )}
399
 
400
+ {/* stats overlay */}
401
  {stats && stats.isStreaming && (
402
  <div style={{
403
  position: 'absolute',
 
416
  <div>Latency: {stats.avgLatency.toFixed(0)}ms</div>
417
  <div>Status: {stats.currentStatus ? 'Focused' : 'Not Focused'}</div>
418
  <div>Confidence: {(stats.lastConfidence * 100).toFixed(1)}%</div>
419
+ {systemStats && systemStats.cpu_percent != null && (
420
+ <div style={{ marginTop: '6px', borderTop: '1px solid #444', paddingTop: '4px' }}>
421
+ <div>CPU: {systemStats.cpu_percent}%</div>
422
+ <div>RAM: {systemStats.memory_percent}% ({systemStats.memory_used_mb}/{systemStats.memory_total_mb} MB)</div>
423
+ </div>
424
+ )}
425
  </div>
426
  )}
427
  </section>
428
 
429
+ {/* Server CPU / Memory (always visible) */}
430
+ {systemStats && (systemStats.cpu_percent != null || systemStats.memory_percent != null) && (
431
+ <section style={{
432
+ display: 'flex',
433
+ alignItems: 'center',
434
+ justifyContent: 'center',
435
+ gap: '16px',
436
+ padding: '6px 12px',
437
+ background: 'rgba(0,0,0,0.3)',
438
+ borderRadius: '8px',
439
+ margin: '6px auto',
440
+ maxWidth: '400px',
441
+ fontSize: '13px',
442
+ color: '#aaa'
443
+ }}>
444
+ <span title="Server CPU">CPU: <strong style={{ color: '#8f8' }}>{systemStats.cpu_percent}%</strong></span>
445
+ <span title="Server memory">RAM: <strong style={{ color: '#8af' }}>{systemStats.memory_percent}%</strong> ({systemStats.memory_used_mb}/{systemStats.memory_total_mb} MB)</span>
446
+ </section>
447
+ )}
448
+
449
  {/* 2. Model Selector */}
450
  {availableModels.length > 0 && (
451
  <section style={{
src/components/Home.jsx CHANGED
@@ -3,13 +3,13 @@ import React, { useRef } from 'react';
3
  function Home({ setActiveTab, role, setRole }) {
4
  const fileInputRef = useRef(null);
5
 
6
- // 1. 开始新生活
7
  const handleNewStart = async () => {
8
  await fetch('/api/history', { method: 'DELETE' });
9
  setActiveTab('focus');
10
  };
11
 
12
- // 2. 自动导入 (使用缓存魔术)
13
  const handleAutoImport = async () => {
14
  const backup = localStorage.getItem('focus_magic_backup');
15
  if (backup) {
@@ -33,7 +33,7 @@ function Home({ setActiveTab, role, setRole }) {
33
  }
34
  };
35
 
36
- // 3. 手动导入
37
  const handleFileChange = async (event) => {
38
  const file = event.target.files[0];
39
  if (!file) return;
@@ -57,7 +57,7 @@ function Home({ setActiveTab, role, setRole }) {
57
  reader.readAsText(file);
58
  };
59
 
60
- // 4. 切换 Admin/User 模式
61
  const handleAdminToggle = async () => {
62
  if (role === 'admin') {
63
  if (window.confirm("Switch back to User mode? Current data will be cleared.")) {
@@ -96,10 +96,10 @@ function Home({ setActiveTab, role, setRole }) {
96
  <h1>FocusGuard</h1>
97
  <p>Your productivity monitor assistant.</p>
98
 
99
- {/* 🔪 把隐藏的上传框放在网格外面,绝对不让它占格子! */}
100
  <input type="file" ref={fileInputRef} style={{ display: 'none' }} accept=".json" onChange={handleFileChange} />
101
 
102
- {/* 使用全新的 2x2 网格容器,里面只放 4 个纯净的按钮 */}
103
  <div className="home-button-grid">
104
 
105
  <button className="btn-main" onClick={handleNewStart}>
 
3
  function Home({ setActiveTab, role, setRole }) {
4
  const fileInputRef = useRef(null);
5
 
6
+ //
7
  const handleNewStart = async () => {
8
  await fetch('/api/history', { method: 'DELETE' });
9
  setActiveTab('focus');
10
  };
11
 
12
+ //
13
  const handleAutoImport = async () => {
14
  const backup = localStorage.getItem('focus_magic_backup');
15
  if (backup) {
 
33
  }
34
  };
35
 
36
+ //
37
  const handleFileChange = async (event) => {
38
  const file = event.target.files[0];
39
  if (!file) return;
 
57
  reader.readAsText(file);
58
  };
59
 
60
+ //
61
  const handleAdminToggle = async () => {
62
  if (role === 'admin') {
63
  if (window.confirm("Switch back to User mode? Current data will be cleared.")) {
 
96
  <h1>FocusGuard</h1>
97
  <p>Your productivity monitor assistant.</p>
98
 
99
+ {/* hidden file input outside grid */}
100
  <input type="file" ref={fileInputRef} style={{ display: 'none' }} accept=".json" onChange={handleFileChange} />
101
 
102
+ {/* 2x2 button grid */}
103
  <div className="home-button-grid">
104
 
105
  <button className="btn-main" onClick={handleNewStart}>
src/components/Records.jsx CHANGED
@@ -6,14 +6,14 @@ function Records() {
6
  const [loading, setLoading] = useState(false);
7
  const chartRef = useRef(null);
8
 
9
- // 格式化时间
10
  const formatDuration = (seconds) => {
11
  const mins = Math.floor(seconds / 60);
12
  const secs = seconds % 60;
13
  return `${mins}m ${secs}s`;
14
  };
15
 
16
- // 格式化日期
17
  const formatDate = (dateString) => {
18
  const date = new Date(dateString);
19
  return date.toLocaleDateString('en-US', {
@@ -24,7 +24,7 @@ function Records() {
24
  });
25
  };
26
 
27
- // 加载会话数据
28
  const loadSessions = async (filterType) => {
29
  setLoading(true);
30
  try {
@@ -39,7 +39,7 @@ function Records() {
39
  }
40
  };
41
 
42
- // 绘制图表
43
  const drawChart = (data) => {
44
  const canvas = chartRef.current;
45
  if (!canvas) return;
@@ -48,7 +48,7 @@ function Records() {
48
  const width = canvas.width = canvas.offsetWidth;
49
  const height = canvas.height = 300;
50
 
51
- // 清空画布
52
  ctx.clearRect(0, 0, width, height);
53
 
54
  if (data.length === 0) {
@@ -59,17 +59,17 @@ function Records() {
59
  return;
60
  }
61
 
62
- // 准备数据 (最多显示最近20个会话)
63
  const displayData = data.slice(0, 20).reverse();
64
  const padding = 50;
65
  const chartWidth = width - padding * 2;
66
  const chartHeight = height - padding * 2;
67
  const barWidth = chartWidth / displayData.length;
68
 
69
- // 找到最大值用于缩放
70
  const maxScore = 1.0;
71
 
72
- // 绘制坐标轴
73
  ctx.strokeStyle = '#E0E0E0';
74
  ctx.lineWidth = 2;
75
  ctx.beginPath();
@@ -78,7 +78,7 @@ function Records() {
78
  ctx.lineTo(width - padding, height - padding);
79
  ctx.stroke();
80
 
81
- // 绘制Y轴刻度
82
  ctx.fillStyle = '#666';
83
  ctx.font = '12px Nunito';
84
  ctx.textAlign = 'right';
@@ -87,7 +87,7 @@ function Records() {
87
  const value = (maxScore * i / 4 * 100).toFixed(0);
88
  ctx.fillText(value + '%', padding - 10, y + 4);
89
 
90
- // 绘制网格线
91
  ctx.strokeStyle = '#F0F0F0';
92
  ctx.lineWidth = 1;
93
  ctx.beginPath();
@@ -96,14 +96,14 @@ function Records() {
96
  ctx.stroke();
97
  }
98
 
99
- // 绘制柱状图
100
  displayData.forEach((session, index) => {
101
  const barHeight = (session.focus_score / maxScore) * chartHeight;
102
  const x = padding + index * barWidth + barWidth * 0.1;
103
  const y = height - padding - barHeight;
104
  const barActualWidth = barWidth * 0.8;
105
 
106
- // 根据分数设置颜色 - 使用蓝色主题
107
  const score = session.focus_score;
108
  let color;
109
  if (score >= 0.8) color = '#4A90E2';
@@ -114,32 +114,32 @@ function Records() {
114
  ctx.fillStyle = color;
115
  ctx.fillRect(x, y, barActualWidth, barHeight);
116
 
117
- // 绘制边框
118
  ctx.strokeStyle = color;
119
  ctx.lineWidth = 1;
120
  ctx.strokeRect(x, y, barActualWidth, barHeight);
121
  });
122
 
123
- // 绘制图例
124
  ctx.textAlign = 'left';
125
  ctx.font = 'bold 14px Nunito';
126
  ctx.fillStyle = '#4A90E2';
127
  ctx.fillText('Focus Score by Session', padding, 30);
128
  };
129
 
130
- // 初始加载
131
  useEffect(() => {
132
  loadSessions(filter);
133
  }, [filter]);
134
 
135
- // 处理筛选按钮点击
136
  const handleFilterClick = (filterType) => {
137
  setFilter(filterType);
138
  };
139
 
140
- // 查看详情
141
  const handleViewDetails = (sessionId) => {
142
- // 这里可以实现查看详情的功能,比如弹窗显示该会话的详细信息
143
  alert(`View details for session ${sessionId}\n(Feature can be extended later)`);
144
  };
145
 
 
6
  const [loading, setLoading] = useState(false);
7
  const chartRef = useRef(null);
8
 
9
+ //
10
  const formatDuration = (seconds) => {
11
  const mins = Math.floor(seconds / 60);
12
  const secs = seconds % 60;
13
  return `${mins}m ${secs}s`;
14
  };
15
 
16
+ //
17
  const formatDate = (dateString) => {
18
  const date = new Date(dateString);
19
  return date.toLocaleDateString('en-US', {
 
24
  });
25
  };
26
 
27
+ //
28
  const loadSessions = async (filterType) => {
29
  setLoading(true);
30
  try {
 
39
  }
40
  };
41
 
42
+ //
43
  const drawChart = (data) => {
44
  const canvas = chartRef.current;
45
  if (!canvas) return;
 
48
  const width = canvas.width = canvas.offsetWidth;
49
  const height = canvas.height = 300;
50
 
51
+ //
52
  ctx.clearRect(0, 0, width, height);
53
 
54
  if (data.length === 0) {
 
59
  return;
60
  }
61
 
62
+ //
63
  const displayData = data.slice(0, 20).reverse();
64
  const padding = 50;
65
  const chartWidth = width - padding * 2;
66
  const chartHeight = height - padding * 2;
67
  const barWidth = chartWidth / displayData.length;
68
 
69
+ //
70
  const maxScore = 1.0;
71
 
72
+ //
73
  ctx.strokeStyle = '#E0E0E0';
74
  ctx.lineWidth = 2;
75
  ctx.beginPath();
 
78
  ctx.lineTo(width - padding, height - padding);
79
  ctx.stroke();
80
 
81
+ //
82
  ctx.fillStyle = '#666';
83
  ctx.font = '12px Nunito';
84
  ctx.textAlign = 'right';
 
87
  const value = (maxScore * i / 4 * 100).toFixed(0);
88
  ctx.fillText(value + '%', padding - 10, y + 4);
89
 
90
+ //
91
  ctx.strokeStyle = '#F0F0F0';
92
  ctx.lineWidth = 1;
93
  ctx.beginPath();
 
96
  ctx.stroke();
97
  }
98
 
99
+ //
100
  displayData.forEach((session, index) => {
101
  const barHeight = (session.focus_score / maxScore) * chartHeight;
102
  const x = padding + index * barWidth + barWidth * 0.1;
103
  const y = height - padding - barHeight;
104
  const barActualWidth = barWidth * 0.8;
105
 
106
+ //
107
  const score = session.focus_score;
108
  let color;
109
  if (score >= 0.8) color = '#4A90E2';
 
114
  ctx.fillStyle = color;
115
  ctx.fillRect(x, y, barActualWidth, barHeight);
116
 
117
+ //
118
  ctx.strokeStyle = color;
119
  ctx.lineWidth = 1;
120
  ctx.strokeRect(x, y, barActualWidth, barHeight);
121
  });
122
 
123
+ //
124
  ctx.textAlign = 'left';
125
  ctx.font = 'bold 14px Nunito';
126
  ctx.fillStyle = '#4A90E2';
127
  ctx.fillText('Focus Score by Session', padding, 30);
128
  };
129
 
130
+ //
131
  useEffect(() => {
132
  loadSessions(filter);
133
  }, [filter]);
134
 
135
+ //
136
  const handleFilterClick = (filterType) => {
137
  setFilter(filterType);
138
  };
139
 
140
+ //
141
  const handleViewDetails = (sessionId) => {
142
+ //
143
  alert(`View details for session ${sessionId}\n(Feature can be extended later)`);
144
  };
145
 
src/utils/VideoManager.js CHANGED
@@ -2,12 +2,10 @@
2
 
3
  export class VideoManager {
4
  constructor(callbacks) {
5
- // callbacks 用于通知 React 组件更新界面
6
- // 例如: onStatusUpdate, onSessionStart, onSessionEnd
7
  this.callbacks = callbacks || {};
8
 
9
- this.videoElement = null; // 显示后端处理后的视频
10
- this.stream = null; // 本地摄像头流
11
  this.pc = null;
12
  this.dataChannel = null;
13
 
@@ -15,25 +13,21 @@ export class VideoManager {
15
  this.sessionId = null;
16
  this.frameRate = 30;
17
 
18
- // 状态平滑处理
19
  this.currentStatus = false;
20
  this.statusBuffer = [];
21
  this.bufferSize = 5;
22
 
23
- // 检测数据
24
  this.latestDetectionData = null;
25
  this.lastConfidence = 0;
26
  this.detectionHoldMs = 30;
27
 
28
- // 通知系统
29
  this.notificationEnabled = true;
30
- this.notificationThreshold = 30; // 默认30秒
31
  this.unfocusedStartTime = null;
32
  this.lastNotificationTime = null;
33
- this.notificationCooldown = 60000; // 通知冷却时间60秒
34
  }
35
 
36
- // 初始化:获取摄像头流,并记录展示视频的元素
37
  async initCamera(videoRef) {
38
  try {
39
  this.stream = await navigator.mediaDevices.getUserMedia({
@@ -62,9 +56,9 @@ export class VideoManager {
62
 
63
  console.log('📹 Starting streaming...');
64
 
65
- // 请求通知权限
66
  await this.requestNotificationPermission();
67
- // 加载通知设置
68
  await this.loadNotificationSettings();
69
 
70
  this.pc = new RTCPeerConnection({
@@ -78,7 +72,7 @@ export class VideoManager {
78
  iceCandidatePoolSize: 10
79
  });
80
 
81
- // 添加连接状态监控
82
  this.pc.onconnectionstatechange = () => {
83
  console.log('🔗 Connection state:', this.pc.connectionState);
84
  };
@@ -199,7 +193,7 @@ export class VideoManager {
199
  requireInteraction: false
200
  });
201
 
202
- // 3秒后自动关闭
203
  setTimeout(() => notification.close(), 3000);
204
  } catch (error) {
205
  console.error('Failed to send notification:', error);
@@ -247,28 +241,28 @@ export class VideoManager {
247
  this.currentStatus = false;
248
  }
249
 
250
- // 通知逻辑
251
  this.handleNotificationLogic(previousStatus, this.currentStatus);
252
  }
253
 
254
  handleNotificationLogic(previousStatus, currentStatus) {
255
  const now = Date.now();
256
 
257
- // 如果从专注变为不专注,记录开始时间
258
  if (previousStatus && !currentStatus) {
259
  this.unfocusedStartTime = now;
260
  }
261
 
262
- // 如果从不专注变为专注,清除计时
263
  if (!previousStatus && currentStatus) {
264
  this.unfocusedStartTime = null;
265
  }
266
 
267
- // 如果持续不专注
268
  if (!currentStatus && this.unfocusedStartTime) {
269
- const unfocusedDuration = (now - this.unfocusedStartTime) / 1000; // 秒
270
 
271
- // 检查是否超过阈值且不在冷却期
272
  if (unfocusedDuration >= this.notificationThreshold) {
273
  const canSendNotification = !this.lastNotificationTime ||
274
  (now - this.lastNotificationTime) >= this.notificationCooldown;
@@ -335,7 +329,7 @@ export class VideoManager {
335
  }
336
  }
337
 
338
- // 清理通知状态
339
  this.unfocusedStartTime = null;
340
  this.lastNotificationTime = null;
341
  this.sessionId = null;
 
2
 
3
  export class VideoManager {
4
  constructor(callbacks) {
 
 
5
  this.callbacks = callbacks || {};
6
 
7
+ this.videoElement = null;
8
+ this.stream = null;
9
  this.pc = null;
10
  this.dataChannel = null;
11
 
 
13
  this.sessionId = null;
14
  this.frameRate = 30;
15
 
 
16
  this.currentStatus = false;
17
  this.statusBuffer = [];
18
  this.bufferSize = 5;
19
 
 
20
  this.latestDetectionData = null;
21
  this.lastConfidence = 0;
22
  this.detectionHoldMs = 30;
23
 
 
24
  this.notificationEnabled = true;
25
+ this.notificationThreshold = 30;
26
  this.unfocusedStartTime = null;
27
  this.lastNotificationTime = null;
28
+ this.notificationCooldown = 60000;
29
  }
30
 
 
31
  async initCamera(videoRef) {
32
  try {
33
  this.stream = await navigator.mediaDevices.getUserMedia({
 
56
 
57
  console.log('📹 Starting streaming...');
58
 
59
+ //
60
  await this.requestNotificationPermission();
61
+ //
62
  await this.loadNotificationSettings();
63
 
64
  this.pc = new RTCPeerConnection({
 
72
  iceCandidatePoolSize: 10
73
  });
74
 
75
+ //
76
  this.pc.onconnectionstatechange = () => {
77
  console.log('🔗 Connection state:', this.pc.connectionState);
78
  };
 
193
  requireInteraction: false
194
  });
195
 
196
+ //
197
  setTimeout(() => notification.close(), 3000);
198
  } catch (error) {
199
  console.error('Failed to send notification:', error);
 
241
  this.currentStatus = false;
242
  }
243
 
244
+ //
245
  this.handleNotificationLogic(previousStatus, this.currentStatus);
246
  }
247
 
248
  handleNotificationLogic(previousStatus, currentStatus) {
249
  const now = Date.now();
250
 
251
+ //
252
  if (previousStatus && !currentStatus) {
253
  this.unfocusedStartTime = now;
254
  }
255
 
256
+ //
257
  if (!previousStatus && currentStatus) {
258
  this.unfocusedStartTime = null;
259
  }
260
 
261
+ //
262
  if (!currentStatus && this.unfocusedStartTime) {
263
+ const unfocusedDuration = (now - this.unfocusedStartTime) / 1000;
264
 
265
+ //
266
  if (unfocusedDuration >= this.notificationThreshold) {
267
  const canSendNotification = !this.lastNotificationTime ||
268
  (now - this.lastNotificationTime) >= this.notificationCooldown;
 
329
  }
330
  }
331
 
332
+ //
333
  this.unfocusedStartTime = null;
334
  this.lastNotificationTime = null;
335
  this.sessionId = null;
src/utils/VideoManagerLocal.js CHANGED
@@ -1,12 +1,12 @@
1
  // src/utils/VideoManagerLocal.js
2
- // 本地视频处理版本 - 使用 WebSocket + Canvas,不依赖 WebRTC
3
 
4
  export class VideoManagerLocal {
5
  constructor(callbacks) {
6
  this.callbacks = callbacks || {};
7
 
8
- this.localVideoElement = null; // 显示本地摄像头
9
- this.displayVideoElement = null; // 显示处理后的视频
10
  this.canvas = null;
11
  this.stream = null;
12
  this.ws = null;
@@ -14,15 +14,13 @@ export class VideoManagerLocal {
14
  this.isStreaming = false;
15
  this.sessionId = null;
16
  this.sessionStartTime = null;
17
- this.frameRate = 15; // 降低帧率以减少网络负载
18
  this.captureInterval = null;
19
 
20
- // 状态平滑处理
21
  this.currentStatus = false;
22
  this.statusBuffer = [];
23
  this.bufferSize = 3;
24
 
25
- // 检测数据
26
  this.latestDetectionData = null;
27
  this.lastConfidence = 0;
28
 
@@ -32,14 +30,12 @@ export class VideoManagerLocal {
32
  // Continuous render loop
33
  this._animFrameId = null;
34
 
35
- // 通知系统
36
  this.notificationEnabled = true;
37
  this.notificationThreshold = 30;
38
  this.unfocusedStartTime = null;
39
  this.lastNotificationTime = null;
40
  this.notificationCooldown = 60000;
41
 
42
- // 性能统计
43
  this.stats = {
44
  framesSent: 0,
45
  framesProcessed: 0,
@@ -48,7 +44,6 @@ export class VideoManagerLocal {
48
  };
49
  }
50
 
51
- // 初始化摄像头
52
  async initCamera(localVideoRef, displayCanvasRef) {
53
  try {
54
  console.log('Initializing local camera...');
@@ -65,13 +60,11 @@ export class VideoManagerLocal {
65
  this.localVideoElement = localVideoRef;
66
  this.displayCanvas = displayCanvasRef;
67
 
68
- // 显示本地视频流
69
  if (this.localVideoElement) {
70
  this.localVideoElement.srcObject = this.stream;
71
  this.localVideoElement.play();
72
  }
73
 
74
- // 创建用于截图的 canvas (smaller for faster encode + transfer)
75
  this.canvas = document.createElement('canvas');
76
  this.canvas.width = 320;
77
  this.canvas.height = 240;
@@ -84,7 +77,6 @@ export class VideoManagerLocal {
84
  }
85
  }
86
 
87
- // 开始流式处理
88
  async startStreaming() {
89
  if (!this.stream) {
90
  throw new Error('Camera not initialized');
@@ -109,14 +101,14 @@ export class VideoManagerLocal {
109
  }
110
  }
111
 
112
- // 请求通知权限
113
  await this.requestNotificationPermission();
114
  await this.loadNotificationSettings();
115
 
116
- // 建立 WebSocket 连接
117
  await this.connectWebSocket();
118
 
119
- // 开始定期截图并发送
120
  this.startCapture();
121
 
122
  // Start continuous render loop for smooth video
@@ -126,7 +118,7 @@ export class VideoManagerLocal {
126
  console.log('Streaming started');
127
  }
128
 
129
- // 建立 WebSocket 连接
130
  async connectWebSocket() {
131
  return new Promise((resolve, reject) => {
132
  const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
@@ -139,7 +131,7 @@ export class VideoManagerLocal {
139
  this.ws.onopen = () => {
140
  console.log('WebSocket connected');
141
 
142
- // 发送开始会话请求
143
  this.ws.send(JSON.stringify({ type: 'start_session' }));
144
  resolve();
145
  };
@@ -168,7 +160,7 @@ export class VideoManagerLocal {
168
  });
169
  }
170
 
171
- // 开始截图并发送 (binary blobs for speed)
172
  startCapture() {
173
  const interval = 1000 / this.frameRate;
174
  this._sendingBlob = false; // prevent overlapping toBlob calls
@@ -272,7 +264,7 @@ export class VideoManagerLocal {
272
  }
273
  }
274
 
275
- // 处理服务器消息
276
  handleServerMessage(data) {
277
  switch (data.type) {
278
  case 'session_started':
@@ -590,17 +582,17 @@ export class VideoManagerLocal {
590
  this._stopRenderLoop();
591
  this._lastDetection = null;
592
 
593
- // 停止截图
594
  if (this.captureInterval) {
595
  clearInterval(this.captureInterval);
596
  this.captureInterval = null;
597
  }
598
 
599
- // 发送结束会话请求并等待响应
600
  if (this.ws && this.ws.readyState === WebSocket.OPEN && this.sessionId) {
601
  const sessionId = this.sessionId;
602
 
603
- // 等待 session_ended 消息
604
  const waitForSessionEnd = new Promise((resolve) => {
605
  const originalHandler = this.ws.onmessage;
606
  const timeout = setTimeout(() => {
@@ -618,7 +610,7 @@ export class VideoManagerLocal {
618
  this.ws.onmessage = originalHandler;
619
  resolve();
620
  } else {
621
- // 仍然处理其他消息
622
  this.handleServerMessage(data);
623
  }
624
  } catch (e) {
@@ -633,37 +625,37 @@ export class VideoManagerLocal {
633
  session_id: sessionId
634
  }));
635
 
636
- // 等待响应或超时
637
  await waitForSessionEnd;
638
  }
639
 
640
- // 延迟关闭 WebSocket 确保消息发送完成
641
  await new Promise(resolve => setTimeout(resolve, 200));
642
 
643
- // 关闭 WebSocket
644
  if (this.ws) {
645
  this.ws.close();
646
  this.ws = null;
647
  }
648
 
649
- // 停止摄像头
650
  if (this.stream) {
651
  this.stream.getTracks().forEach(track => track.stop());
652
  this.stream = null;
653
  }
654
 
655
- // 清空视频
656
  if (this.localVideoElement) {
657
  this.localVideoElement.srcObject = null;
658
  }
659
 
660
- // 清空 canvas
661
  if (this.displayCanvas) {
662
  const ctx = this.displayCanvas.getContext('2d');
663
  ctx.clearRect(0, 0, this.displayCanvas.width, this.displayCanvas.height);
664
  }
665
 
666
- // 清理状态
667
  this.unfocusedStartTime = null;
668
  this.lastNotificationTime = null;
669
 
@@ -675,7 +667,7 @@ export class VideoManagerLocal {
675
  this.frameRate = Math.max(10, Math.min(30, rate));
676
  console.log(`Frame rate set to ${this.frameRate} FPS`);
677
 
678
- // 重启截图(如果正在运行)
679
  if (this.isStreaming && this.captureInterval) {
680
  clearInterval(this.captureInterval);
681
  this.startCapture();
 
1
  // src/utils/VideoManagerLocal.js
2
+ // WebSocket + Canvas (no WebRTC)
3
 
4
  export class VideoManagerLocal {
5
  constructor(callbacks) {
6
  this.callbacks = callbacks || {};
7
 
8
+ this.localVideoElement = null;
9
+ this.displayVideoElement = null;
10
  this.canvas = null;
11
  this.stream = null;
12
  this.ws = null;
 
14
  this.isStreaming = false;
15
  this.sessionId = null;
16
  this.sessionStartTime = null;
17
+ this.frameRate = 15;
18
  this.captureInterval = null;
19
 
 
20
  this.currentStatus = false;
21
  this.statusBuffer = [];
22
  this.bufferSize = 3;
23
 
 
24
  this.latestDetectionData = null;
25
  this.lastConfidence = 0;
26
 
 
30
  // Continuous render loop
31
  this._animFrameId = null;
32
 
 
33
  this.notificationEnabled = true;
34
  this.notificationThreshold = 30;
35
  this.unfocusedStartTime = null;
36
  this.lastNotificationTime = null;
37
  this.notificationCooldown = 60000;
38
 
 
39
  this.stats = {
40
  framesSent: 0,
41
  framesProcessed: 0,
 
44
  };
45
  }
46
 
 
47
  async initCamera(localVideoRef, displayCanvasRef) {
48
  try {
49
  console.log('Initializing local camera...');
 
60
  this.localVideoElement = localVideoRef;
61
  this.displayCanvas = displayCanvasRef;
62
 
 
63
  if (this.localVideoElement) {
64
  this.localVideoElement.srcObject = this.stream;
65
  this.localVideoElement.play();
66
  }
67
 
 
68
  this.canvas = document.createElement('canvas');
69
  this.canvas.width = 320;
70
  this.canvas.height = 240;
 
77
  }
78
  }
79
 
 
80
  async startStreaming() {
81
  if (!this.stream) {
82
  throw new Error('Camera not initialized');
 
101
  }
102
  }
103
 
104
+ //
105
  await this.requestNotificationPermission();
106
  await this.loadNotificationSettings();
107
 
108
+ //
109
  await this.connectWebSocket();
110
 
111
+ //
112
  this.startCapture();
113
 
114
  // Start continuous render loop for smooth video
 
118
  console.log('Streaming started');
119
  }
120
 
121
+ //
122
  async connectWebSocket() {
123
  return new Promise((resolve, reject) => {
124
  const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
 
131
  this.ws.onopen = () => {
132
  console.log('WebSocket connected');
133
 
134
+ //
135
  this.ws.send(JSON.stringify({ type: 'start_session' }));
136
  resolve();
137
  };
 
160
  });
161
  }
162
 
163
+ //
164
  startCapture() {
165
  const interval = 1000 / this.frameRate;
166
  this._sendingBlob = false; // prevent overlapping toBlob calls
 
264
  }
265
  }
266
 
267
+ //
268
  handleServerMessage(data) {
269
  switch (data.type) {
270
  case 'session_started':
 
582
  this._stopRenderLoop();
583
  this._lastDetection = null;
584
 
585
+ //
586
  if (this.captureInterval) {
587
  clearInterval(this.captureInterval);
588
  this.captureInterval = null;
589
  }
590
 
591
+ //
592
  if (this.ws && this.ws.readyState === WebSocket.OPEN && this.sessionId) {
593
  const sessionId = this.sessionId;
594
 
595
+ //
596
  const waitForSessionEnd = new Promise((resolve) => {
597
  const originalHandler = this.ws.onmessage;
598
  const timeout = setTimeout(() => {
 
610
  this.ws.onmessage = originalHandler;
611
  resolve();
612
  } else {
613
+ //
614
  this.handleServerMessage(data);
615
  }
616
  } catch (e) {
 
625
  session_id: sessionId
626
  }));
627
 
628
+ //
629
  await waitForSessionEnd;
630
  }
631
 
632
+ //
633
  await new Promise(resolve => setTimeout(resolve, 200));
634
 
635
+ //
636
  if (this.ws) {
637
  this.ws.close();
638
  this.ws = null;
639
  }
640
 
641
+ //
642
  if (this.stream) {
643
  this.stream.getTracks().forEach(track => track.stop());
644
  this.stream = null;
645
  }
646
 
647
+ //
648
  if (this.localVideoElement) {
649
  this.localVideoElement.srcObject = null;
650
  }
651
 
652
+ //
653
  if (this.displayCanvas) {
654
  const ctx = this.displayCanvas.getContext('2d');
655
  ctx.clearRect(0, 0, this.displayCanvas.width, this.displayCanvas.height);
656
  }
657
 
658
+ //
659
  this.unfocusedStartTime = null;
660
  this.lastNotificationTime = null;
661
 
 
667
  this.frameRate = Math.max(10, Math.min(30, rate));
668
  console.log(`Frame rate set to ${this.frameRate} FPS`);
669
 
670
+ //
671
  if (this.isStreaming && this.captureInterval) {
672
  clearInterval(this.captureInterval);
673
  this.startCapture();
tests/test_data_preparation.py CHANGED
@@ -24,7 +24,7 @@ def test_generate_synthetic_data_shape():
24
  def test_get_numpy_splits_consistency():
25
  splits, num_features, num_classes, scaler = get_numpy_splits("face_orientation")
26
 
27
- # number of sample > 0,and split each in train/val/test
28
  n_train = len(splits["y_train"])
29
  n_val = len(splits["y_val"])
30
  n_test = len(splits["y_test"])
 
24
  def test_get_numpy_splits_consistency():
25
  splits, num_features, num_classes, scaler = get_numpy_splits("face_orientation")
26
 
27
+ # train/val/test each have samples
28
  n_train = len(splits["y_train"])
29
  n_val = len(splits["y_val"])
30
  n_test = len(splits["y_test"])
ui/README.md CHANGED
@@ -1,40 +1,16 @@
1
  # ui/
2
 
3
- Live camera demo and real-time inference pipeline.
4
 
5
- ## 1. Files
6
 
7
- | File | Description |
8
- |------|-------------|
9
- | `pipeline.py` | Inference pipelines: `FaceMeshPipeline`, `MLPPipeline`, `XGBoostPipeline` |
10
- | `live_demo.py` | OpenCV webcam window with mesh overlay and focus classification |
11
 
12
- ## 2. Pipelines
13
-
14
- | Pipeline | Features | Model | Source |
15
- |----------|----------|-------|--------|
16
- | `FaceMeshPipeline` | Head pose + eye geometry | Rule-based fusion | `models/head_pose.py`, `models/eye_scorer.py` |
17
- | `MLPPipeline` | 10 selected features | PyTorch MLP (10→64→32→2) | `checkpoints/mlp_best.pt` + `scaler_mlp.joblib` |
18
- | `XGBoostPipeline` | 10 selected features | XGBoost | `checkpoints/xgboost_face_orientation_best.json` |
19
-
20
- ## 3. Running
21
 
22
  ```bash
23
- # default mode (cycles through available pipelines)
24
  python ui/live_demo.py
25
-
26
- # start directly in XGBoost mode
27
  python ui/live_demo.py --xgb
28
  ```
29
 
30
- ### Controls
31
-
32
- | Key | Action |
33
- |-----|--------|
34
- | `m` | Cycle mesh overlay (full → contours → off) |
35
- | `p` | Switch pipeline (FaceMesh → MLP → XGBoost) |
36
- | `q` | Quit |
37
-
38
- ## 4. Integration
39
-
40
- The same pipelines are used by the FastAPI backend (`main.py`) for WebSocket-based video inference in the React app.
 
1
  # ui/
2
 
3
+ Live OpenCV demo and inference pipelines used by the app.
4
 
5
+ **Files:** `pipeline.py` (FaceMesh, MLP, XGBoost, Hybrid pipelines), `live_demo.py` (webcam window with mesh + focus label).
6
 
7
+ **Pipelines:** FaceMesh = rule-based head/eye; MLP = 10 features → PyTorch MLP (checkpoints/mlp_best.pt + scaler); XGBoost = same 10 features → xgboost_face_orientation_best.json. Hybrid combines ML/XGB with geometric scores.
 
 
 
8
 
9
+ **Run demo:**
 
 
 
 
 
 
 
 
10
 
11
  ```bash
 
12
  python ui/live_demo.py
 
 
13
  python ui/live_demo.py --xgb
14
  ```
15
 
16
+ `m` = cycle mesh, `p` = switch pipeline, `q` = quit. Same pipelines back the FastAPI WebSocket video in `main.py`.
 
 
 
 
 
 
 
 
 
 
ui/pipeline.py CHANGED
@@ -375,7 +375,7 @@ class MLPPipeline:
375
  out["mar"] = float(vec[_FEAT_IDX["mar"]])
376
 
377
  X = vec[self._indices].reshape(1, -1).astype(np.float32)
378
- X_sc = self._scaler.transform(X)
379
  with torch.no_grad():
380
  x_t = torch.from_numpy(X_sc).float()
381
  logits = self._mlp(x_t)
@@ -534,7 +534,7 @@ class HybridFocusPipeline:
534
  focus_score = self._cfg["w_xgb"] * model_prob + self._cfg["w_geo"] * out["geo_score"]
535
  else:
536
  X = vec[self._indices].reshape(1, -1).astype(np.float32)
537
- X_sc = self._scaler.transform(X)
538
  with torch.no_grad():
539
  x_t = torch.from_numpy(X_sc).float()
540
  logits = self._mlp(x_t)
 
375
  out["mar"] = float(vec[_FEAT_IDX["mar"]])
376
 
377
  X = vec[self._indices].reshape(1, -1).astype(np.float32)
378
+ X_sc = self._scaler.transform(X) if self._scaler is not None else X
379
  with torch.no_grad():
380
  x_t = torch.from_numpy(X_sc).float()
381
  logits = self._mlp(x_t)
 
534
  focus_score = self._cfg["w_xgb"] * model_prob + self._cfg["w_geo"] * out["geo_score"]
535
  else:
536
  X = vec[self._indices].reshape(1, -1).astype(np.float32)
537
+ X_sc = self._scaler.transform(X) if self._scaler is not None else X
538
  with torch.no_grad():
539
  x_t = torch.from_numpy(X_sc).float()
540
  logits = self._mlp(x_t)