Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- LICENSE +21 -0
- README.md +152 -0
- experiments/__init__.py +0 -0
- experiments/analysis/__init__.py +0 -0
- experiments/analysis/aggregate_new_exps.py +166 -0
- experiments/analysis/aggregate_t1_extended.py +60 -0
- experiments/analysis/analysis_figures.py +444 -0
- experiments/analysis/build_taxonomy.py +136 -0
- experiments/analysis/check_seg_lengths.py +229 -0
- experiments/analysis/data_statistics_figure.py +126 -0
- experiments/analysis/exp_per_subject.py +150 -0
- experiments/analysis/extract_video_features.py +208 -0
- experiments/analysis/extract_videomae_features.py +276 -0
- experiments/analysis/gen_val_comparison.py +74 -0
- experiments/analysis/generate_action_labels.py +133 -0
- experiments/analysis/generate_coarse_annotations.py +296 -0
- experiments/analysis/grasp_phase_analysis.py +442 -0
- experiments/analysis/modality_viz.py +145 -0
- experiments/analysis/reannotate_actions.py +363 -0
- experiments/data/__init__.py +0 -0
- experiments/data/__pycache__/dataset.cpython-312.pyc +0 -0
- experiments/data/dataset.py +332 -0
- experiments/data/dataset_forecast.py +319 -0
- experiments/data/dataset_grasp_state.py +571 -0
- experiments/data/dataset_seqpred.py +533 -0
- experiments/data/dataset_signal_forecast.py +391 -0
- experiments/nets/__init__.py +0 -0
- experiments/nets/__pycache__/models_seqpred.cpython-312.pyc +0 -0
- experiments/nets/baselines_published/__init__.py +0 -0
- experiments/nets/baselines_published/baselines.py +488 -0
- experiments/nets/baselines_published/syncfuse.py +270 -0
- experiments/nets/models.py +648 -0
- experiments/nets/models_forecast.py +269 -0
- experiments/nets/models_forecast_priv.py +76 -0
- experiments/nets/models_seqpred.py +806 -0
- experiments/nets/published_models.py +699 -0
- experiments/s9_primitives.json +76 -0
- experiments/slurm/freeze_all_rows.sh +179 -0
- experiments/slurm/run_ablation_fix.sh +33 -0
- experiments/slurm/run_ablation_fusion.sh +174 -0
- experiments/slurm/run_asformer_exp3.sh +44 -0
- experiments/slurm/run_exp1.sh +40 -0
- experiments/slurm/run_exp1_fusion.sh +36 -0
- experiments/slurm/run_exp1_parallel.sh +67 -0
- experiments/slurm/run_exp1_small.sh +84 -0
- experiments/slurm/run_exp1_small2.sh +85 -0
- experiments/slurm/run_exp1_small3.sh +137 -0
- experiments/slurm/run_exp1_v3.sh +68 -0
- experiments/slurm/run_exp1_v4.sh +69 -0
- experiments/slurm/run_exp1_v5.sh +62 -0
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 Anonymous Authors (under double-blind review for NeurIPS 2026)
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
library_name: pytorch
|
| 6 |
+
tags:
|
| 7 |
+
- multi-modal
|
| 8 |
+
- daily-activity
|
| 9 |
+
- wearable-sensors
|
| 10 |
+
- benchmark
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# PULSE — Code Repository
|
| 14 |
+
|
| 15 |
+
Reference implementation, training scripts, and benchmark baselines for the
|
| 16 |
+
**PULSE** dataset paper (under double-blind review at NeurIPS 2026 Evaluations &
|
| 17 |
+
Datasets Track).
|
| 18 |
+
|
| 19 |
+
> **Dataset:** [`velvet-pine-22/PULSE`](https://huggingface.co/datasets/velvet-pine-22/PULSE)
|
| 20 |
+
> · **Sample subset (≈285 MB):** [`velvet-pine-22/PULSE-sample`](https://huggingface.co/datasets/velvet-pine-22/PULSE-sample)
|
| 21 |
+
|
| 22 |
+
## Repository layout
|
| 23 |
+
|
| 24 |
+
```
|
| 25 |
+
PULSE-code/
|
| 26 |
+
├── experiments/
|
| 27 |
+
│ ├── data/ # PyTorch Dataset wrappers
|
| 28 |
+
│ │ ├── dataset.py # core multi-modal dataset (T1, T2)
|
| 29 |
+
│ │ ├── dataset_seqpred.py # T2 fine-grained action recognition
|
| 30 |
+
│ │ ├── dataset_grasp_state.py # T3 grasp onset anticipation
|
| 31 |
+
│ │ ├── dataset_forecast.py # auxiliary forecasting heads
|
| 32 |
+
│ │ └── dataset_signal_forecast.py # T5 tactile-driven motion forecast
|
| 33 |
+
│ │
|
| 34 |
+
│ ├── nets/ # Model architectures
|
| 35 |
+
│ │ ├── models.py # backbone networks (Transformer / LSTM / 1D-CNN)
|
| 36 |
+
│ │ ├── models_seqpred.py # DailyActFormer (DAF) — multi-modal Transformer
|
| 37 |
+
│ │ ├── models_forecast.py # forecasting heads
|
| 38 |
+
│ │ ├── models_forecast_priv.py # privileged-tactile variants for T5
|
| 39 |
+
│ │ ├── published_models.py # third-party model implementations
|
| 40 |
+
│ │ └── baselines_published/ # 7 published baselines (re-implementation)
|
| 41 |
+
│ │ ├── baselines.py # DeepConvLSTM / InceptionTime / MS-TCN / etc.
|
| 42 |
+
│ │ └── syncfuse.py # under-pressure-style multi-modal fusion
|
| 43 |
+
│ │
|
| 44 |
+
│ ├── tasks/ # Training + evaluation entry points
|
| 45 |
+
│ │ ├── train_exp1.py # T1 — scene recognition
|
| 46 |
+
│ │ ├── train_seqpred.py # T2 — action recognition (DAF + ablations)
|
| 47 |
+
│ │ ├── train_grasp_state.py # T3 — grasp onset anticipation
|
| 48 |
+
│ │ ├── train_pred_cls.py # T3 alt classification head
|
| 49 |
+
│ │ ├── train_exp_missing.py # T4 — missing-modality robustness
|
| 50 |
+
│ │ ├── train_signal_forecast.py # T5 — tactile-driven motion forecasting
|
| 51 |
+
│ │ ├── train_signal_forecast_priv.py # T5 privileged variants
|
| 52 |
+
│ │ ├── train_baselines_t1.py # baselines for T1
|
| 53 |
+
│ │ ├── train_exp{2,3,4}.py # ablation experiments
|
| 54 |
+
│ │ ├── train_exp_{anticipate,grip,pose,retrieval,zeroshot}.py # auxiliary
|
| 55 |
+
│ │ ├── train_pred.py / train_forecast.py
|
| 56 |
+
│ │ ├── eval_baselines.py / eval_combined.py
|
| 57 |
+
│ │ └── published_baselines.py # baseline registry
|
| 58 |
+
│ │
|
| 59 |
+
│ ├── analysis/ # Case study, figures, data prep utilities
|
| 60 |
+
│ │ ├── grasp_phase_analysis.py # case study (gaze→EMG→hand→contact cascade)
|
| 61 |
+
│ │ ├── modality_viz.py / analysis_figures.py / data_statistics_figure.py
|
| 62 |
+
│ │ ├── extract_video_features.py / extract_videomae_features.py
|
| 63 |
+
│ │ ├── build_taxonomy.py / generate_action_labels.py / generate_coarse_annotations.py
|
| 64 |
+
│ │ ├── reannotate_actions.py / gen_val_comparison.py
|
| 65 |
+
│ │ ├── exp_per_subject.py / check_seg_lengths.py
|
| 66 |
+
│ │ └── aggregate_*.py # collate run results
|
| 67 |
+
│ │
|
| 68 |
+
│ ├── slurm/ # 60+ SLURM launch scripts (one per main experiment)
|
| 69 |
+
│ │ └── run_*.sh
|
| 70 |
+
│ │
|
| 71 |
+
│ ├── taxonomy.py # shared 18-primitive taxonomy
|
| 72 |
+
│ ├── s9_primitives.json
|
| 73 |
+
│ └── taxonomy_v3.json
|
| 74 |
+
│
|
| 75 |
+
├── scripts/ # Top-level utilities (not task-specific)
|
| 76 |
+
│ ├── build_paper_tables.py # collates results JSONs into LaTeX tables
|
| 77 |
+
│ ├── eval_macrof1.py / eval_subset.py / eval_topk_v3.py
|
| 78 |
+
│ └── dispatch_eval.sh # batch dispatcher
|
| 79 |
+
│
|
| 80 |
+
├── LICENSE # MIT
|
| 81 |
+
├── requirements.txt # Python deps
|
| 82 |
+
└── README.md
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
## Quick start
|
| 86 |
+
|
| 87 |
+
```bash
|
| 88 |
+
# 1. Set up Python environment
|
| 89 |
+
python -m venv .venv && source .venv/bin/activate
|
| 90 |
+
pip install -r requirements.txt
|
| 91 |
+
|
| 92 |
+
# 2. Point at the PULSE dataset (download from HuggingFace first)
|
| 93 |
+
export PULSE_ROOT=/path/to/PULSE # the dataset root (not this code repo)
|
| 94 |
+
|
| 95 |
+
# 3. Run a training entry point as a module (from the experiments/ directory)
|
| 96 |
+
cd experiments
|
| 97 |
+
python -m tasks.train_seqpred \
|
| 98 |
+
--root $PULSE_ROOT \
|
| 99 |
+
--modalities mocap emg eyetrack imu pressure \
|
| 100 |
+
--output_dir runs/t2_daf
|
| 101 |
+
|
| 102 |
+
# 4. Reproduce paper tables (after training all benchmarks)
|
| 103 |
+
cd ..
|
| 104 |
+
python scripts/build_paper_tables.py \
|
| 105 |
+
--results_root experiments/runs/ \
|
| 106 |
+
--out tables/
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
> **Why `python -m tasks.train_seqpred` and not `python tasks/train_seqpred.py`?**
|
| 110 |
+
> The training scripts import sibling modules (`from data.dataset import …`,
|
| 111 |
+
> `from nets.models import …`). Running with `-m` from the `experiments/`
|
| 112 |
+
> directory makes Python treat `data/`, `nets/`, `tasks/`, and `analysis/` as
|
| 113 |
+
> top-level packages so the imports resolve cleanly.
|
| 114 |
+
|
| 115 |
+
## Reproducing the benchmark tasks
|
| 116 |
+
|
| 117 |
+
| Task | Entry point | Output |
|
| 118 |
+
|---|---|---|
|
| 119 |
+
| T1 — Scene recognition (8-way) | `tasks.train_exp1` | scene-classification metrics |
|
| 120 |
+
| T2 — Fine-grained action recognition | `tasks.train_seqpred` | verb / noun / hand top-k accuracy |
|
| 121 |
+
| T3 — Grasp onset anticipation | `tasks.train_grasp_state` / `tasks.train_pred_cls` | anticipation F1 / time-to-contact |
|
| 122 |
+
| T4 — Missing-modality robustness | `tasks.train_exp_missing` + `tasks.eval_combined` | per-modality ablation table |
|
| 123 |
+
| T5 — Tactile-driven grasp-state recognition | `tasks.train_signal_forecast` (+ `_priv` variants) | sub-second grasp-state metrics |
|
| 124 |
+
| T6 — Cross-modal pressure prediction | `tasks.train_forecast` / `tasks.train_signal_forecast` | pressure reconstruction metrics |
|
| 125 |
+
|
| 126 |
+
The exact command lines (with hyperparameters, seeds, GPU configs) used for
|
| 127 |
+
every paper table are checked in under `experiments/slurm/run_*.sh`, one
|
| 128 |
+
SLURM script per paper experiment. Output JSON files from these runs are
|
| 129 |
+
collated into LaTeX tables by `scripts/build_paper_tables.py`.
|
| 130 |
+
|
| 131 |
+
## Hardware
|
| 132 |
+
|
| 133 |
+
Headline experiments were run on **NVIDIA A800 (80 GB)** GPUs. A single seed of
|
| 134 |
+
DailyActFormer T2 trains in ~6 hours on one A800. Most baselines fit on a
|
| 135 |
+
single 24 GB consumer GPU.
|
| 136 |
+
|
| 137 |
+
## License & attribution
|
| 138 |
+
|
| 139 |
+
Code is released under **MIT** (see `LICENSE`). The PULSE dataset itself is
|
| 140 |
+
released under **CC BY-NC 4.0** (see the dataset repository).
|
| 141 |
+
|
| 142 |
+
## Citation
|
| 143 |
+
|
| 144 |
+
```bibtex
|
| 145 |
+
@inproceedings{anonymous2026pulse,
|
| 146 |
+
title = {PULSE: A Synchronized Five-Modality Dataset for Multi-Modal Daily Activity Understanding},
|
| 147 |
+
author = {Anonymous Authors},
|
| 148 |
+
booktitle = {Submitted to NeurIPS 2026 Evaluations and Datasets Track},
|
| 149 |
+
year = {2026},
|
| 150 |
+
note = {Under double-blind review}
|
| 151 |
+
}
|
| 152 |
+
```
|
experiments/__init__.py
ADDED
|
File without changes
|
experiments/analysis/__init__.py
ADDED
|
File without changes
|
experiments/analysis/aggregate_new_exps.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Aggregate results from the three new benchmark experiments."""
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
import glob
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
ROOT = '${PULSE_ROOT}/results/exp_new'
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_results(pattern):
|
| 12 |
+
files = sorted(glob.glob(pattern))
|
| 13 |
+
results = []
|
| 14 |
+
for f in files:
|
| 15 |
+
try:
|
| 16 |
+
results.append(json.load(open(f)))
|
| 17 |
+
except Exception as e:
|
| 18 |
+
print(f" ERR: {f}: {e}")
|
| 19 |
+
return results
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def aggregate_expA():
|
| 23 |
+
"""Missing modality: average across seeds per eval config."""
|
| 24 |
+
print("\n" + "=" * 70)
|
| 25 |
+
print("EXP A: Missing-modality robustness")
|
| 26 |
+
print("=" * 70)
|
| 27 |
+
|
| 28 |
+
for subdir in ['expA_missing', 'expA_baseline']:
|
| 29 |
+
files = load_results(f'{ROOT}/{subdir}/*/results.json')
|
| 30 |
+
if not files:
|
| 31 |
+
print(f" No results yet for {subdir}")
|
| 32 |
+
continue
|
| 33 |
+
print(f"\n-- {subdir} (n seeds = {len(files)}) --")
|
| 34 |
+
# Group by eval config name; accumulate F1/Acc over seeds
|
| 35 |
+
config_stats = {}
|
| 36 |
+
for r in files:
|
| 37 |
+
if 'eval_configs' not in r:
|
| 38 |
+
continue
|
| 39 |
+
for name, info in r['eval_configs'].items():
|
| 40 |
+
config_stats.setdefault(name, {'f1': [], 'acc': [], 'active': info['active']})
|
| 41 |
+
config_stats[name]['f1'].append(info['f1'])
|
| 42 |
+
config_stats[name]['acc'].append(info['acc'])
|
| 43 |
+
|
| 44 |
+
# Order: full, leave-one-out, singletons
|
| 45 |
+
full_names = [n for n in config_stats if n == 'full']
|
| 46 |
+
drop_names = sorted([n for n in config_stats if n.startswith('drop_')])
|
| 47 |
+
only_names = sorted([n for n in config_stats if n.startswith('only_')])
|
| 48 |
+
|
| 49 |
+
print(f" {'Config':<22s} {'Active modalities':<42s} "
|
| 50 |
+
f"{'F1 mean±std':<14s} {'Acc mean±std':<14s}")
|
| 51 |
+
print(' ' + '-' * 96)
|
| 52 |
+
for grp in [full_names, drop_names, only_names]:
|
| 53 |
+
for name in grp:
|
| 54 |
+
d = config_stats[name]
|
| 55 |
+
f1_m, f1_s = np.mean(d['f1']), np.std(d['f1'])
|
| 56 |
+
ac_m, ac_s = np.mean(d['acc']), np.std(d['acc'])
|
| 57 |
+
active = ','.join(d['active'])
|
| 58 |
+
print(f" {name:<22s} {active:<42s} "
|
| 59 |
+
f"{f1_m:.3f}±{f1_s:.3f} {ac_m:.3f}±{ac_s:.3f}")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def aggregate_expB():
|
| 63 |
+
"""Grip regression: group by (backbone, mod_config), average over seeds."""
|
| 64 |
+
print("\n" + "=" * 70)
|
| 65 |
+
print("EXP B: Grip force regression")
|
| 66 |
+
print("=" * 70)
|
| 67 |
+
files = load_results(f'{ROOT}/expB_grip/*/results.json')
|
| 68 |
+
if not files:
|
| 69 |
+
print(" No results yet")
|
| 70 |
+
return
|
| 71 |
+
|
| 72 |
+
# Group
|
| 73 |
+
groups = {}
|
| 74 |
+
for r in files:
|
| 75 |
+
if 'best_test_metrics' not in r:
|
| 76 |
+
continue
|
| 77 |
+
key = (r['backbone'], ','.join(r['modalities']))
|
| 78 |
+
groups.setdefault(key, []).append(r)
|
| 79 |
+
|
| 80 |
+
rows = []
|
| 81 |
+
for (bb, mods), rs in groups.items():
|
| 82 |
+
mae_R = [r['best_test_metrics']['right_hand']['mae_g'] for r in rs]
|
| 83 |
+
mae_L = [r['best_test_metrics']['left_hand']['mae_g'] for r in rs]
|
| 84 |
+
r_R = [r['best_test_metrics']['right_hand']['pearson_r'] for r in rs]
|
| 85 |
+
r_L = [r['best_test_metrics']['left_hand']['pearson_r'] for r in rs]
|
| 86 |
+
r2_R = [r['best_test_metrics']['right_hand']['r2'] for r in rs]
|
| 87 |
+
r2_L = [r['best_test_metrics']['left_hand']['r2'] for r in rs]
|
| 88 |
+
mae_avg = [r['best_test_metrics']['avg_mae_g'] for r in rs]
|
| 89 |
+
r_avg = [r['best_test_metrics']['avg_pearson_r'] for r in rs]
|
| 90 |
+
rows.append({
|
| 91 |
+
'backbone': bb,
|
| 92 |
+
'modalities': mods,
|
| 93 |
+
'n_seeds': len(rs),
|
| 94 |
+
'mae_R': (np.mean(mae_R), np.std(mae_R)),
|
| 95 |
+
'mae_L': (np.mean(mae_L), np.std(mae_L)),
|
| 96 |
+
'mae_avg': (np.mean(mae_avg), np.std(mae_avg)),
|
| 97 |
+
'r_R': (np.mean(r_R), np.std(r_R)),
|
| 98 |
+
'r_L': (np.mean(r_L), np.std(r_L)),
|
| 99 |
+
'r_avg': (np.mean(r_avg), np.std(r_avg)),
|
| 100 |
+
'r2_R': (np.mean(r2_R), np.std(r2_R)),
|
| 101 |
+
'r2_L': (np.mean(r2_L), np.std(r2_L)),
|
| 102 |
+
})
|
| 103 |
+
rows.sort(key=lambda r: r['r_avg'][0], reverse=True)
|
| 104 |
+
print(f" {'Backbone':<12s} {'Modalities':<30s} N "
|
| 105 |
+
f"{'MAE(g) avg':<14s} {'Pearson r avg':<14s} {'R²(R)':<12s} {'R²(L)':<12s}")
|
| 106 |
+
print(' ' + '-' * 102)
|
| 107 |
+
for row in rows:
|
| 108 |
+
print(f" {row['backbone']:<12s} {row['modalities']:<30s} {row['n_seeds']} "
|
| 109 |
+
f"{row['mae_avg'][0]:.1f}±{row['mae_avg'][1]:.1f} "
|
| 110 |
+
f"{row['r_avg'][0]:.3f}±{row['r_avg'][1]:.3f} "
|
| 111 |
+
f"{row['r2_R'][0]:.3f}±{row['r2_R'][1]:.3f} "
|
| 112 |
+
f"{row['r2_L'][0]:.3f}±{row['r2_L'][1]:.3f}")
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def aggregate_expC():
|
| 116 |
+
"""T5 retrieval: group by mod config, average over seeds."""
|
| 117 |
+
print("\n" + "=" * 70)
|
| 118 |
+
print("EXP C: T5 Cross-modal text retrieval")
|
| 119 |
+
print("=" * 70)
|
| 120 |
+
files = load_results(f'{ROOT}/expC_retrieval/*/results.json')
|
| 121 |
+
if not files:
|
| 122 |
+
print(" No results yet")
|
| 123 |
+
return
|
| 124 |
+
groups = {}
|
| 125 |
+
for r in files:
|
| 126 |
+
if 'final_avg_over_3_pool_seeds' not in r:
|
| 127 |
+
continue
|
| 128 |
+
key = ','.join(r['modalities'])
|
| 129 |
+
groups.setdefault(key, []).append(r)
|
| 130 |
+
|
| 131 |
+
rows = []
|
| 132 |
+
for mods, rs in groups.items():
|
| 133 |
+
r1 = [r['final_avg_over_3_pool_seeds']['recall@1'] for r in rs]
|
| 134 |
+
r5 = [r['final_avg_over_3_pool_seeds']['recall@5'] for r in rs]
|
| 135 |
+
r10 = [r['final_avg_over_3_pool_seeds']['recall@10'] for r in rs]
|
| 136 |
+
medR = [r['final_avg_over_3_pool_seeds']['median_rank'] for r in rs]
|
| 137 |
+
rows.append({
|
| 138 |
+
'modalities': mods,
|
| 139 |
+
'n_seeds': len(rs),
|
| 140 |
+
'r1': (np.mean(r1), np.std(r1)),
|
| 141 |
+
'r5': (np.mean(r5), np.std(r5)),
|
| 142 |
+
'r10': (np.mean(r10), np.std(r10)),
|
| 143 |
+
'medR': (np.mean(medR), np.std(medR)),
|
| 144 |
+
'n_test': rs[0].get('n_test_segments', 0),
|
| 145 |
+
'K': rs[0].get('K_pool', 100),
|
| 146 |
+
})
|
| 147 |
+
rows.sort(key=lambda r: r['r10'][0], reverse=True)
|
| 148 |
+
print(f" {'Modalities':<30s} N N_test K "
|
| 149 |
+
f"{'R@1':<12s} {'R@5':<12s} {'R@10':<12s} {'medR':<12s}")
|
| 150 |
+
print(' ' + '-' * 100)
|
| 151 |
+
for row in rows:
|
| 152 |
+
print(f" {row['modalities']:<30s} {row['n_seeds']} {row['n_test']:<6d} {row['K']:<2d} "
|
| 153 |
+
f"{row['r1'][0]:.3f}±{row['r1'][1]:.3f} "
|
| 154 |
+
f"{row['r5'][0]:.3f}±{row['r5'][1]:.3f} "
|
| 155 |
+
f"{row['r10'][0]:.3f}±{row['r10'][1]:.3f} "
|
| 156 |
+
f"{row['medR'][0]:.1f}±{row['medR'][1]:.1f}")
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def main():
|
| 160 |
+
aggregate_expA()
|
| 161 |
+
aggregate_expB()
|
| 162 |
+
aggregate_expC()
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
if __name__ == '__main__':
|
| 166 |
+
main()
|
experiments/analysis/aggregate_t1_extended.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Aggregate T1 extended benchmark results.
|
| 3 |
+
Prints a Markdown-style table sorted by F1 desc."""
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
import glob
|
| 7 |
+
import numpy as np
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
|
| 10 |
+
ROOT = '${PULSE_ROOT}/results/t1_extended'
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def collect(pattern):
|
| 14 |
+
by_key = defaultdict(list)
|
| 15 |
+
for f in sorted(glob.glob(pattern)):
|
| 16 |
+
try:
|
| 17 |
+
r = json.load(open(f))
|
| 18 |
+
except Exception as e:
|
| 19 |
+
print(f" ERR reading {f}: {e}")
|
| 20 |
+
continue
|
| 21 |
+
key = r.get('method', os.path.basename(os.path.dirname(f)))
|
| 22 |
+
# Distinguish ablations by tag
|
| 23 |
+
tag = r.get('args', {}).get('tag', '')
|
| 24 |
+
if tag:
|
| 25 |
+
key = f"{key}_{tag}"
|
| 26 |
+
by_key[key].append(r)
|
| 27 |
+
return by_key
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def main():
|
| 31 |
+
groups = collect(f'{ROOT}/*/results.json')
|
| 32 |
+
rows = []
|
| 33 |
+
for key, rs in groups.items():
|
| 34 |
+
f1s = [r['test_f1'] for r in rs]
|
| 35 |
+
accs = [r['test_acc'] for r in rs]
|
| 36 |
+
mods = ','.join(rs[0]['modalities'])
|
| 37 |
+
rows.append({
|
| 38 |
+
'method': key,
|
| 39 |
+
'modalities': mods,
|
| 40 |
+
'n_seeds': len(rs),
|
| 41 |
+
'f1_mean': np.mean(f1s),
|
| 42 |
+
'f1_std': np.std(f1s),
|
| 43 |
+
'acc_mean': np.mean(accs),
|
| 44 |
+
'acc_std': np.std(accs),
|
| 45 |
+
'n_params': rs[0].get('n_params', 0),
|
| 46 |
+
})
|
| 47 |
+
rows.sort(key=lambda r: r['f1_mean'], reverse=True)
|
| 48 |
+
|
| 49 |
+
print(f"\n{'Method':<28s} {'Modalities':<32s} N {'F1 mean±std':<14s} "
|
| 50 |
+
f"{'Acc mean±std':<14s} Params")
|
| 51 |
+
print('-' * 110)
|
| 52 |
+
for r in rows:
|
| 53 |
+
print(f"{r['method']:<28s} {r['modalities']:<32s} {r['n_seeds']} "
|
| 54 |
+
f"{r['f1_mean']:.3f}±{r['f1_std']:.3f} "
|
| 55 |
+
f"{r['acc_mean']:.3f}±{r['acc_std']:.3f} "
|
| 56 |
+
f"{r['n_params']:,}")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
if __name__ == '__main__':
|
| 60 |
+
main()
|
experiments/analysis/analysis_figures.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Generate three showcase figures for the main paper:
|
| 3 |
+
1. Eye-Hand-Contact coordination (gaze fixation + hand velocity + pressure)
|
| 4 |
+
2. Pressure fingerprints per action category
|
| 5 |
+
3. 3D hand trajectory colored by pressure
|
| 6 |
+
"""
|
| 7 |
+
import os, glob, json, re
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import matplotlib
|
| 11 |
+
matplotlib.use('Agg')
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
from scipy.signal import savgol_filter
|
| 14 |
+
|
| 15 |
+
DATASET = "${PULSE_ROOT}/dataset"
|
| 16 |
+
OUT_DIR = "${PULSE_ROOT}/paper/figures"
|
| 17 |
+
os.makedirs(OUT_DIR, exist_ok=True)
|
| 18 |
+
|
| 19 |
+
PRESSURE_THRESHOLD = 5.0
|
| 20 |
+
FPS = 100
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ============================================================
|
| 24 |
+
# Shared data-loading helpers
|
| 25 |
+
# ============================================================
|
| 26 |
+
|
| 27 |
+
def load_pressure(scenario_dir):
|
| 28 |
+
"""Return (T, 2) array of (right_total, left_total) pressure."""
|
| 29 |
+
f = os.path.join(scenario_dir, "aligned_pressure_100hz.csv")
|
| 30 |
+
if not os.path.exists(f):
|
| 31 |
+
return None
|
| 32 |
+
df = pd.read_csv(f, low_memory=False)
|
| 33 |
+
r_cols = [c for c in df.columns if c.startswith('R') and c.endswith('(g)')]
|
| 34 |
+
l_cols = [c for c in df.columns if c.startswith('L') and c.endswith('(g)')]
|
| 35 |
+
if len(r_cols) < 20 or len(l_cols) < 20:
|
| 36 |
+
return None
|
| 37 |
+
r = df[r_cols].apply(pd.to_numeric, errors='coerce').fillna(0).values
|
| 38 |
+
l = df[l_cols].apply(pd.to_numeric, errors='coerce').fillna(0).values
|
| 39 |
+
return r, l # (T, 25) each
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def load_emg(scenario_dir):
|
| 43 |
+
f = os.path.join(scenario_dir, "aligned_emg_100hz.csv")
|
| 44 |
+
if not os.path.exists(f):
|
| 45 |
+
return None
|
| 46 |
+
df = pd.read_csv(f, low_memory=False)
|
| 47 |
+
numeric = [c for c in df.select_dtypes(include=[np.number]).columns
|
| 48 |
+
if c not in ('time', 'UTC', 'Frame')]
|
| 49 |
+
if len(numeric) < 4:
|
| 50 |
+
return None
|
| 51 |
+
return np.nan_to_num(df[numeric].values.astype(np.float32))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def load_gaze(scenario_dir):
|
| 55 |
+
f = os.path.join(scenario_dir, "aligned_eyetrack_100hz.csv")
|
| 56 |
+
if not os.path.exists(f):
|
| 57 |
+
return None
|
| 58 |
+
df = pd.read_csv(f, low_memory=False)
|
| 59 |
+
gx_col = [c for c in df.columns if 'Gaze X' in c and 'Scene Cam' in c]
|
| 60 |
+
gy_col = [c for c in df.columns if 'Gaze Y' in c and 'Scene Cam' in c]
|
| 61 |
+
if gx_col and gy_col:
|
| 62 |
+
gx = pd.to_numeric(df[gx_col[0]], errors='coerce').fillna(0).values
|
| 63 |
+
gy = pd.to_numeric(df[gy_col[0]], errors='coerce').fillna(0).values
|
| 64 |
+
return np.stack([gx, gy], axis=1)
|
| 65 |
+
return None
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def load_mocap_hand(scenario_dir, vol, scenario):
|
| 69 |
+
"""Return wrist 3D position (T,3) and tip position summary."""
|
| 70 |
+
f = os.path.join(scenario_dir, f"aligned_{vol}{scenario}_s_Q.tsv")
|
| 71 |
+
if not os.path.exists(f):
|
| 72 |
+
return None, None
|
| 73 |
+
df = pd.read_csv(f, sep='\t', low_memory=False)
|
| 74 |
+
# Right hand wrist (try several naming patterns)
|
| 75 |
+
candidates = [
|
| 76 |
+
['RightHand_X','RightHand_Y','RightHand_Z'],
|
| 77 |
+
['R_Hand_X','R_Hand_Y','R_Hand_Z'],
|
| 78 |
+
['Q_RWristIn_X','Q_RWristIn_Y','Q_RWristIn_Z'],
|
| 79 |
+
]
|
| 80 |
+
r_wrist = None
|
| 81 |
+
for cs in candidates:
|
| 82 |
+
if all(c in df.columns for c in cs):
|
| 83 |
+
r_wrist = df[cs].apply(pd.to_numeric, errors='coerce').fillna(0).values
|
| 84 |
+
break
|
| 85 |
+
l_wrist = None
|
| 86 |
+
for cs_l in [['LeftHand_X','LeftHand_Y','LeftHand_Z'],
|
| 87 |
+
['L_Hand_X','L_Hand_Y','L_Hand_Z'],
|
| 88 |
+
['Q_LWristIn_X','Q_LWristIn_Y','Q_LWristIn_Z']]:
|
| 89 |
+
if all(c in df.columns for c in cs_l):
|
| 90 |
+
l_wrist = df[cs_l].apply(pd.to_numeric, errors='coerce').fillna(0).values
|
| 91 |
+
break
|
| 92 |
+
return r_wrist, l_wrist
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def compute_velocity(position, window=5):
|
| 96 |
+
"""Magnitude of velocity (after smoothing)."""
|
| 97 |
+
vel = np.zeros_like(position)
|
| 98 |
+
vel[1:] = position[1:] - position[:-1]
|
| 99 |
+
mag = np.linalg.norm(vel, axis=1)
|
| 100 |
+
try:
|
| 101 |
+
mag = savgol_filter(mag, window_length=min(window*2+1, len(mag)-1 if len(mag)%2==0 else len(mag)), polyorder=2)
|
| 102 |
+
except:
|
| 103 |
+
pass
|
| 104 |
+
return mag
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def detect_grasp_events(hand_pressure, threshold=PRESSURE_THRESHOLD, min_gap=50):
|
| 108 |
+
"""Detect pressure onset events."""
|
| 109 |
+
total = hand_pressure.sum(axis=1) if hand_pressure.ndim == 2 else hand_pressure
|
| 110 |
+
above = total > threshold
|
| 111 |
+
onsets = []
|
| 112 |
+
last_state = False
|
| 113 |
+
for i, a in enumerate(above):
|
| 114 |
+
if a and not last_state:
|
| 115 |
+
if i + 10 < len(above) and np.mean(above[i:i+10]) > 0.7:
|
| 116 |
+
if not onsets or i - onsets[-1] > min_gap:
|
| 117 |
+
onsets.append(i)
|
| 118 |
+
last_state = True
|
| 119 |
+
elif not a and last_state:
|
| 120 |
+
if i + 5 < len(above) and np.mean(above[i:i+5]) < 0.3:
|
| 121 |
+
last_state = False
|
| 122 |
+
return onsets
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def emg_envelope(emg, window=20):
|
| 126 |
+
rect = np.abs(emg - np.mean(emg, axis=0))
|
| 127 |
+
kernel = np.ones(window) / window
|
| 128 |
+
env = np.stack([np.convolve(rect[:, c], kernel, mode='same') for c in range(rect.shape[1])], axis=1)
|
| 129 |
+
return env.sum(axis=1)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def gaze_velocity(gaze_xy, window=5):
|
| 133 |
+
"""Magnitude of gaze velocity — high = saccade, low = fixation."""
|
| 134 |
+
v = np.zeros_like(gaze_xy)
|
| 135 |
+
v[1:] = gaze_xy[1:] - gaze_xy[:-1]
|
| 136 |
+
mag = np.linalg.norm(v, axis=1)
|
| 137 |
+
try:
|
| 138 |
+
mag = savgol_filter(mag, window_length=min(window*2+1, 15), polyorder=2)
|
| 139 |
+
except:
|
| 140 |
+
pass
|
| 141 |
+
return mag
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# ============================================================
|
| 145 |
+
# FIGURE 1: Eye-Hand-Contact coordination
|
| 146 |
+
# ============================================================
|
| 147 |
+
def make_eye_hand_contact_figure():
|
| 148 |
+
print("=== Figure 1: Eye-Hand-Contact coordination ===")
|
| 149 |
+
context = 200 # 2s before + 0.5s after
|
| 150 |
+
after = 50
|
| 151 |
+
events = [] # list of dicts: gaze_vel, hand_vel, pressure, all shape (context+after,)
|
| 152 |
+
|
| 153 |
+
for vol_dir in sorted(glob.glob(f"{DATASET}/v*")):
|
| 154 |
+
vol = os.path.basename(vol_dir)
|
| 155 |
+
for sd in sorted(glob.glob(f"{vol_dir}/s*")):
|
| 156 |
+
scenario = os.path.basename(sd)
|
| 157 |
+
meta_path = os.path.join(sd, "alignment_metadata.json")
|
| 158 |
+
if not os.path.exists(meta_path):
|
| 159 |
+
continue
|
| 160 |
+
meta = json.load(open(meta_path))
|
| 161 |
+
if not {'pressure', 'eyetrack', 'mocap'}.issubset(set(meta['modalities'])):
|
| 162 |
+
continue
|
| 163 |
+
|
| 164 |
+
p = load_pressure(sd)
|
| 165 |
+
g = load_gaze(sd)
|
| 166 |
+
r_wrist, _ = load_mocap_hand(sd, vol, scenario)
|
| 167 |
+
if p is None or g is None or r_wrist is None:
|
| 168 |
+
continue
|
| 169 |
+
r_p, _ = p
|
| 170 |
+
min_len = min(len(r_p), len(g), len(r_wrist))
|
| 171 |
+
r_p, g, r_wrist = r_p[:min_len], g[:min_len], r_wrist[:min_len]
|
| 172 |
+
|
| 173 |
+
hand_vel = compute_velocity(r_wrist)
|
| 174 |
+
gvel = gaze_velocity(g)
|
| 175 |
+
total_p = r_p.sum(axis=1)
|
| 176 |
+
|
| 177 |
+
onsets = detect_grasp_events(r_p)
|
| 178 |
+
for o in onsets:
|
| 179 |
+
if o < context or o + after >= min_len:
|
| 180 |
+
continue
|
| 181 |
+
# Require quiescent pre-grasp
|
| 182 |
+
rest_window = gvel[o-150:o-100]
|
| 183 |
+
vel_rest = hand_vel[o-150:o-100]
|
| 184 |
+
if np.mean(vel_rest) > hand_vel[o-50:o].mean() * 0.5:
|
| 185 |
+
continue
|
| 186 |
+
gv_seg = gvel[o-context:o+after]
|
| 187 |
+
hv_seg = hand_vel[o-context:o+after]
|
| 188 |
+
pr_seg = total_p[o-context:o+after]
|
| 189 |
+
if len(gv_seg) != context+after or np.isnan(gv_seg).any():
|
| 190 |
+
continue
|
| 191 |
+
events.append({'gv': gv_seg, 'hv': hv_seg, 'p': pr_seg})
|
| 192 |
+
if len(events) > 400:
|
| 193 |
+
break
|
| 194 |
+
if len(events) > 400:
|
| 195 |
+
break
|
| 196 |
+
|
| 197 |
+
print(f" Collected {len(events)} events")
|
| 198 |
+
if len(events) < 50:
|
| 199 |
+
print(" Not enough events, skipping")
|
| 200 |
+
return
|
| 201 |
+
|
| 202 |
+
# Gaze: fixation = low gaze velocity, so use "1 - normalized gaze velocity"
|
| 203 |
+
# This represents "gaze fixation stability"
|
| 204 |
+
def norm01(arr):
|
| 205 |
+
arr = np.array(arr)
|
| 206 |
+
arr = arr - arr.min(axis=1, keepdims=True)
|
| 207 |
+
mx = arr.max(axis=1, keepdims=True)
|
| 208 |
+
return arr / (mx + 1e-8)
|
| 209 |
+
|
| 210 |
+
gv_stack = norm01([e['gv'] for e in events])
|
| 211 |
+
hv_stack = norm01([e['hv'] for e in events])
|
| 212 |
+
p_stack = norm01([e['p'] for e in events])
|
| 213 |
+
|
| 214 |
+
# Smooth gaze to show fixation trend
|
| 215 |
+
# Gaze fixation = low velocity. Plot (1 - gaze_velocity) -> rises as gaze fixates
|
| 216 |
+
gaze_fix = 1 - gv_stack # high = fixating
|
| 217 |
+
# Normalize each event's fix to [0,1] for display
|
| 218 |
+
gaze_fix_plot = norm01(gaze_fix)
|
| 219 |
+
|
| 220 |
+
time_axis = np.arange(-context, after) * 10 # ms
|
| 221 |
+
|
| 222 |
+
fig, ax = plt.subplots(figsize=(9, 4.5))
|
| 223 |
+
|
| 224 |
+
for stack, color, label in [
|
| 225 |
+
(gaze_fix_plot, '#8E44AD', 'Gaze fixation'),
|
| 226 |
+
(hv_stack, '#3498DB', 'Hand velocity'),
|
| 227 |
+
(p_stack, '#27AE60', 'Pressure (contact)'),
|
| 228 |
+
]:
|
| 229 |
+
mean = stack.mean(axis=0)
|
| 230 |
+
std = stack.std(axis=0)
|
| 231 |
+
ax.plot(time_axis, mean, color=color, linewidth=2.5, label=label)
|
| 232 |
+
ax.fill_between(time_axis, mean - std*0.4, mean + std*0.4, color=color, alpha=0.15)
|
| 233 |
+
|
| 234 |
+
ax.axvline(0, color='black', linestyle='--', linewidth=1.2, alpha=0.7)
|
| 235 |
+
ax.set_xlabel('Time relative to contact onset (ms)', fontsize=12)
|
| 236 |
+
ax.set_ylabel('Normalized amplitude', fontsize=12)
|
| 237 |
+
ax.set_title(f'Gaze → Hand → Contact coordination ({len(events)} events)',
|
| 238 |
+
fontsize=13, fontweight='bold')
|
| 239 |
+
ax.set_xlim(-2000, 500)
|
| 240 |
+
ax.legend(loc='upper left', fontsize=10, frameon=True)
|
| 241 |
+
ax.grid(True, alpha=0.3)
|
| 242 |
+
ax.set_ylim(-0.05, 1.1)
|
| 243 |
+
|
| 244 |
+
plt.tight_layout()
|
| 245 |
+
out_path = os.path.join(OUT_DIR, 'eye_hand_contact.pdf')
|
| 246 |
+
plt.savefig(out_path, dpi=150, bbox_inches='tight')
|
| 247 |
+
plt.savefig(out_path.replace('.pdf', '.png'), dpi=150, bbox_inches='tight')
|
| 248 |
+
plt.close()
|
| 249 |
+
print(f" Saved {out_path}")
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
# ============================================================
|
| 253 |
+
# FIGURE 2: Pressure fingerprints per action category
|
| 254 |
+
# ============================================================
|
| 255 |
+
def make_pressure_fingerprints():
|
| 256 |
+
print("\n=== Figure 2: Pressure fingerprints ===")
|
| 257 |
+
import sys
|
| 258 |
+
sys.path.insert(0, '${PULSE_ROOT}')
|
| 259 |
+
from experiments.train_exp2 import load_annotations
|
| 260 |
+
|
| 261 |
+
# For each action class, accumulate mean pressure profile (50 channels)
|
| 262 |
+
action_r_sum = {} # action -> (sum 25 channels, count)
|
| 263 |
+
action_l_sum = {}
|
| 264 |
+
|
| 265 |
+
for vol_dir in sorted(glob.glob(f"{DATASET}/v*")):
|
| 266 |
+
vol = os.path.basename(vol_dir)
|
| 267 |
+
for sd in sorted(glob.glob(f"{vol_dir}/s*")):
|
| 268 |
+
scenario = os.path.basename(sd)
|
| 269 |
+
meta_path = os.path.join(sd, "alignment_metadata.json")
|
| 270 |
+
if not os.path.exists(meta_path):
|
| 271 |
+
continue
|
| 272 |
+
meta = json.load(open(meta_path))
|
| 273 |
+
if 'pressure' not in set(meta['modalities']):
|
| 274 |
+
continue
|
| 275 |
+
p = load_pressure(sd)
|
| 276 |
+
if p is None:
|
| 277 |
+
continue
|
| 278 |
+
r_p, l_p = p
|
| 279 |
+
labels = load_annotations(vol, scenario, len(r_p), sampling_rate=100, use_coarse=False)
|
| 280 |
+
if labels is None:
|
| 281 |
+
continue
|
| 282 |
+
labels = labels[:len(r_p)]
|
| 283 |
+
from experiments.train_exp2 import ACTION_NAMES
|
| 284 |
+
for a_id, a_name in ACTION_NAMES.items():
|
| 285 |
+
if a_name == 'Idle':
|
| 286 |
+
continue
|
| 287 |
+
mask = labels == a_id
|
| 288 |
+
if mask.sum() < 10:
|
| 289 |
+
continue
|
| 290 |
+
r_mean = r_p[mask].mean(axis=0)
|
| 291 |
+
l_mean = l_p[mask].mean(axis=0)
|
| 292 |
+
if a_name not in action_r_sum:
|
| 293 |
+
action_r_sum[a_name] = [np.zeros(25), 0]
|
| 294 |
+
action_l_sum[a_name] = [np.zeros(25), 0]
|
| 295 |
+
action_r_sum[a_name][0] += r_mean * mask.sum()
|
| 296 |
+
action_r_sum[a_name][1] += mask.sum()
|
| 297 |
+
action_l_sum[a_name][0] += l_mean * mask.sum()
|
| 298 |
+
action_l_sum[a_name][1] += mask.sum()
|
| 299 |
+
|
| 300 |
+
# Compute mean for each action
|
| 301 |
+
results = {}
|
| 302 |
+
for a_name in action_r_sum:
|
| 303 |
+
r_cnt = action_r_sum[a_name][1]
|
| 304 |
+
l_cnt = action_l_sum[a_name][1]
|
| 305 |
+
if r_cnt == 0 or l_cnt == 0:
|
| 306 |
+
continue
|
| 307 |
+
results[a_name] = {
|
| 308 |
+
'r': action_r_sum[a_name][0] / r_cnt,
|
| 309 |
+
'l': action_l_sum[a_name][0] / l_cnt,
|
| 310 |
+
}
|
| 311 |
+
print(f" Action categories: {list(results.keys())}")
|
| 312 |
+
|
| 313 |
+
if not results:
|
| 314 |
+
print(" No data")
|
| 315 |
+
return
|
| 316 |
+
|
| 317 |
+
# Pick top 6 by frequency (they have most data)
|
| 318 |
+
# Sort by right-hand count
|
| 319 |
+
sorted_actions = sorted(results.keys(),
|
| 320 |
+
key=lambda a: action_r_sum[a][1], reverse=True)[:6]
|
| 321 |
+
|
| 322 |
+
# Plot as 2-row grid: top row = right hand, bottom row = left hand (or combine as single image)
|
| 323 |
+
# Use 25 points arranged as a 5x5 grid (stylized hand layout)
|
| 324 |
+
# Actual finger layout is complex; for visualization use simple grid
|
| 325 |
+
# Layout (rough hand analogy): arrange as fingertips at top, palm base at bottom
|
| 326 |
+
# Index mapping — 25 points, organized heuristically:
|
| 327 |
+
# row 0 (fingertips): 1-5
|
| 328 |
+
# row 1-2: finger segments
|
| 329 |
+
# row 3-4: palm area
|
| 330 |
+
def point_to_xy(idx):
|
| 331 |
+
"""Map channel index (0-24) to 2D hand position (stylized)."""
|
| 332 |
+
# Simple 5x5 grid
|
| 333 |
+
row = idx // 5
|
| 334 |
+
col = idx % 5
|
| 335 |
+
return col, 4 - row # flip y so fingertips at top
|
| 336 |
+
|
| 337 |
+
n = len(sorted_actions)
|
| 338 |
+
fig, axes = plt.subplots(2, n, figsize=(2.0 * n, 4.8), squeeze=False)
|
| 339 |
+
vmax = max(max(results[a]['r'].max(), results[a]['l'].max()) for a in sorted_actions)
|
| 340 |
+
|
| 341 |
+
for i, a in enumerate(sorted_actions):
|
| 342 |
+
for row, (hand, title) in enumerate([('r', 'Right'), ('l', 'Left')]):
|
| 343 |
+
ax = axes[row][i]
|
| 344 |
+
data = results[a][hand]
|
| 345 |
+
grid = np.zeros((5, 5))
|
| 346 |
+
for idx, v in enumerate(data):
|
| 347 |
+
x, y = point_to_xy(idx)
|
| 348 |
+
grid[4-y, x] = v
|
| 349 |
+
im = ax.imshow(grid, cmap='hot', vmin=0, vmax=vmax, aspect='equal')
|
| 350 |
+
ax.set_xticks([]); ax.set_yticks([])
|
| 351 |
+
if row == 0:
|
| 352 |
+
ax.set_title(a, fontsize=11, fontweight='bold')
|
| 353 |
+
if i == 0:
|
| 354 |
+
ax.set_ylabel(title, fontsize=10)
|
| 355 |
+
|
| 356 |
+
fig.suptitle('Per-action fingertip pressure signatures (mean across events)',
|
| 357 |
+
fontsize=12, fontweight='bold', y=0.98)
|
| 358 |
+
cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.7, pad=0.02)
|
| 359 |
+
cbar.set_label('Pressure (g)', fontsize=10)
|
| 360 |
+
plt.savefig(os.path.join(OUT_DIR, 'pressure_fingerprints.pdf'), bbox_inches='tight')
|
| 361 |
+
plt.savefig(os.path.join(OUT_DIR, 'pressure_fingerprints.png'), dpi=150, bbox_inches='tight')
|
| 362 |
+
plt.close()
|
| 363 |
+
print(f" Saved pressure_fingerprints.pdf")
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
# ============================================================
|
| 367 |
+
# FIGURE 3: 3D hand trajectory colored by pressure
|
| 368 |
+
# ============================================================
|
| 369 |
+
def make_3d_trajectory():
|
| 370 |
+
print("\n=== Figure 3: 3D hand trajectory + pressure coloring ===")
|
| 371 |
+
from mpl_toolkits.mplot3d import Axes3D
|
| 372 |
+
# Pick a few illustrative recordings with rich grasping — use v1 s3 (kitchen) or similar
|
| 373 |
+
candidates = [('v1', 's3'), ('v2', 's4'), ('v1', 's5'), ('v1', 's7')]
|
| 374 |
+
picked = []
|
| 375 |
+
|
| 376 |
+
for vol, scn in candidates:
|
| 377 |
+
sd = f"{DATASET}/{vol}/{scn}"
|
| 378 |
+
if not os.path.isdir(sd):
|
| 379 |
+
continue
|
| 380 |
+
p = load_pressure(sd)
|
| 381 |
+
r_wrist, _ = load_mocap_hand(sd, vol, scn)
|
| 382 |
+
if p is None or r_wrist is None:
|
| 383 |
+
continue
|
| 384 |
+
r_p, _ = p
|
| 385 |
+
min_len = min(len(r_p), len(r_wrist))
|
| 386 |
+
total_p = r_p[:min_len].sum(axis=1)
|
| 387 |
+
r_wrist = r_wrist[:min_len]
|
| 388 |
+
# Take a window that contains a grasp
|
| 389 |
+
onsets = detect_grasp_events(r_p[:min_len])
|
| 390 |
+
if not onsets:
|
| 391 |
+
continue
|
| 392 |
+
# Take ~3s centred on first onset
|
| 393 |
+
o = onsets[0]
|
| 394 |
+
start = max(0, o - 150)
|
| 395 |
+
end = min(min_len, o + 150)
|
| 396 |
+
traj = r_wrist[start:end]
|
| 397 |
+
pressure = total_p[start:end]
|
| 398 |
+
picked.append((vol, scn, traj, pressure))
|
| 399 |
+
if len(picked) >= 3:
|
| 400 |
+
break
|
| 401 |
+
|
| 402 |
+
if not picked:
|
| 403 |
+
print(" No valid recordings found")
|
| 404 |
+
return
|
| 405 |
+
|
| 406 |
+
fig = plt.figure(figsize=(3.5 * len(picked), 4))
|
| 407 |
+
for i, (vol, scn, traj, pr) in enumerate(picked):
|
| 408 |
+
ax = fig.add_subplot(1, len(picked), i+1, projection='3d')
|
| 409 |
+
# Normalize pressure for coloring
|
| 410 |
+
pr_norm = pr / (pr.max() + 1e-6)
|
| 411 |
+
# Plot as colored line segments
|
| 412 |
+
for j in range(len(traj) - 1):
|
| 413 |
+
x = traj[j:j+2, 0]
|
| 414 |
+
y = traj[j:j+2, 1]
|
| 415 |
+
z = traj[j:j+2, 2]
|
| 416 |
+
c = plt.cm.coolwarm(pr_norm[j])
|
| 417 |
+
ax.plot(x, y, z, color=c, linewidth=2.5, alpha=0.85)
|
| 418 |
+
# Mark contact point
|
| 419 |
+
contact_idx = np.argmax(pr)
|
| 420 |
+
ax.scatter(traj[contact_idx, 0], traj[contact_idx, 1], traj[contact_idx, 2],
|
| 421 |
+
color='red', s=50, marker='*', zorder=5, label='Peak contact')
|
| 422 |
+
ax.set_title(f'{vol}/{scn}', fontsize=10)
|
| 423 |
+
ax.set_xlabel('X', fontsize=8); ax.set_ylabel('Y', fontsize=8); ax.set_zlabel('Z', fontsize=8)
|
| 424 |
+
ax.tick_params(labelsize=7)
|
| 425 |
+
|
| 426 |
+
# Colorbar
|
| 427 |
+
sm = plt.cm.ScalarMappable(cmap='coolwarm', norm=matplotlib.colors.Normalize(vmin=0, vmax=1))
|
| 428 |
+
sm.set_array([])
|
| 429 |
+
cbar = fig.colorbar(sm, ax=fig.axes, shrink=0.6, pad=0.02)
|
| 430 |
+
cbar.set_label('Normalised pressure', fontsize=10)
|
| 431 |
+
|
| 432 |
+
fig.suptitle('Right-hand wrist 3D trajectory coloured by fingertip pressure',
|
| 433 |
+
fontsize=12, fontweight='bold', y=1.02)
|
| 434 |
+
plt.savefig(os.path.join(OUT_DIR, 'hand_trajectory_3d.pdf'), bbox_inches='tight')
|
| 435 |
+
plt.savefig(os.path.join(OUT_DIR, 'hand_trajectory_3d.png'), dpi=150, bbox_inches='tight')
|
| 436 |
+
plt.close()
|
| 437 |
+
print(f" Saved hand_trajectory_3d.pdf")
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
if __name__ == '__main__':
|
| 441 |
+
make_eye_hand_contact_figure()
|
| 442 |
+
make_pressure_fingerprints()
|
| 443 |
+
make_3d_trajectory()
|
| 444 |
+
print("\nAll figures generated in", OUT_DIR)
|
experiments/analysis/build_taxonomy.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Rebuild the frozen taxonomy JSON from the current annotations_v3/ state.
|
| 4 |
+
|
| 5 |
+
Run this *once* after annotation is complete to lock the 28+ noun list. Later
|
| 6 |
+
experiments load the frozen list via taxonomy.py, so class indices don't
|
| 7 |
+
drift if more annotations are ever added.
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
python3 experiments/build_taxonomy.py
|
| 11 |
+
python3 experiments/build_taxonomy.py --threshold 50 --out experiments/taxonomy_v3.json
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import glob
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
from collections import Counter
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
REPO = Path(__file__).resolve().parents[1]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def main():
|
| 25 |
+
ap = argparse.ArgumentParser()
|
| 26 |
+
ap.add_argument(
|
| 27 |
+
"--annotations_dir",
|
| 28 |
+
default=str(REPO / "annotations_v3"),
|
| 29 |
+
help="Directory containing v*/s*.json annotation files",
|
| 30 |
+
)
|
| 31 |
+
ap.add_argument("--threshold", type=int, default=50,
|
| 32 |
+
help="Minimum noun frequency to keep (Strategy A drops the rest)")
|
| 33 |
+
ap.add_argument(
|
| 34 |
+
"--out",
|
| 35 |
+
default=str(REPO / "experiments" / "taxonomy_v3.json"),
|
| 36 |
+
help="Output frozen taxonomy JSON",
|
| 37 |
+
)
|
| 38 |
+
args = ap.parse_args()
|
| 39 |
+
|
| 40 |
+
# Late import so building the list doesn't depend on the frozen file
|
| 41 |
+
# being present yet.
|
| 42 |
+
import sys
|
| 43 |
+
sys.path.insert(0, str(REPO))
|
| 44 |
+
from experiments.taxonomy import (
|
| 45 |
+
VERB_FINE, VERB_COMPOSITE, HAND, NOUN_CANONICAL, canonical_noun,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
paths = sorted(glob.glob(os.path.join(args.annotations_dir, "v*", "s*.json")))
|
| 49 |
+
if not paths:
|
| 50 |
+
raise SystemExit(f"No json files under {args.annotations_dir}")
|
| 51 |
+
|
| 52 |
+
verbs, nouns, hands = Counter(), Counter(), Counter()
|
| 53 |
+
total = 0
|
| 54 |
+
dropped_unknown_verb = 0
|
| 55 |
+
dropped_unknown_hand = 0
|
| 56 |
+
for p in paths:
|
| 57 |
+
try:
|
| 58 |
+
with open(p) as f:
|
| 59 |
+
d = json.load(f)
|
| 60 |
+
except Exception as e:
|
| 61 |
+
print(f" WARN: could not parse {p}: {e}")
|
| 62 |
+
continue
|
| 63 |
+
for s in d.get("segments", []):
|
| 64 |
+
a = s.get("action_annotation", {})
|
| 65 |
+
v = a.get("action_name")
|
| 66 |
+
n = a.get("object_name")
|
| 67 |
+
h = a.get("hand_type")
|
| 68 |
+
if not (v and n and h):
|
| 69 |
+
continue
|
| 70 |
+
total += 1
|
| 71 |
+
if v not in VERB_FINE:
|
| 72 |
+
dropped_unknown_verb += 1
|
| 73 |
+
continue
|
| 74 |
+
if h not in HAND:
|
| 75 |
+
dropped_unknown_hand += 1
|
| 76 |
+
continue
|
| 77 |
+
verbs[v] += 1
|
| 78 |
+
nouns[canonical_noun(n)] += 1
|
| 79 |
+
hands[h] += 1
|
| 80 |
+
|
| 81 |
+
kept = [n for n, c in nouns.most_common() if c >= args.threshold]
|
| 82 |
+
|
| 83 |
+
# Stable alphabetical ordering within kept-set, so re-runs that swap two
|
| 84 |
+
# near-tie classes don't flip indices.
|
| 85 |
+
kept = sorted(kept, key=lambda n: (-nouns[n], n))
|
| 86 |
+
|
| 87 |
+
surviving_segs = 0
|
| 88 |
+
for p in paths:
|
| 89 |
+
with open(p) as f:
|
| 90 |
+
d = json.load(f)
|
| 91 |
+
for s in d.get("segments", []):
|
| 92 |
+
a = s.get("action_annotation", {})
|
| 93 |
+
v = a.get("action_name")
|
| 94 |
+
n = a.get("object_name")
|
| 95 |
+
h = a.get("hand_type")
|
| 96 |
+
if not (v and n and h):
|
| 97 |
+
continue
|
| 98 |
+
if v not in VERB_FINE or h not in HAND:
|
| 99 |
+
continue
|
| 100 |
+
if canonical_noun(n) not in kept:
|
| 101 |
+
continue
|
| 102 |
+
surviving_segs += 1
|
| 103 |
+
|
| 104 |
+
out = {
|
| 105 |
+
"threshold": args.threshold,
|
| 106 |
+
"annotation_file_count": len(paths),
|
| 107 |
+
"total_segments": total,
|
| 108 |
+
"dropped_unknown_verb": dropped_unknown_verb,
|
| 109 |
+
"dropped_unknown_hand": dropped_unknown_hand,
|
| 110 |
+
"surviving_segments": surviving_segs,
|
| 111 |
+
"verbs": VERB_FINE,
|
| 112 |
+
"verb_composite": VERB_COMPOSITE,
|
| 113 |
+
"hand": HAND,
|
| 114 |
+
"nouns": kept,
|
| 115 |
+
"noun_counts": {n: nouns[n] for n in kept},
|
| 116 |
+
"verb_counts": dict(verbs),
|
| 117 |
+
"hand_counts": dict(hands),
|
| 118 |
+
}
|
| 119 |
+
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
|
| 120 |
+
with open(args.out, "w") as f:
|
| 121 |
+
json.dump(out, f, ensure_ascii=False, indent=2)
|
| 122 |
+
|
| 123 |
+
print(f"Scanned {len(paths)} files, {total} segments")
|
| 124 |
+
print(f"Dropped (unknown verb / hand): {dropped_unknown_verb} / "
|
| 125 |
+
f"{dropped_unknown_hand}")
|
| 126 |
+
print(f"Kept {len(kept)} nouns (>= {args.threshold}):")
|
| 127 |
+
for n in kept:
|
| 128 |
+
print(f" {n}: {nouns[n]}")
|
| 129 |
+
print(f"Surviving segments (Strategy A): "
|
| 130 |
+
f"{surviving_segs} / {total} "
|
| 131 |
+
f"({100 * surviving_segs / max(1, total):.1f}%)")
|
| 132 |
+
print(f"Wrote {args.out}")
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
main()
|
experiments/analysis/check_seg_lengths.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Analyze segment lengths in the recognition dataset.
|
| 4 |
+
|
| 5 |
+
For each annotation file, computes segment lengths in:
|
| 6 |
+
- Raw frames (at 100Hz sampling rate)
|
| 7 |
+
- Downsampled frames (downsample=5 -> 20Hz effective)
|
| 8 |
+
|
| 9 |
+
Reports statistics and distribution relative to window_frames used in training.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import json
|
| 15 |
+
import re
|
| 16 |
+
import numpy as np
|
| 17 |
+
from collections import defaultdict
|
| 18 |
+
|
| 19 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 20 |
+
from data.dataset import DATASET_DIR, TRAIN_VOLS, VAL_VOLS, TEST_VOLS
|
| 21 |
+
|
| 22 |
+
ANNOTATION_DIR = "${PULSE_ROOT}"
|
| 23 |
+
SAMPLING_RATE = 100 # Hz
|
| 24 |
+
DOWNSAMPLE = 5
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def parse_timestamp(ts_str):
|
| 28 |
+
parts = ts_str.strip().split(':')
|
| 29 |
+
if len(parts) == 2:
|
| 30 |
+
return int(parts[0]) * 60 + int(parts[1])
|
| 31 |
+
elif len(parts) == 3:
|
| 32 |
+
return int(parts[0]) * 3600 + int(parts[1]) * 60 + int(parts[2])
|
| 33 |
+
return 0
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def main():
|
| 37 |
+
all_vols = TRAIN_VOLS + VAL_VOLS + TEST_VOLS
|
| 38 |
+
|
| 39 |
+
# Collect segment lengths
|
| 40 |
+
raw_lengths_sec = [] # in seconds
|
| 41 |
+
raw_lengths_frames = [] # in raw 100Hz frames
|
| 42 |
+
ds_lengths_frames = [] # in downsampled frames (100/5 = 20Hz)
|
| 43 |
+
|
| 44 |
+
split_stats = defaultdict(list) # split -> list of ds_lengths
|
| 45 |
+
|
| 46 |
+
total_scenarios = 0
|
| 47 |
+
total_segments = 0
|
| 48 |
+
skipped_segments = 0
|
| 49 |
+
|
| 50 |
+
for vol in sorted(all_vols):
|
| 51 |
+
# Determine split
|
| 52 |
+
if vol in TRAIN_VOLS:
|
| 53 |
+
split = 'train'
|
| 54 |
+
elif vol in VAL_VOLS:
|
| 55 |
+
split = 'val'
|
| 56 |
+
else:
|
| 57 |
+
split = 'test'
|
| 58 |
+
|
| 59 |
+
ann_vol_dir = os.path.join(ANNOTATION_DIR, vol)
|
| 60 |
+
if not os.path.isdir(ann_vol_dir):
|
| 61 |
+
print(f"WARNING: No annotation dir for {vol}")
|
| 62 |
+
continue
|
| 63 |
+
|
| 64 |
+
for ann_file in sorted(os.listdir(ann_vol_dir)):
|
| 65 |
+
if not ann_file.endswith('.json'):
|
| 66 |
+
continue
|
| 67 |
+
scenario = ann_file.replace('.json', '')
|
| 68 |
+
ann_path = os.path.join(ann_vol_dir, ann_file)
|
| 69 |
+
|
| 70 |
+
# Also check that corresponding dataset dir exists
|
| 71 |
+
scenario_dir = os.path.join(DATASET_DIR, vol, scenario)
|
| 72 |
+
if not os.path.isdir(scenario_dir):
|
| 73 |
+
continue
|
| 74 |
+
|
| 75 |
+
with open(ann_path) as f:
|
| 76 |
+
ann = json.load(f)
|
| 77 |
+
|
| 78 |
+
total_scenarios += 1
|
| 79 |
+
|
| 80 |
+
for seg in ann.get('segments', []):
|
| 81 |
+
m = re.match(r'(\d+:\d+(?::\d+)?)\s*-\s*(\d+:\d+(?::\d+)?)',
|
| 82 |
+
seg['timestamp'])
|
| 83 |
+
if not m:
|
| 84 |
+
skipped_segments += 1
|
| 85 |
+
continue
|
| 86 |
+
|
| 87 |
+
start_sec = parse_timestamp(m.group(1))
|
| 88 |
+
end_sec = parse_timestamp(m.group(2))
|
| 89 |
+
|
| 90 |
+
if end_sec <= start_sec:
|
| 91 |
+
skipped_segments += 1
|
| 92 |
+
continue
|
| 93 |
+
|
| 94 |
+
duration_sec = end_sec - start_sec
|
| 95 |
+
raw_frames = duration_sec * SAMPLING_RATE
|
| 96 |
+
ds_frames = int(end_sec * SAMPLING_RATE / DOWNSAMPLE) - int(start_sec * SAMPLING_RATE / DOWNSAMPLE)
|
| 97 |
+
|
| 98 |
+
raw_lengths_sec.append(duration_sec)
|
| 99 |
+
raw_lengths_frames.append(raw_frames)
|
| 100 |
+
ds_lengths_frames.append(ds_frames)
|
| 101 |
+
split_stats[split].append(ds_frames)
|
| 102 |
+
total_segments += 1
|
| 103 |
+
|
| 104 |
+
# Convert to numpy
|
| 105 |
+
raw_sec = np.array(raw_lengths_sec)
|
| 106 |
+
raw_fr = np.array(raw_lengths_frames)
|
| 107 |
+
ds_fr = np.array(ds_lengths_frames)
|
| 108 |
+
|
| 109 |
+
print("=" * 70)
|
| 110 |
+
print("SEGMENT LENGTH ANALYSIS FOR RECOGNITION DATASET")
|
| 111 |
+
print("=" * 70)
|
| 112 |
+
print(f"\nTotal scenarios: {total_scenarios}")
|
| 113 |
+
print(f"Total valid segments: {total_segments}")
|
| 114 |
+
print(f"Skipped segments (bad timestamp): {skipped_segments}")
|
| 115 |
+
print(f"Sampling rate: {SAMPLING_RATE} Hz")
|
| 116 |
+
print(f"Downsample factor: {DOWNSAMPLE}")
|
| 117 |
+
print(f"Effective rate after downsample: {SAMPLING_RATE / DOWNSAMPLE} Hz")
|
| 118 |
+
|
| 119 |
+
# --- Raw seconds ---
|
| 120 |
+
print("\n" + "-" * 70)
|
| 121 |
+
print("SEGMENT DURATION (seconds)")
|
| 122 |
+
print("-" * 70)
|
| 123 |
+
print(f" Min: {raw_sec.min():.1f}s")
|
| 124 |
+
print(f" Max: {raw_sec.max():.1f}s")
|
| 125 |
+
print(f" Mean: {raw_sec.mean():.2f}s")
|
| 126 |
+
print(f" Median: {np.median(raw_sec):.1f}s")
|
| 127 |
+
print(f" Std: {raw_sec.std():.2f}s")
|
| 128 |
+
|
| 129 |
+
# Percentiles
|
| 130 |
+
for p in [5, 10, 25, 50, 75, 90, 95]:
|
| 131 |
+
print(f" P{p:2d}: {np.percentile(raw_sec, p):.1f}s")
|
| 132 |
+
|
| 133 |
+
# --- Raw frames (100Hz) ---
|
| 134 |
+
print("\n" + "-" * 70)
|
| 135 |
+
print("SEGMENT LENGTH (raw frames @ 100Hz)")
|
| 136 |
+
print("-" * 70)
|
| 137 |
+
print(f" Min: {raw_fr.min()}")
|
| 138 |
+
print(f" Max: {raw_fr.max()}")
|
| 139 |
+
print(f" Mean: {raw_fr.mean():.1f}")
|
| 140 |
+
print(f" Median: {np.median(raw_fr):.0f}")
|
| 141 |
+
|
| 142 |
+
# --- Downsampled frames ---
|
| 143 |
+
print("\n" + "-" * 70)
|
| 144 |
+
print(f"SEGMENT LENGTH (downsampled frames @ {SAMPLING_RATE/DOWNSAMPLE:.0f}Hz)")
|
| 145 |
+
print("-" * 70)
|
| 146 |
+
print(f" Min: {ds_fr.min()}")
|
| 147 |
+
print(f" Max: {ds_fr.max()}")
|
| 148 |
+
print(f" Mean: {ds_fr.mean():.1f}")
|
| 149 |
+
print(f" Median: {np.median(ds_fr):.0f}")
|
| 150 |
+
print(f" Std: {ds_fr.std():.1f}")
|
| 151 |
+
|
| 152 |
+
for p in [5, 10, 25, 50, 75, 90, 95]:
|
| 153 |
+
print(f" P{p:2d}: {np.percentile(ds_fr, p):.0f}")
|
| 154 |
+
|
| 155 |
+
# --- Comparison with window_frames ---
|
| 156 |
+
print("\n" + "-" * 70)
|
| 157 |
+
print("COMPARISON WITH window_frames SETTINGS")
|
| 158 |
+
print("-" * 70)
|
| 159 |
+
|
| 160 |
+
# Common window_sec values and their corresponding window_frames
|
| 161 |
+
for window_sec in [5.0, 10.0, 15.0, 20.0, 30.0]:
|
| 162 |
+
wf = int(window_sec * SAMPLING_RATE / DOWNSAMPLE)
|
| 163 |
+
shorter = (ds_fr < wf).sum()
|
| 164 |
+
equal_or_longer = (ds_fr >= wf).sum()
|
| 165 |
+
longer = (ds_fr > wf).sum()
|
| 166 |
+
pct_shorter = 100.0 * shorter / len(ds_fr)
|
| 167 |
+
pct_longer = 100.0 * longer / len(ds_fr)
|
| 168 |
+
print(f"\n window_sec={window_sec:5.1f}s -> window_frames={wf}")
|
| 169 |
+
print(f" Segments SHORTER than window: {shorter:4d} ({pct_shorter:5.1f}%) -> will be PADDED")
|
| 170 |
+
print(f" Segments LONGER than window: {longer:4d} ({pct_longer:5.1f}%) -> will be CENTER-CROPPED")
|
| 171 |
+
|
| 172 |
+
# --- Thresholds in downsampled frames ---
|
| 173 |
+
print("\n" + "-" * 70)
|
| 174 |
+
print("PERCENTAGE SHORTER THAN THRESHOLDS (downsampled frames)")
|
| 175 |
+
print("-" * 70)
|
| 176 |
+
for thresh in [20, 40, 60, 100, 200, 300, 400, 500, 1000, 2000]:
|
| 177 |
+
pct = 100.0 * (ds_fr < thresh).sum() / len(ds_fr)
|
| 178 |
+
print(f" < {thresh:5d} frames ({thresh * DOWNSAMPLE / SAMPLING_RATE:6.1f}s): {pct:5.1f}%")
|
| 179 |
+
|
| 180 |
+
# --- Per-split stats ---
|
| 181 |
+
print("\n" + "-" * 70)
|
| 182 |
+
print("PER-SPLIT STATISTICS (downsampled frames)")
|
| 183 |
+
print("-" * 70)
|
| 184 |
+
for split in ['train', 'val', 'test']:
|
| 185 |
+
arr = np.array(split_stats[split])
|
| 186 |
+
if len(arr) == 0:
|
| 187 |
+
print(f" {split}: no segments")
|
| 188 |
+
continue
|
| 189 |
+
print(f"\n {split.upper()} ({len(arr)} segments):")
|
| 190 |
+
print(f" Min={arr.min()}, Max={arr.max()}, Mean={arr.mean():.1f}, Median={np.median(arr):.0f}")
|
| 191 |
+
|
| 192 |
+
# --- Histogram (text-based) ---
|
| 193 |
+
print("\n" + "-" * 70)
|
| 194 |
+
print("HISTOGRAM OF SEGMENT DURATIONS (seconds)")
|
| 195 |
+
print("-" * 70)
|
| 196 |
+
bins = [0, 1, 2, 3, 4, 5, 7, 10, 15, 20, 30, 60, 120, 300, 600]
|
| 197 |
+
for i in range(len(bins) - 1):
|
| 198 |
+
count = ((raw_sec >= bins[i]) & (raw_sec < bins[i + 1])).sum()
|
| 199 |
+
pct = 100.0 * count / len(raw_sec)
|
| 200 |
+
bar = '#' * int(pct / 2)
|
| 201 |
+
print(f" [{bins[i]:4d}-{bins[i+1]:4d})s: {count:5d} ({pct:5.1f}%) {bar}")
|
| 202 |
+
# Last bin: >= 600
|
| 203 |
+
count = (raw_sec >= bins[-1]).sum()
|
| 204 |
+
pct = 100.0 * count / len(raw_sec)
|
| 205 |
+
bar = '#' * int(pct / 2)
|
| 206 |
+
print(f" [{bins[-1]:4d}+ )s: {count:5d} ({pct:5.1f}%) {bar}")
|
| 207 |
+
|
| 208 |
+
# --- Key insight ---
|
| 209 |
+
print("\n" + "=" * 70)
|
| 210 |
+
print("KEY INSIGHTS")
|
| 211 |
+
print("=" * 70)
|
| 212 |
+
median_sec = np.median(raw_sec)
|
| 213 |
+
mean_sec = raw_sec.mean()
|
| 214 |
+
print(f" Median segment duration: {median_sec:.1f}s ({median_sec * SAMPLING_RATE / DOWNSAMPLE:.0f} ds-frames)")
|
| 215 |
+
print(f" Mean segment duration: {mean_sec:.1f}s ({mean_sec * SAMPLING_RATE / DOWNSAMPLE:.0f} ds-frames)")
|
| 216 |
+
print()
|
| 217 |
+
# Suggest optimal window
|
| 218 |
+
p95_sec = np.percentile(raw_sec, 95)
|
| 219 |
+
print(f" 95th percentile duration: {p95_sec:.1f}s")
|
| 220 |
+
print(f" -> A window of {p95_sec:.0f}s would cover 95% of segments without cropping")
|
| 221 |
+
print(f" -> Current default window_sec=15.0 -> window_frames={int(15.0 * SAMPLING_RATE / DOWNSAMPLE)}")
|
| 222 |
+
wf15 = int(15.0 * SAMPLING_RATE / DOWNSAMPLE)
|
| 223 |
+
pct_crop = 100.0 * (ds_fr > wf15).sum() / len(ds_fr)
|
| 224 |
+
pct_pad = 100.0 * (ds_fr < wf15).sum() / len(ds_fr)
|
| 225 |
+
print(f" {pct_pad:.1f}% segments padded, {pct_crop:.1f}% center-cropped")
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
if __name__ == '__main__':
|
| 229 |
+
main()
|
experiments/analysis/data_statistics_figure.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generate dataset statistics figure from the currently-available annotations.
|
| 2 |
+
|
| 3 |
+
Panels (3):
|
| 4 |
+
(a) Recording duration distribution per scene (boxplot)
|
| 5 |
+
(b) Segment length distribution (histogram)
|
| 6 |
+
(c) Top-20 manipulated objects by segment count
|
| 7 |
+
|
| 8 |
+
Note: panel for motor-primitive frequency is deferred until the 18-primitive
|
| 9 |
+
annotation pipeline (anno.py) is rerun across all recordings.
|
| 10 |
+
"""
|
| 11 |
+
import json, re
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from collections import Counter
|
| 14 |
+
import numpy as np
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
|
| 17 |
+
ANNO_DIR = Path("${PULSE_ROOT}/annotations_by_scene")
|
| 18 |
+
OUT = Path("${PULSE_ROOT}/paper/figures/dataset_stats.pdf")
|
| 19 |
+
|
| 20 |
+
# Chinese -> English object name mapping (from anno.py OBJECT_TRANSLATIONS)
|
| 21 |
+
OBJ_EN = {
|
| 22 |
+
"笔记本电脑": "laptop", "有线鼠标": "wired mouse", "有线键盘": "wired keyboard",
|
| 23 |
+
"马克笔": "marker", "胶带": "tape", "笔记本电源": "laptop power", "折叠伞": "umbrella",
|
| 24 |
+
"剪刀": "scissors", "钱包": "wallet", "纸": "paper", "订书机": "stapler",
|
| 25 |
+
"纸箱": "box", "文件": "document", "架子": "rack", "桌布": "tablecloth", "罐子": "jar",
|
| 26 |
+
"调料瓶": "seasoning bottle", "密封罐": "sealed jar", "厨房纸巾": "kitchen paper",
|
| 27 |
+
"抹布": "cloth", "茶包": "tea bag", "饭碗": "rice bowl", "菜盘": "plate",
|
| 28 |
+
"菜锅": "pot", "勺子": "spoon", "水杯": "water cup", "茶杯": "tea cup",
|
| 29 |
+
"茶壶": "teapot", "食物残渣": "food residue", "垃圾桶": "trash bin",
|
| 30 |
+
"纸巾": "tissue", "餐垫": "placemat", "托盘": "tray", "清洁喷雾": "spray",
|
| 31 |
+
"食物": "food", "电源": "power adapter", "移动硬盘": "HDD", "鼠标": "mouse",
|
| 32 |
+
"笔记本充电器": "laptop charger", "转换插头": "plug adapter", "插线板": "power strip",
|
| 33 |
+
"线材收纳包": "cable organizer", "衬衫": "shirt", "裤子": "pants",
|
| 34 |
+
"牙膏": "toothpaste", "牙刷": "toothbrush", "牙刷盒": "toothbrush case",
|
| 35 |
+
"剃须刀": "razor", "毛巾": "towel", "皮鞋": "shoes", "鞋袋": "shoe bag",
|
| 36 |
+
"耳机": "headphones", "护照套": "passport holder", "证件夹": "ID holder",
|
| 37 |
+
"纸巾包": "tissue pack", "行李箱": "suitcase", "马克杯": "mug",
|
| 38 |
+
"调料罐": "seasoning jar", "茶罐": "tea canister", "外套": "coat",
|
| 39 |
+
"围巾": "scarf", "衣架": "hanger",
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def parse_t(ts: str) -> float:
|
| 44 |
+
parts = ts.split(":")
|
| 45 |
+
if len(parts) == 2: # MM:SS
|
| 46 |
+
m, s = parts
|
| 47 |
+
return int(m) * 60 + int(s)
|
| 48 |
+
h, m, s = parts
|
| 49 |
+
return int(h) * 3600 + int(m) * 60 + int(s)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
durations = {f"S{i}": [] for i in range(1, 9)}
|
| 53 |
+
seg_lengths = []
|
| 54 |
+
objects = Counter()
|
| 55 |
+
|
| 56 |
+
for v_dir in sorted(ANNO_DIR.glob("v*")):
|
| 57 |
+
for jf in sorted(v_dir.glob("s*.json")):
|
| 58 |
+
scene = jf.stem.upper()
|
| 59 |
+
try:
|
| 60 |
+
data = json.loads(jf.read_text())
|
| 61 |
+
except Exception:
|
| 62 |
+
continue
|
| 63 |
+
segs = data.get("segments", [])
|
| 64 |
+
if not segs:
|
| 65 |
+
continue
|
| 66 |
+
max_end = 0
|
| 67 |
+
for seg in segs:
|
| 68 |
+
ts = seg.get("timestamp", "")
|
| 69 |
+
if "-" not in ts:
|
| 70 |
+
continue
|
| 71 |
+
try:
|
| 72 |
+
start, end = ts.split("-")
|
| 73 |
+
s_sec, e_sec = parse_t(start), parse_t(end)
|
| 74 |
+
seg_lengths.append(e_sec - s_sec)
|
| 75 |
+
max_end = max(max_end, e_sec)
|
| 76 |
+
for o in seg.get("objects", []) or []:
|
| 77 |
+
nm = o.get("name") if isinstance(o, dict) else o
|
| 78 |
+
if nm:
|
| 79 |
+
objects[OBJ_EN.get(nm, nm)] += 1
|
| 80 |
+
except Exception:
|
| 81 |
+
continue
|
| 82 |
+
if max_end > 0 and scene in durations:
|
| 83 |
+
durations[scene].append(max_end / 60.0)
|
| 84 |
+
|
| 85 |
+
print(f"Per-scene durations: { {s: len(v) for s, v in durations.items()} }")
|
| 86 |
+
print(f"Total segments: {len(seg_lengths)}")
|
| 87 |
+
print(f"Unique objects: {len(objects)}")
|
| 88 |
+
top_obj = objects.most_common(5)
|
| 89 |
+
print(f"Top objects: {top_obj}")
|
| 90 |
+
|
| 91 |
+
fig, axes = plt.subplots(1, 3, figsize=(12, 3.5))
|
| 92 |
+
|
| 93 |
+
# (a) Duration boxplot per scene
|
| 94 |
+
ax = axes[0]
|
| 95 |
+
scene_order = [f"S{i}" for i in range(1, 9)]
|
| 96 |
+
data = [durations[s] for s in scene_order]
|
| 97 |
+
ax.boxplot(data, tick_labels=scene_order, showfliers=False, patch_artist=True,
|
| 98 |
+
boxprops=dict(facecolor="#b3cde3"))
|
| 99 |
+
ax.set_ylabel("Recording duration (min)")
|
| 100 |
+
ax.set_title("(a) Recording duration per scene")
|
| 101 |
+
ax.grid(axis="y", alpha=0.3)
|
| 102 |
+
|
| 103 |
+
# (b) Segment length histogram
|
| 104 |
+
ax = axes[1]
|
| 105 |
+
seg_arr = np.array(seg_lengths)
|
| 106 |
+
seg_arr = seg_arr[seg_arr <= 10]
|
| 107 |
+
ax.hist(seg_arr, bins=np.arange(0, 11) - 0.5, color="#8c96c6", edgecolor="black")
|
| 108 |
+
ax.set_xlabel("Segment length (s)")
|
| 109 |
+
ax.set_ylabel("Segment count")
|
| 110 |
+
ax.set_title(f"(b) Segment length (n={len(seg_lengths)})")
|
| 111 |
+
ax.set_xticks(range(0, 11))
|
| 112 |
+
ax.grid(axis="y", alpha=0.3)
|
| 113 |
+
|
| 114 |
+
# (c) Top-20 objects
|
| 115 |
+
ax = axes[2]
|
| 116 |
+
objs, ocounts = zip(*objects.most_common(20))
|
| 117 |
+
ax.barh(objs[::-1], ocounts[::-1], color="#74c476")
|
| 118 |
+
ax.set_xlabel("Segment count")
|
| 119 |
+
ax.set_title("(c) Top-20 manipulated objects")
|
| 120 |
+
ax.tick_params(axis="y", labelsize=8)
|
| 121 |
+
ax.grid(axis="x", alpha=0.3)
|
| 122 |
+
|
| 123 |
+
fig.tight_layout()
|
| 124 |
+
fig.savefig(OUT, bbox_inches="tight")
|
| 125 |
+
fig.savefig(str(OUT).replace(".pdf", ".png"), dpi=140, bbox_inches="tight")
|
| 126 |
+
print(f"Saved: {OUT}")
|
experiments/analysis/exp_per_subject.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Experiment G: Per-subject diagnostic analysis.
|
| 4 |
+
|
| 5 |
+
Load the best scene-recognition checkpoint(s) from previous T1 runs and
|
| 6 |
+
produce a per-test-volunteer breakdown of F1 and Accuracy. Reveals whether
|
| 7 |
+
aggregate metrics are driven by one or two outlier subjects, as reviewers
|
| 8 |
+
often ask.
|
| 9 |
+
|
| 10 |
+
Runs CPU-side; no training.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
import json
|
| 16 |
+
import glob
|
| 17 |
+
import argparse
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
from sklearn.metrics import accuracy_score, f1_score
|
| 21 |
+
|
| 22 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 23 |
+
from data.dataset import (
|
| 24 |
+
MultimodalSceneDataset, TEST_VOLS, SCENE_LABELS, NUM_CLASSES,
|
| 25 |
+
get_dataloaders,
|
| 26 |
+
)
|
| 27 |
+
from nets.models import build_model
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def per_subject_eval(model, device, modalities, stats, downsample):
|
| 31 |
+
"""Evaluate one model across each test volunteer separately."""
|
| 32 |
+
breakdown = {}
|
| 33 |
+
for vol in TEST_VOLS:
|
| 34 |
+
ds = MultimodalSceneDataset([vol], modalities, downsample=downsample,
|
| 35 |
+
stats=stats)
|
| 36 |
+
if len(ds) == 0:
|
| 37 |
+
breakdown[vol] = {'n': 0}
|
| 38 |
+
continue
|
| 39 |
+
preds, ys = [], []
|
| 40 |
+
model.eval()
|
| 41 |
+
with torch.no_grad():
|
| 42 |
+
for i in range(len(ds)):
|
| 43 |
+
x, y = ds[i]
|
| 44 |
+
x = x.to(device).unsqueeze(0)
|
| 45 |
+
mask = torch.ones(1, x.size(1), dtype=torch.bool).to(device)
|
| 46 |
+
logits = model(x, mask)
|
| 47 |
+
preds.append(logits.argmax(dim=1).cpu().item())
|
| 48 |
+
ys.append(y)
|
| 49 |
+
breakdown[vol] = {
|
| 50 |
+
'n': len(ds),
|
| 51 |
+
'acc': float(accuracy_score(ys, preds)),
|
| 52 |
+
'f1': float(f1_score(ys, preds, average='macro', zero_division=0)),
|
| 53 |
+
'preds': preds,
|
| 54 |
+
'labels': ys,
|
| 55 |
+
'samples': ds.sample_info,
|
| 56 |
+
}
|
| 57 |
+
return breakdown
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def run_on_checkpoint(ckpt_path, args_json_path, output_dir):
|
| 61 |
+
ckpt_args = json.load(open(args_json_path))['args']
|
| 62 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 63 |
+
modalities = ckpt_args['modalities'] if isinstance(ckpt_args['modalities'], list) \
|
| 64 |
+
else ckpt_args['modalities'].split(',')
|
| 65 |
+
downsample = ckpt_args.get('downsample', 5)
|
| 66 |
+
# Get train stats
|
| 67 |
+
_, _, _, info = get_dataloaders(modalities,
|
| 68 |
+
batch_size=ckpt_args.get('batch_size', 16),
|
| 69 |
+
downsample=downsample)
|
| 70 |
+
# Need the actual stats object -- re-load train set to compute
|
| 71 |
+
tr_ds = MultimodalSceneDataset(
|
| 72 |
+
__import__('experiments.dataset', fromlist=['TRAIN_VOLS']).TRAIN_VOLS,
|
| 73 |
+
modalities, downsample=downsample)
|
| 74 |
+
stats = tr_ds.get_stats()
|
| 75 |
+
|
| 76 |
+
model = build_model(
|
| 77 |
+
ckpt_args.get('model', 'transformer'),
|
| 78 |
+
ckpt_args.get('fusion', 'late'),
|
| 79 |
+
info['feat_dim'], info['modality_dims'], NUM_CLASSES,
|
| 80 |
+
hidden_dim=ckpt_args.get('hidden_dim', 128),
|
| 81 |
+
proj_dim=ckpt_args.get('proj_dim', 0),
|
| 82 |
+
late_agg=ckpt_args.get('late_agg', 'mean'),
|
| 83 |
+
).to(device)
|
| 84 |
+
try:
|
| 85 |
+
sd = torch.load(ckpt_path, weights_only=True, map_location=device)
|
| 86 |
+
except Exception:
|
| 87 |
+
sd = torch.load(ckpt_path, map_location=device)
|
| 88 |
+
model.load_state_dict(sd, strict=False)
|
| 89 |
+
|
| 90 |
+
breakdown = per_subject_eval(model, device, modalities, stats, downsample)
|
| 91 |
+
|
| 92 |
+
# Overall F1
|
| 93 |
+
all_preds, all_ys = [], []
|
| 94 |
+
for v, info_v in breakdown.items():
|
| 95 |
+
if info_v.get('n', 0) > 0:
|
| 96 |
+
all_preds.extend(info_v['preds'])
|
| 97 |
+
all_ys.extend(info_v['labels'])
|
| 98 |
+
overall_f1 = float(f1_score(all_ys, all_preds, average='macro', zero_division=0))
|
| 99 |
+
overall_acc = float(accuracy_score(all_ys, all_preds))
|
| 100 |
+
|
| 101 |
+
# Per-subject summary
|
| 102 |
+
summary = {
|
| 103 |
+
'ckpt': ckpt_path,
|
| 104 |
+
'modalities': modalities,
|
| 105 |
+
'overall': {'acc': overall_acc, 'f1': overall_f1,
|
| 106 |
+
'n': len(all_preds)},
|
| 107 |
+
'per_subject': {
|
| 108 |
+
v: {'n': b.get('n'), 'acc': b.get('acc'), 'f1': b.get('f1')}
|
| 109 |
+
for v, b in breakdown.items()
|
| 110 |
+
},
|
| 111 |
+
'detail': breakdown,
|
| 112 |
+
}
|
| 113 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 114 |
+
out_path = os.path.join(output_dir, os.path.basename(
|
| 115 |
+
os.path.dirname(ckpt_path)) + '_per_subject.json')
|
| 116 |
+
with open(out_path, 'w') as f:
|
| 117 |
+
json.dump(summary, f, indent=2)
|
| 118 |
+
print(f"Per-subject breakdown saved: {out_path}")
|
| 119 |
+
print(f"Overall F1: {overall_f1:.4f} Acc: {overall_acc:.4f}")
|
| 120 |
+
for v, b in summary['per_subject'].items():
|
| 121 |
+
print(f" {v}: n={b['n']} acc={b.get('acc'):.3f} f1={b.get('f1'):.3f}"
|
| 122 |
+
if b.get('n') else f" {v}: (empty)")
|
| 123 |
+
return summary
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def main():
|
| 127 |
+
p = argparse.ArgumentParser()
|
| 128 |
+
p.add_argument('--exp_root', type=str, required=True,
|
| 129 |
+
help='Directory containing run subdirs with model_best.pt and results.json')
|
| 130 |
+
p.add_argument('--output_dir', type=str, required=True)
|
| 131 |
+
args = p.parse_args()
|
| 132 |
+
|
| 133 |
+
runs = []
|
| 134 |
+
for sub in sorted(os.listdir(args.exp_root)):
|
| 135 |
+
if sub == 'slurm_logs':
|
| 136 |
+
continue
|
| 137 |
+
ckpt = os.path.join(args.exp_root, sub, 'model_best.pt')
|
| 138 |
+
res = os.path.join(args.exp_root, sub, 'results.json')
|
| 139 |
+
if os.path.exists(ckpt) and os.path.exists(res):
|
| 140 |
+
runs.append((ckpt, res))
|
| 141 |
+
print(f"Found {len(runs)} runs with checkpoints.")
|
| 142 |
+
for ckpt, res in runs:
|
| 143 |
+
try:
|
| 144 |
+
run_on_checkpoint(ckpt, res, args.output_dir)
|
| 145 |
+
except Exception as e:
|
| 146 |
+
print(f" FAIL {ckpt}: {e}")
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
if __name__ == '__main__':
|
| 150 |
+
main()
|
experiments/analysis/extract_video_features.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Extract video features from Scene Camera videos using a pretrained backbone.
|
| 4 |
+
Uses CLIP (ViT-B/16) which is lightweight and doesn't need video-specific pretraining.
|
| 5 |
+
|
| 6 |
+
Output: per-frame feature vectors saved as .npy files, aligned to 100Hz sensor data.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import json
|
| 12 |
+
import glob
|
| 13 |
+
import argparse
|
| 14 |
+
import numpy as np
|
| 15 |
+
import cv2
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
from torchvision import transforms
|
| 19 |
+
|
| 20 |
+
DATASET_DIR = "${PULSE_ROOT}/dataset"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CLIPFeatureExtractor:
|
| 24 |
+
"""Extract features using CLIP ViT-B/16 (via torchvision)."""
|
| 25 |
+
|
| 26 |
+
def __init__(self, device='cpu'):
|
| 27 |
+
self.device = device
|
| 28 |
+
# Use torchvision's pretrained ViT
|
| 29 |
+
from torchvision.models import vit_b_16, ViT_B_16_Weights
|
| 30 |
+
weights = ViT_B_16_Weights.IMAGENET1K_V1
|
| 31 |
+
model = vit_b_16(weights=weights)
|
| 32 |
+
# Remove classification head, keep feature extractor
|
| 33 |
+
model.heads = nn.Identity()
|
| 34 |
+
model.eval()
|
| 35 |
+
self.model = model.to(device)
|
| 36 |
+
self.transform = weights.transforms()
|
| 37 |
+
self.feat_dim = 768 # ViT-B/16 feature dimension
|
| 38 |
+
|
| 39 |
+
@torch.no_grad()
|
| 40 |
+
def extract_batch(self, frames):
|
| 41 |
+
"""Extract features from a batch of frames.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
frames: list of numpy arrays (H, W, 3) in BGR format
|
| 45 |
+
Returns:
|
| 46 |
+
features: numpy array (N, feat_dim)
|
| 47 |
+
"""
|
| 48 |
+
tensors = []
|
| 49 |
+
for frame in frames:
|
| 50 |
+
# BGR -> RGB -> PIL-like tensor
|
| 51 |
+
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 52 |
+
tensor = torch.from_numpy(rgb).permute(2, 0, 1).float() / 255.0
|
| 53 |
+
tensor = self.transform(tensor)
|
| 54 |
+
tensors.append(tensor)
|
| 55 |
+
|
| 56 |
+
batch = torch.stack(tensors).to(self.device)
|
| 57 |
+
features = self.model(batch)
|
| 58 |
+
return features.cpu().numpy()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def find_scene_video(scenario_dir, vol, scenario):
|
| 62 |
+
"""Find the Scene Camera video file."""
|
| 63 |
+
pattern = os.path.join(scenario_dir, f"trimmed_{vol}{scenario}*Scene Cam.mp4")
|
| 64 |
+
matches = glob.glob(pattern)
|
| 65 |
+
return matches[0] if matches else None
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def extract_features_for_video(extractor, video_path, target_fps=100,
|
| 69 |
+
batch_size=32, sample_fps=2):
|
| 70 |
+
"""Extract features from a video file.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
extractor: feature extractor
|
| 74 |
+
video_path: path to video file
|
| 75 |
+
target_fps: target frame rate to align with sensor data (100Hz)
|
| 76 |
+
batch_size: batch size for feature extraction
|
| 77 |
+
sample_fps: extract features at this rate (e.g., 2 = every 0.5s)
|
| 78 |
+
Features are then interpolated to target_fps.
|
| 79 |
+
Returns:
|
| 80 |
+
features: numpy array (T_target, feat_dim) aligned to target_fps
|
| 81 |
+
"""
|
| 82 |
+
cap = cv2.VideoCapture(video_path)
|
| 83 |
+
video_fps = cap.get(cv2.CAP_PROP_FPS)
|
| 84 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 85 |
+
duration = total_frames / video_fps
|
| 86 |
+
|
| 87 |
+
# Sample frames at sample_fps
|
| 88 |
+
sample_interval = int(video_fps / sample_fps)
|
| 89 |
+
sample_indices = list(range(0, total_frames, sample_interval))
|
| 90 |
+
|
| 91 |
+
print(f" Video: {total_frames} frames @ {video_fps:.1f}fps = {duration:.1f}s")
|
| 92 |
+
print(f" Sampling {len(sample_indices)} frames @ {sample_fps}fps")
|
| 93 |
+
|
| 94 |
+
# Extract features in batches
|
| 95 |
+
all_features = []
|
| 96 |
+
batch_frames = []
|
| 97 |
+
batch_indices = []
|
| 98 |
+
|
| 99 |
+
for idx in sample_indices:
|
| 100 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
| 101 |
+
ret, frame = cap.read()
|
| 102 |
+
if not ret:
|
| 103 |
+
break
|
| 104 |
+
batch_frames.append(frame)
|
| 105 |
+
batch_indices.append(idx)
|
| 106 |
+
|
| 107 |
+
if len(batch_frames) >= batch_size:
|
| 108 |
+
feats = extractor.extract_batch(batch_frames)
|
| 109 |
+
all_features.append(feats)
|
| 110 |
+
batch_frames = []
|
| 111 |
+
if len(all_features) % 10 == 0:
|
| 112 |
+
print(f" Processed {len(all_features) * batch_size} frames...")
|
| 113 |
+
|
| 114 |
+
if batch_frames:
|
| 115 |
+
feats = extractor.extract_batch(batch_frames)
|
| 116 |
+
all_features.append(feats)
|
| 117 |
+
|
| 118 |
+
cap.release()
|
| 119 |
+
|
| 120 |
+
if not all_features:
|
| 121 |
+
return None
|
| 122 |
+
|
| 123 |
+
features = np.concatenate(all_features, axis=0) # (N_samples, feat_dim)
|
| 124 |
+
sample_times = np.array(batch_indices[:features.shape[0]]) / video_fps # seconds
|
| 125 |
+
|
| 126 |
+
# Interpolate to target_fps (100Hz)
|
| 127 |
+
target_times = np.arange(0, duration, 1.0 / target_fps)
|
| 128 |
+
n_target = len(target_times)
|
| 129 |
+
|
| 130 |
+
# Linear interpolation per feature dimension
|
| 131 |
+
from scipy.interpolate import interp1d
|
| 132 |
+
if len(sample_times) < 2:
|
| 133 |
+
# Not enough samples, repeat
|
| 134 |
+
interpolated = np.tile(features[0], (n_target, 1))
|
| 135 |
+
else:
|
| 136 |
+
interp_func = interp1d(
|
| 137 |
+
sample_times, features, axis=0,
|
| 138 |
+
kind='linear', fill_value='extrapolate'
|
| 139 |
+
)
|
| 140 |
+
interpolated = interp_func(target_times).astype(np.float32)
|
| 141 |
+
|
| 142 |
+
print(f" Output: {interpolated.shape} @ {target_fps}Hz")
|
| 143 |
+
return interpolated
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def main():
|
| 147 |
+
parser = argparse.ArgumentParser(description='Extract video features')
|
| 148 |
+
parser.add_argument('--sample_fps', type=int, default=2,
|
| 149 |
+
help='Sample rate for feature extraction (default: 2fps)')
|
| 150 |
+
parser.add_argument('--batch_size', type=int, default=16,
|
| 151 |
+
help='Batch size for feature extraction')
|
| 152 |
+
parser.add_argument('--device', type=str, default='cuda',
|
| 153 |
+
help='Device (cuda or cpu)')
|
| 154 |
+
args = parser.parse_args()
|
| 155 |
+
|
| 156 |
+
device = args.device if torch.cuda.is_available() and args.device == 'cuda' else 'cpu'
|
| 157 |
+
print(f"Device: {device}")
|
| 158 |
+
|
| 159 |
+
print("Loading ViT-B/16 feature extractor...")
|
| 160 |
+
extractor = CLIPFeatureExtractor(device=device)
|
| 161 |
+
print(f"Feature dim: {extractor.feat_dim}")
|
| 162 |
+
|
| 163 |
+
# Process all volunteers and scenarios
|
| 164 |
+
processed = 0
|
| 165 |
+
skipped = 0
|
| 166 |
+
|
| 167 |
+
for vol_dir in sorted(glob.glob(f"{DATASET_DIR}/v*")):
|
| 168 |
+
vol = os.path.basename(vol_dir)
|
| 169 |
+
for scenario_dir in sorted(glob.glob(f"{vol_dir}/s*")):
|
| 170 |
+
scenario = os.path.basename(scenario_dir)
|
| 171 |
+
output_path = os.path.join(scenario_dir, "video_features_100hz.npy")
|
| 172 |
+
|
| 173 |
+
# Skip if already extracted
|
| 174 |
+
if os.path.exists(output_path):
|
| 175 |
+
print(f"[{vol}/{scenario}] Already exists, skipping")
|
| 176 |
+
skipped += 1
|
| 177 |
+
continue
|
| 178 |
+
|
| 179 |
+
# Find video
|
| 180 |
+
video_path = find_scene_video(scenario_dir, vol, scenario)
|
| 181 |
+
if video_path is None:
|
| 182 |
+
print(f"[{vol}/{scenario}] No Scene Camera video found, skipping")
|
| 183 |
+
skipped += 1
|
| 184 |
+
continue
|
| 185 |
+
|
| 186 |
+
print(f"\n[{vol}/{scenario}]")
|
| 187 |
+
print(f" Video: {os.path.basename(video_path)}")
|
| 188 |
+
|
| 189 |
+
features = extract_features_for_video(
|
| 190 |
+
extractor, video_path,
|
| 191 |
+
batch_size=args.batch_size,
|
| 192 |
+
sample_fps=args.sample_fps,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
if features is not None:
|
| 196 |
+
np.save(output_path, features)
|
| 197 |
+
print(f" Saved: {output_path} ({features.shape})")
|
| 198 |
+
processed += 1
|
| 199 |
+
else:
|
| 200 |
+
print(f" FAILED: Could not extract features")
|
| 201 |
+
|
| 202 |
+
print(f"\n{'='*60}")
|
| 203 |
+
print(f"Done! Processed: {processed}, Skipped: {skipped}")
|
| 204 |
+
print(f"Feature files: {DATASET_DIR}/*/*/video_features_100hz.npy")
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
if __name__ == '__main__':
|
| 208 |
+
main()
|
experiments/analysis/extract_videomae_features.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Extract video features using VideoMAE (pretrained on Kinetics-400).
|
| 4 |
+
Process 16-frame video clips to capture temporal dynamics.
|
| 5 |
+
|
| 6 |
+
Output: per-frame feature vectors aligned to 100Hz sensor data.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import json
|
| 12 |
+
import glob
|
| 13 |
+
import argparse
|
| 14 |
+
import numpy as np
|
| 15 |
+
import cv2
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
DATASET_DIR = "${PULSE_ROOT}/dataset"
|
| 19 |
+
MODEL_NAME = "${PULSE_ROOT}/models/videomae-base-kinetics"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class VideoMAEFeatureExtractor:
|
| 23 |
+
"""Extract features using VideoMAE-Base (16-frame clips). Multi-GPU enabled."""
|
| 24 |
+
|
| 25 |
+
def __init__(self, device='cpu'):
|
| 26 |
+
from transformers import VideoMAEModel, VideoMAEImageProcessor
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
self.device = device
|
| 29 |
+
self.processor = VideoMAEImageProcessor.from_pretrained(MODEL_NAME)
|
| 30 |
+
model = VideoMAEModel.from_pretrained(MODEL_NAME).to(device)
|
| 31 |
+
model.eval()
|
| 32 |
+
# Wrap with DataParallel if multiple GPUs available
|
| 33 |
+
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
|
| 34 |
+
self.n_gpus = torch.cuda.device_count()
|
| 35 |
+
print(f" Using DataParallel across {self.n_gpus} GPUs")
|
| 36 |
+
self.model = nn.DataParallel(model)
|
| 37 |
+
self.num_frames = model.config.num_frames
|
| 38 |
+
self.feat_dim = model.config.hidden_size
|
| 39 |
+
else:
|
| 40 |
+
self.n_gpus = 1
|
| 41 |
+
self.model = model
|
| 42 |
+
self.num_frames = model.config.num_frames
|
| 43 |
+
self.feat_dim = model.config.hidden_size
|
| 44 |
+
|
| 45 |
+
@torch.no_grad()
|
| 46 |
+
def extract_clip(self, frames):
|
| 47 |
+
"""Extract feature from a single 16-frame clip.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
frames: list of 16 RGB numpy arrays (H, W, 3)
|
| 51 |
+
Returns:
|
| 52 |
+
feature: numpy array (feat_dim,) - mean-pooled patch tokens
|
| 53 |
+
"""
|
| 54 |
+
# Pad/truncate to exactly num_frames
|
| 55 |
+
if len(frames) < self.num_frames:
|
| 56 |
+
frames = frames + [frames[-1]] * (self.num_frames - len(frames))
|
| 57 |
+
elif len(frames) > self.num_frames:
|
| 58 |
+
# uniform sampling
|
| 59 |
+
indices = np.linspace(0, len(frames) - 1, self.num_frames, dtype=int)
|
| 60 |
+
frames = [frames[i] for i in indices]
|
| 61 |
+
|
| 62 |
+
inputs = self.processor(frames, return_tensors="pt")
|
| 63 |
+
pixel_values = inputs["pixel_values"].to(self.device)
|
| 64 |
+
outputs = self.model(pixel_values)
|
| 65 |
+
# Average pool over all patch tokens
|
| 66 |
+
feature = outputs.last_hidden_state.mean(dim=1).squeeze(0) # (768,)
|
| 67 |
+
return feature.cpu().numpy()
|
| 68 |
+
|
| 69 |
+
@torch.no_grad()
|
| 70 |
+
def extract_clip_batch(self, clips):
|
| 71 |
+
"""Extract features from a batch of clips.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
clips: list of clips, each is a list of 16 RGB frames
|
| 75 |
+
Returns:
|
| 76 |
+
features: numpy array (B, feat_dim)
|
| 77 |
+
"""
|
| 78 |
+
# Process each clip
|
| 79 |
+
all_pixel_values = []
|
| 80 |
+
for frames in clips:
|
| 81 |
+
if len(frames) < self.num_frames:
|
| 82 |
+
frames = frames + [frames[-1]] * (self.num_frames - len(frames))
|
| 83 |
+
elif len(frames) > self.num_frames:
|
| 84 |
+
indices = np.linspace(0, len(frames) - 1, self.num_frames, dtype=int)
|
| 85 |
+
frames = [frames[i] for i in indices]
|
| 86 |
+
inputs = self.processor(frames, return_tensors="pt")
|
| 87 |
+
all_pixel_values.append(inputs["pixel_values"])
|
| 88 |
+
|
| 89 |
+
batch = torch.cat(all_pixel_values, dim=0).to(self.device)
|
| 90 |
+
outputs = self.model(batch)
|
| 91 |
+
features = outputs.last_hidden_state.mean(dim=1) # (B, 768)
|
| 92 |
+
return features.cpu().numpy()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def find_scene_video(scenario_dir, vol, scenario):
|
| 96 |
+
pattern = os.path.join(scenario_dir, f"trimmed_{vol}{scenario}*Scene Cam.mp4")
|
| 97 |
+
matches = glob.glob(pattern)
|
| 98 |
+
return matches[0] if matches else None
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def extract_features_for_video(extractor, video_path, target_fps=100,
|
| 102 |
+
clip_stride_sec=0.5, batch_size=4):
|
| 103 |
+
"""Extract VideoMAE features from a video.
|
| 104 |
+
|
| 105 |
+
Strategy (fast):
|
| 106 |
+
- Sequentially decode video ONCE, downsample to 8fps and store frames in RAM
|
| 107 |
+
- Build clips by indexing into the in-memory frame array (no random seeks)
|
| 108 |
+
"""
|
| 109 |
+
import time
|
| 110 |
+
t0 = time.time()
|
| 111 |
+
cap = cv2.VideoCapture(video_path)
|
| 112 |
+
video_fps = cap.get(cv2.CAP_PROP_FPS)
|
| 113 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 114 |
+
duration = total_frames / video_fps
|
| 115 |
+
|
| 116 |
+
# Read all frames sequentially, downsample to ~16fps (every video_fps/16 frame)
|
| 117 |
+
decode_fps = 16 # we sample frames at this rate from the video
|
| 118 |
+
decode_stride = max(1, int(round(video_fps / decode_fps)))
|
| 119 |
+
print(f" Video: {total_frames} frames @ {video_fps:.1f}fps = {duration:.1f}s")
|
| 120 |
+
print(f" Decoding sequentially with stride {decode_stride} (~{video_fps/decode_stride:.1f}fps)...")
|
| 121 |
+
|
| 122 |
+
# Pre-resize to model input size during decoding to save memory
|
| 123 |
+
# VideoMAE expects 224x224
|
| 124 |
+
target_size = 224
|
| 125 |
+
|
| 126 |
+
decoded_frames = [] # list of (H, W, 3) uint8 RGB arrays
|
| 127 |
+
decoded_times = [] # corresponding timestamps in seconds
|
| 128 |
+
frame_idx = 0
|
| 129 |
+
while True:
|
| 130 |
+
ret, frame = cap.read()
|
| 131 |
+
if not ret:
|
| 132 |
+
break
|
| 133 |
+
if frame_idx % decode_stride == 0:
|
| 134 |
+
# Resize early to save memory
|
| 135 |
+
resized = cv2.resize(frame, (target_size, target_size), interpolation=cv2.INTER_AREA)
|
| 136 |
+
rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
|
| 137 |
+
decoded_frames.append(rgb)
|
| 138 |
+
decoded_times.append(frame_idx / video_fps)
|
| 139 |
+
frame_idx += 1
|
| 140 |
+
cap.release()
|
| 141 |
+
|
| 142 |
+
decoded_frames = np.array(decoded_frames) # (N, 224, 224, 3)
|
| 143 |
+
decoded_times = np.array(decoded_times)
|
| 144 |
+
decode_time = time.time() - t0
|
| 145 |
+
print(f" Decoded {len(decoded_frames)} frames in {decode_time:.1f}s")
|
| 146 |
+
|
| 147 |
+
# Build clips: each clip = 16 frames spanning ~1 second
|
| 148 |
+
# Sample 16 consecutive frames from in-memory array
|
| 149 |
+
frames_per_clip = 16
|
| 150 |
+
n_decoded = len(decoded_frames)
|
| 151 |
+
if n_decoded < 4:
|
| 152 |
+
return None
|
| 153 |
+
|
| 154 |
+
# Each clip occupies 16 frames at ~16fps = 1 second
|
| 155 |
+
clip_centers_sec = np.arange(0.5, duration - 0.5, clip_stride_sec)
|
| 156 |
+
n_clips = len(clip_centers_sec)
|
| 157 |
+
print(f" Building {n_clips} clips (stride={clip_stride_sec}s, {frames_per_clip} frames each)")
|
| 158 |
+
|
| 159 |
+
all_features = []
|
| 160 |
+
clip_times = []
|
| 161 |
+
batch_clips = []
|
| 162 |
+
batch_times = []
|
| 163 |
+
|
| 164 |
+
t1 = time.time()
|
| 165 |
+
for center_sec in clip_centers_sec:
|
| 166 |
+
# Find decoded frames within ±0.5s window
|
| 167 |
+
center_idx = np.searchsorted(decoded_times, center_sec)
|
| 168 |
+
half = frames_per_clip // 2
|
| 169 |
+
start = max(0, center_idx - half)
|
| 170 |
+
end = min(n_decoded, start + frames_per_clip)
|
| 171 |
+
start = max(0, end - frames_per_clip)
|
| 172 |
+
|
| 173 |
+
if end - start < 4:
|
| 174 |
+
continue
|
| 175 |
+
|
| 176 |
+
clip = list(decoded_frames[start:end])
|
| 177 |
+
# Pad if needed
|
| 178 |
+
if len(clip) < frames_per_clip:
|
| 179 |
+
clip = clip + [clip[-1]] * (frames_per_clip - len(clip))
|
| 180 |
+
|
| 181 |
+
batch_clips.append(clip)
|
| 182 |
+
batch_times.append(center_sec)
|
| 183 |
+
|
| 184 |
+
if len(batch_clips) >= batch_size:
|
| 185 |
+
feats = extractor.extract_clip_batch(batch_clips)
|
| 186 |
+
all_features.append(feats)
|
| 187 |
+
clip_times.extend(batch_times)
|
| 188 |
+
batch_clips = []
|
| 189 |
+
batch_times = []
|
| 190 |
+
|
| 191 |
+
if batch_clips:
|
| 192 |
+
feats = extractor.extract_clip_batch(batch_clips)
|
| 193 |
+
all_features.append(feats)
|
| 194 |
+
clip_times.extend(batch_times)
|
| 195 |
+
inference_time = time.time() - t1
|
| 196 |
+
print(f" Inference time: {inference_time:.1f}s ({len(clip_times)} clips)")
|
| 197 |
+
|
| 198 |
+
if not all_features:
|
| 199 |
+
return None
|
| 200 |
+
|
| 201 |
+
features = np.concatenate(all_features, axis=0) # (N_clips, 768)
|
| 202 |
+
clip_times = np.array(clip_times[:features.shape[0]])
|
| 203 |
+
|
| 204 |
+
# Interpolate to target_fps (100Hz)
|
| 205 |
+
target_times = np.arange(0, duration, 1.0 / target_fps)
|
| 206 |
+
n_target = len(target_times)
|
| 207 |
+
|
| 208 |
+
from scipy.interpolate import interp1d
|
| 209 |
+
if len(clip_times) < 2:
|
| 210 |
+
interpolated = np.tile(features[0], (n_target, 1))
|
| 211 |
+
else:
|
| 212 |
+
interp_func = interp1d(
|
| 213 |
+
clip_times, features, axis=0,
|
| 214 |
+
kind='linear', fill_value='extrapolate'
|
| 215 |
+
)
|
| 216 |
+
interpolated = interp_func(target_times).astype(np.float32)
|
| 217 |
+
|
| 218 |
+
print(f" Output: {interpolated.shape} @ {target_fps}Hz")
|
| 219 |
+
return interpolated
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def main():
|
| 223 |
+
parser = argparse.ArgumentParser()
|
| 224 |
+
parser.add_argument('--clip_stride', type=float, default=0.5,
|
| 225 |
+
help='Clip extraction stride in seconds (default: 0.5)')
|
| 226 |
+
parser.add_argument('--batch_size', type=int, default=4)
|
| 227 |
+
parser.add_argument('--device', type=str, default='cuda')
|
| 228 |
+
parser.add_argument('--output_name', type=str, default='video_features_videomae_100hz.npy')
|
| 229 |
+
args = parser.parse_args()
|
| 230 |
+
|
| 231 |
+
device = args.device if torch.cuda.is_available() and args.device == 'cuda' else 'cpu'
|
| 232 |
+
print(f"Device: {device}")
|
| 233 |
+
|
| 234 |
+
print(f"Loading VideoMAE from {MODEL_NAME}...")
|
| 235 |
+
extractor = VideoMAEFeatureExtractor(device=device)
|
| 236 |
+
print(f"Feature dim: {extractor.feat_dim}, num frames per clip: {extractor.num_frames}")
|
| 237 |
+
|
| 238 |
+
processed = 0
|
| 239 |
+
skipped = 0
|
| 240 |
+
|
| 241 |
+
for vol_dir in sorted(glob.glob(f"{DATASET_DIR}/v*")):
|
| 242 |
+
vol = os.path.basename(vol_dir)
|
| 243 |
+
for scenario_dir in sorted(glob.glob(f"{vol_dir}/s*")):
|
| 244 |
+
scenario = os.path.basename(scenario_dir)
|
| 245 |
+
output_path = os.path.join(scenario_dir, args.output_name)
|
| 246 |
+
|
| 247 |
+
if os.path.exists(output_path):
|
| 248 |
+
print(f"[{vol}/{scenario}] exists, skip")
|
| 249 |
+
skipped += 1
|
| 250 |
+
continue
|
| 251 |
+
|
| 252 |
+
video_path = find_scene_video(scenario_dir, vol, scenario)
|
| 253 |
+
if video_path is None:
|
| 254 |
+
print(f"[{vol}/{scenario}] no video, skip")
|
| 255 |
+
skipped += 1
|
| 256 |
+
continue
|
| 257 |
+
|
| 258 |
+
print(f"\n[{vol}/{scenario}]")
|
| 259 |
+
features = extract_features_for_video(
|
| 260 |
+
extractor, video_path,
|
| 261 |
+
clip_stride_sec=args.clip_stride,
|
| 262 |
+
batch_size=args.batch_size,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
if features is not None:
|
| 266 |
+
np.save(output_path, features)
|
| 267 |
+
print(f" Saved: {output_path} ({features.shape})")
|
| 268 |
+
processed += 1
|
| 269 |
+
else:
|
| 270 |
+
print(f" FAILED")
|
| 271 |
+
|
| 272 |
+
print(f"\nDone! Processed: {processed}, Skipped: {skipped}")
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
if __name__ == '__main__':
|
| 276 |
+
main()
|
experiments/analysis/gen_val_comparison.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys, json, torch
|
| 2 |
+
sys.path.insert(0, '${PULSE_ROOT}')
|
| 3 |
+
os.environ['HF_HUB_OFFLINE'] = '1'
|
| 4 |
+
os.environ['TRANSFORMERS_OFFLINE'] = '1'
|
| 5 |
+
|
| 6 |
+
from tasks.train_pred import (
|
| 7 |
+
TextPredictionDataset, SensorToTextModel, apply_lora, set_seed
|
| 8 |
+
)
|
| 9 |
+
from data.dataset import TRAIN_VOLS, VAL_VOLS, TEST_VOLS
|
| 10 |
+
|
| 11 |
+
set_seed(42)
|
| 12 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 13 |
+
|
| 14 |
+
# Load tokenizer & LLM
|
| 15 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 16 |
+
llm_path = '${PULSE_ROOT}/models/qwen2.5-0.5b'
|
| 17 |
+
tokenizer = AutoTokenizer.from_pretrained(llm_path, trust_remote_code=True, local_files_only=True)
|
| 18 |
+
if tokenizer.pad_token is None:
|
| 19 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 20 |
+
|
| 21 |
+
llm = AutoModelForCausalLM.from_pretrained(
|
| 22 |
+
llm_path, trust_remote_code=True, torch_dtype=torch.float32, local_files_only=True
|
| 23 |
+
).to(device)
|
| 24 |
+
llm.config.pad_token_id = tokenizer.pad_token_id
|
| 25 |
+
for p in llm.parameters():
|
| 26 |
+
p.requires_grad = False
|
| 27 |
+
lora_params = apply_lora(llm, r=8, alpha=16)
|
| 28 |
+
|
| 29 |
+
modalities = ['mocap', 'emg', 'imu']
|
| 30 |
+
|
| 31 |
+
# Build datasets
|
| 32 |
+
train_ds = TextPredictionDataset(TRAIN_VOLS, modalities, tokenizer, window_sec=15.0, downsample=5)
|
| 33 |
+
stats = train_ds.get_stats()
|
| 34 |
+
val_ds = TextPredictionDataset(VAL_VOLS, modalities, tokenizer, window_sec=15.0, downsample=5, stats=stats)
|
| 35 |
+
test_ds = TextPredictionDataset(TEST_VOLS, modalities, tokenizer, window_sec=15.0, downsample=5, stats=stats)
|
| 36 |
+
|
| 37 |
+
# Build model & load weights
|
| 38 |
+
model = SensorToTextModel(train_ds.feat_dim, llm, tokenizer, n_sensor_tokens=8, d_model=64)
|
| 39 |
+
model.to(device)
|
| 40 |
+
|
| 41 |
+
ckpt_path = '${PULSE_ROOT}/results/pred_llm2/pred_llm_mocap-emg-imu/model_best.pt'
|
| 42 |
+
sd = torch.load(ckpt_path, weights_only=True, map_location=device)
|
| 43 |
+
model.load_state_dict(sd, strict=False)
|
| 44 |
+
model.eval()
|
| 45 |
+
|
| 46 |
+
out_path = '${PULSE_ROOT}/docs/pred_llm2_val_comparison.txt'
|
| 47 |
+
|
| 48 |
+
from torch.utils.data import DataLoader
|
| 49 |
+
|
| 50 |
+
with open(out_path, 'w') as f:
|
| 51 |
+
for split_name, ds in [('Validation', val_ds), ('Test', test_ds)]:
|
| 52 |
+
loader = DataLoader(ds, batch_size=8, shuffle=False)
|
| 53 |
+
f.write(f"{'='*70}\n")
|
| 54 |
+
f.write(f"{split_name} Set — mocap,emg,imu (best charF1=0.0324)\n")
|
| 55 |
+
f.write(f"Samples: {len(ds)}\n")
|
| 56 |
+
f.write(f"{'='*70}\n\n")
|
| 57 |
+
|
| 58 |
+
idx = 0
|
| 59 |
+
for batch in loader:
|
| 60 |
+
sensor = batch['sensor'].to(device)
|
| 61 |
+
preds = model.generate_text(sensor, tokenizer, max_new_tokens=20)
|
| 62 |
+
refs = [ds.texts[idx + i] for i in range(len(preds))]
|
| 63 |
+
for p, r in zip(preds, refs):
|
| 64 |
+
match = "OK" if p.strip() == r.strip() else "XX"
|
| 65 |
+
f.write(f"[{match}] #{idx+1}\n")
|
| 66 |
+
f.write(f" Pred: {p.strip()}\n")
|
| 67 |
+
f.write(f" Ref: {r.strip()}\n\n")
|
| 68 |
+
idx += 1
|
| 69 |
+
|
| 70 |
+
# Stats
|
| 71 |
+
f.write(f"\n--- {split_name} Summary ---\n")
|
| 72 |
+
f.write(f"Total: {idx}\n\n")
|
| 73 |
+
|
| 74 |
+
print(f"Written to {out_path}")
|
experiments/analysis/generate_action_labels.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Generate action labels by clustering task descriptions using text embeddings.
|
| 4 |
+
No manual rules — uses sentence-transformers + K-Means clustering.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import glob
|
| 10 |
+
import argparse
|
| 11 |
+
import numpy as np
|
| 12 |
+
from collections import Counter
|
| 13 |
+
from sklearn.cluster import KMeans
|
| 14 |
+
from sklearn.metrics import silhouette_score
|
| 15 |
+
|
| 16 |
+
ANNOTATION_DIR = "${PULSE_ROOT}"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def collect_tasks():
|
| 20 |
+
"""Collect all task descriptions from all annotation files."""
|
| 21 |
+
tasks = []
|
| 22 |
+
for path in sorted(glob.glob(os.path.join(ANNOTATION_DIR, 'v*/s*.json'))):
|
| 23 |
+
with open(path) as f:
|
| 24 |
+
data = json.load(f)
|
| 25 |
+
for seg in data.get('segments', []):
|
| 26 |
+
tasks.append(seg['task'])
|
| 27 |
+
return tasks
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def embed_texts(texts):
|
| 31 |
+
"""Encode texts using sentence-transformers (multilingual model)."""
|
| 32 |
+
try:
|
| 33 |
+
from sentence_transformers import SentenceTransformer
|
| 34 |
+
model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
|
| 35 |
+
embeddings = model.encode(texts, show_progress_bar=True, batch_size=128)
|
| 36 |
+
print(f"Encoded {len(texts)} texts with sentence-transformers, dim={embeddings.shape[1]}")
|
| 37 |
+
return embeddings
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print(f"sentence-transformers failed ({e}), falling back to TF-IDF")
|
| 40 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 41 |
+
vec = TfidfVectorizer(analyzer='char', ngram_range=(1, 3), max_features=3000)
|
| 42 |
+
X = vec.fit_transform(texts).toarray()
|
| 43 |
+
print(f"Encoded {len(texts)} texts with TF-IDF char n-grams, dim={X.shape[1]}")
|
| 44 |
+
return X
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def cluster_tasks(tasks, k_range=(10, 30)):
|
| 48 |
+
unique_tasks = sorted(set(tasks))
|
| 49 |
+
print(f"Total segments: {len(tasks)}, Unique task texts: {len(unique_tasks)}")
|
| 50 |
+
|
| 51 |
+
X = embed_texts(unique_tasks)
|
| 52 |
+
|
| 53 |
+
# Find optimal K via silhouette score
|
| 54 |
+
best_k, best_score = k_range[0], -1
|
| 55 |
+
scores = {}
|
| 56 |
+
for k in range(k_range[0], k_range[1] + 1):
|
| 57 |
+
km = KMeans(n_clusters=k, random_state=42, n_init=10)
|
| 58 |
+
labels = km.fit_predict(X)
|
| 59 |
+
score = silhouette_score(X, labels, sample_size=min(2000, len(unique_tasks)))
|
| 60 |
+
scores[k] = score
|
| 61 |
+
if score > best_score:
|
| 62 |
+
best_score = score
|
| 63 |
+
best_k = k
|
| 64 |
+
print(f" K={k}: silhouette={score:.4f}" + (" *" if k == best_k else ""))
|
| 65 |
+
|
| 66 |
+
print(f"\nBest K={best_k} (silhouette={best_score:.4f})")
|
| 67 |
+
|
| 68 |
+
# Final clustering
|
| 69 |
+
km = KMeans(n_clusters=best_k, random_state=42, n_init=10)
|
| 70 |
+
labels = km.fit_predict(X)
|
| 71 |
+
|
| 72 |
+
task_to_cluster = {task: int(labels[i]) for i, task in enumerate(unique_tasks)}
|
| 73 |
+
|
| 74 |
+
# Representative task per cluster (closest to centroid)
|
| 75 |
+
cluster_representatives = {}
|
| 76 |
+
cluster_members = {}
|
| 77 |
+
for cid in range(best_k):
|
| 78 |
+
member_idx = [i for i, l in enumerate(labels) if l == cid]
|
| 79 |
+
members = [unique_tasks[i] for i in member_idx]
|
| 80 |
+
cluster_members[cid] = members
|
| 81 |
+
centroid = km.cluster_centers_[cid]
|
| 82 |
+
dists = np.linalg.norm(X[member_idx] - centroid, axis=1)
|
| 83 |
+
closest = member_idx[np.argmin(dists)]
|
| 84 |
+
cluster_representatives[cid] = unique_tasks[closest]
|
| 85 |
+
|
| 86 |
+
return task_to_cluster, cluster_representatives, cluster_members, best_k, scores
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def main():
|
| 90 |
+
parser = argparse.ArgumentParser()
|
| 91 |
+
parser.add_argument('--output_dir', type=str,
|
| 92 |
+
default='${PULSE_ROOT}/results/pred')
|
| 93 |
+
parser.add_argument('--k_min', type=int, default=10)
|
| 94 |
+
parser.add_argument('--k_max', type=int, default=30)
|
| 95 |
+
args = parser.parse_args()
|
| 96 |
+
|
| 97 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 98 |
+
|
| 99 |
+
tasks = collect_tasks()
|
| 100 |
+
task_to_cluster, representatives, members, K, scores = cluster_tasks(
|
| 101 |
+
tasks, k_range=(args.k_min, args.k_max)
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Print summary
|
| 105 |
+
segment_counts = Counter(task_to_cluster[t] for t in tasks)
|
| 106 |
+
print(f"\n{'='*60}")
|
| 107 |
+
print(f"Clusters (K={K}):")
|
| 108 |
+
for cid in range(K):
|
| 109 |
+
rep = representatives[cid]
|
| 110 |
+
n_unique = len(members[cid])
|
| 111 |
+
n_segs = segment_counts.get(cid, 0)
|
| 112 |
+
examples = [m for m in members[cid] if m != rep][:3]
|
| 113 |
+
print(f"\n [{cid:2d}] ({n_segs:4d} segs, {n_unique:3d} unique) \"{rep}\"")
|
| 114 |
+
for ex in examples:
|
| 115 |
+
print(f" - {ex}")
|
| 116 |
+
|
| 117 |
+
# Save
|
| 118 |
+
output = {
|
| 119 |
+
'num_classes': K,
|
| 120 |
+
'task_to_cluster': task_to_cluster,
|
| 121 |
+
'cluster_representatives': {str(k): v for k, v in representatives.items()},
|
| 122 |
+
'cluster_sizes_unique': {str(k): len(v) for k, v in members.items()},
|
| 123 |
+
'cluster_sizes_segments': {str(k): v for k, v in segment_counts.items()},
|
| 124 |
+
'silhouette_scores': {str(k): v for k, v in scores.items()},
|
| 125 |
+
}
|
| 126 |
+
out_path = os.path.join(args.output_dir, 'action_labels.json')
|
| 127 |
+
with open(out_path, 'w') as f:
|
| 128 |
+
json.dump(output, f, indent=2, ensure_ascii=False)
|
| 129 |
+
print(f"\nSaved to {out_path}")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
if __name__ == '__main__':
|
| 133 |
+
main()
|
experiments/analysis/generate_coarse_annotations.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Generate coarse-grained annotations by merging consecutive fine-grained segments
|
| 4 |
+
into composite actions (8-15s duration) using LLM.
|
| 5 |
+
|
| 6 |
+
Input: annotations_v2/ (fine-grained, ~2-3s segments, 11 classes)
|
| 7 |
+
Output: annotations_coarse/ (coarse-grained, ~8-15s segments, ~6 classes)
|
| 8 |
+
|
| 9 |
+
Does NOT modify annotations_v2/.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import json
|
| 14 |
+
import re
|
| 15 |
+
import time
|
| 16 |
+
import glob
|
| 17 |
+
import urllib.request
|
| 18 |
+
from collections import Counter
|
| 19 |
+
|
| 20 |
+
INPUT_DIR = "${PULSE_ROOT}/annotations_v2"
|
| 21 |
+
OUTPUT_DIR = "${PULSE_ROOT}/annotations_coarse"
|
| 22 |
+
|
| 23 |
+
API_URL = "https://api.chatanywhere.tech/v1/chat/completions"
|
| 24 |
+
API_KEYS = [
|
| 25 |
+
"sk-MN5n1uEETyaky96fLJdHqZobXF1f7KmOrZHzwD3lt585asFQ",
|
| 26 |
+
"sk-YnYrtPdAXwlE12hRpi6dYqlE1RRVR3LDVBka6wKaefU4iQRY",
|
| 27 |
+
"sk-jOZtodDv6OxUOMu3NuJ8lzffjwBlshn9OHY5KSmqmPTtc9qs",
|
| 28 |
+
"sk-qAaKTKYIRF24btu1oQWgubWG4UdA92bILNtzOkHNEPAcCxdB",
|
| 29 |
+
"sk-MgCBBonblMrCFnSXd6fJZaBLTCfCJ5FjYZfSe2e46bgmyktk",
|
| 30 |
+
"sk-79e30kYRgduuf2fSU0Lsc814YjNkClXXzQqIbx0iLS40IOEH",
|
| 31 |
+
"sk-h9Tej4tW6AQC6fT0njfzrPKXEk6fBwpiSvvQd0aJAhw4UwLz",
|
| 32 |
+
"sk-k2QNHt5wAH26Fw8hZuPWuVXw8Psd1jX09qusiA6PdBj5Vzuu",
|
| 33 |
+
"sk-w7EkTblciNI44cwosHXi0PGZNUf1hnJmpzOQ85va9VPdAKbz",
|
| 34 |
+
"sk-Dexs5ZF7OjFCq7CZW45wJ8EKoGtIswv6rsLUMzUXXkWBDBBJ",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
SCENE_DESCRIPTIONS = {
|
| 38 |
+
"s1": "办公桌面整理与工作准备",
|
| 39 |
+
"s2": "快递打包发送",
|
| 40 |
+
"s3": "厨房调料整理",
|
| 41 |
+
"s4": "清理餐后桌面",
|
| 42 |
+
"s5": "餐前桌面布置",
|
| 43 |
+
"s6": "商务旅行行李箱打包",
|
| 44 |
+
"s7": "冲泡咖啡/饮品",
|
| 45 |
+
"s8": "晾衣架整理与衣物收纳",
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
COARSE_CATEGORIES = """粗粒度动作类别(共6类):
|
| 49 |
+
|
| 50 |
+
1. Manipulate - 操作物体(抓取、调整、放置某个物体的完整过程,包含拿起→操作→放下的组合)
|
| 51 |
+
2. CleanOrganize - 清洁/整理(擦桌子、理线、整理桌面、叠衣服等持续性整理活动)
|
| 52 |
+
3. Transfer - 搬运/传递(将物体从一个位置搬到另一个位置的过程)
|
| 53 |
+
4. Assemble - 组装/连接/包装(封箱、贴胶带、盖盖子、插电源、拧瓶盖等需要精细对准的操作)
|
| 54 |
+
5. FoodPrep - 食物/饮品准备(倒水、倒调料、搅拌、冲泡等与食物饮品相关的操作)
|
| 55 |
+
6. Idle - 空闲/过渡(无明确操作的间隔)
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
current_key_idx = 0
|
| 59 |
+
call_count = 0
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def call_llm(prompt, max_tokens=1500, retries=3):
|
| 63 |
+
global current_key_idx, call_count
|
| 64 |
+
for attempt in range(retries * len(API_KEYS)):
|
| 65 |
+
key = API_KEYS[current_key_idx]
|
| 66 |
+
try:
|
| 67 |
+
data = json.dumps({
|
| 68 |
+
"model": "gpt-4o-mini",
|
| 69 |
+
"messages": [{"role": "user", "content": prompt}],
|
| 70 |
+
"max_tokens": max_tokens,
|
| 71 |
+
"temperature": 0.1,
|
| 72 |
+
}).encode()
|
| 73 |
+
req = urllib.request.Request(
|
| 74 |
+
API_URL, data=data,
|
| 75 |
+
headers={"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
|
| 76 |
+
)
|
| 77 |
+
resp = urllib.request.urlopen(req, timeout=30)
|
| 78 |
+
result = json.loads(resp.read())
|
| 79 |
+
call_count += 1
|
| 80 |
+
return result["choices"][0]["message"]["content"]
|
| 81 |
+
except Exception as e:
|
| 82 |
+
err = str(e)
|
| 83 |
+
if any(k in err for k in ["429", "quota", "limit", "402", "403"]):
|
| 84 |
+
current_key_idx = (current_key_idx + 1) % len(API_KEYS)
|
| 85 |
+
else:
|
| 86 |
+
time.sleep(0.5)
|
| 87 |
+
current_key_idx = (current_key_idx + 1) % len(API_KEYS)
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def parse_ts(ts_str):
|
| 92 |
+
"""Parse 'MM:SS' to seconds."""
|
| 93 |
+
m = re.match(r'(\d+):(\d+)', ts_str.strip())
|
| 94 |
+
if m:
|
| 95 |
+
return int(m.group(1)) * 60 + int(m.group(2))
|
| 96 |
+
return 0
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def format_ts(sec):
|
| 100 |
+
"""Format seconds to 'MM:SS'."""
|
| 101 |
+
return f"{sec//60:02d}:{sec%60:02d}"
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def merge_segments_with_llm(segments, scene_id):
|
| 105 |
+
"""Use LLM to merge fine-grained segments into coarse composite actions."""
|
| 106 |
+
scene_desc = SCENE_DESCRIPTIONS.get(scene_id, "日常活动")
|
| 107 |
+
|
| 108 |
+
# Build segment list
|
| 109 |
+
seg_lines = []
|
| 110 |
+
for i, seg in enumerate(segments):
|
| 111 |
+
label = seg.get("action_label", "Idle")
|
| 112 |
+
seg_lines.append(f"{i+1}. [{seg['timestamp']}] {label}: {seg['task']}")
|
| 113 |
+
seg_text = "\n".join(seg_lines)
|
| 114 |
+
|
| 115 |
+
prompt = f"""你是一个动作标注专家。以下是一段"{scene_desc}"录制中的细粒度动作序列(每个2-3秒)。
|
| 116 |
+
请将相关的连续动作合并为粗粒度复合动作,每个复合动作持续5-15秒。
|
| 117 |
+
|
| 118 |
+
合并规则:
|
| 119 |
+
- 围绕同一个物体的连续操作合并为一个(如"抓取杯子→调整→放下"合并为一个Manipulate)
|
| 120 |
+
- 连续的整理/清洁动作合并
|
| 121 |
+
- 合并后的时间范围 = 第一个子动作的开始时间 到 最后一个子动作的结束时间
|
| 122 |
+
- 如果中间有短暂Idle(≤3秒),可以包含进去
|
| 123 |
+
- 每个复合动作必须从6个类别中选一个
|
| 124 |
+
|
| 125 |
+
{COARSE_CATEGORIES}
|
| 126 |
+
|
| 127 |
+
细粒度动作序列:
|
| 128 |
+
{seg_text}
|
| 129 |
+
|
| 130 |
+
请严格按以下JSON格式返回,不要添加任何额外文字:
|
| 131 |
+
[{{"timestamp": "MM:SS-MM:SS", "coarse_action": "类别名", "description": "简��描述这段复合动作", "fine_segments": [子动作编号列表]}}]"""
|
| 132 |
+
|
| 133 |
+
response = call_llm(prompt, max_tokens=2000)
|
| 134 |
+
if response is None:
|
| 135 |
+
return None
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
match = re.search(r'\[.*\]', response, re.DOTALL)
|
| 139 |
+
if match:
|
| 140 |
+
results = json.loads(match.group())
|
| 141 |
+
valid = []
|
| 142 |
+
for r in results:
|
| 143 |
+
if all(k in r for k in ["timestamp", "coarse_action", "description"]):
|
| 144 |
+
# Validate category
|
| 145 |
+
if r["coarse_action"] in {"Manipulate", "CleanOrganize", "Transfer",
|
| 146 |
+
"Assemble", "FoodPrep", "Idle"}:
|
| 147 |
+
valid.append(r)
|
| 148 |
+
return valid
|
| 149 |
+
except (json.JSONDecodeError, KeyError) as e:
|
| 150 |
+
print(f" Parse error: {e}")
|
| 151 |
+
return None
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def process_file(input_path, vol, scenario):
|
| 155 |
+
"""Process one annotation file."""
|
| 156 |
+
data = json.load(open(input_path))
|
| 157 |
+
segments = data["segments"]
|
| 158 |
+
|
| 159 |
+
if not segments:
|
| 160 |
+
return {"fine_segments": segments, "coarse_segments": []}, 0
|
| 161 |
+
|
| 162 |
+
print(f" Merging {len(segments)} fine segments...")
|
| 163 |
+
coarse = merge_segments_with_llm(segments, scenario)
|
| 164 |
+
|
| 165 |
+
if coarse is None:
|
| 166 |
+
# Fallback: simple time-based merging without LLM
|
| 167 |
+
print(f" LLM failed, using fallback merge")
|
| 168 |
+
coarse = fallback_merge(segments)
|
| 169 |
+
|
| 170 |
+
result = {
|
| 171 |
+
"fine_segments": segments,
|
| 172 |
+
"coarse_segments": coarse,
|
| 173 |
+
}
|
| 174 |
+
return result, len(coarse)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def fallback_merge(segments):
|
| 178 |
+
"""Simple rule-based merging as fallback."""
|
| 179 |
+
if not segments:
|
| 180 |
+
return []
|
| 181 |
+
|
| 182 |
+
coarse = []
|
| 183 |
+
group = [segments[0]]
|
| 184 |
+
|
| 185 |
+
for seg in segments[1:]:
|
| 186 |
+
# Parse timestamps
|
| 187 |
+
prev_ts = group[-1]["timestamp"]
|
| 188 |
+
curr_ts = seg["timestamp"]
|
| 189 |
+
m1 = re.match(r'(\d+:\d+)\s*-\s*(\d+:\d+)', prev_ts)
|
| 190 |
+
m2 = re.match(r'(\d+:\d+)\s*-\s*(\d+:\d+)', curr_ts)
|
| 191 |
+
if not m1 or not m2:
|
| 192 |
+
group.append(seg)
|
| 193 |
+
continue
|
| 194 |
+
|
| 195 |
+
prev_end = parse_ts(m1.group(2))
|
| 196 |
+
curr_start = parse_ts(m2.group(1))
|
| 197 |
+
gap = curr_start - prev_end
|
| 198 |
+
|
| 199 |
+
# Merge if gap ≤ 3s and group duration < 15s
|
| 200 |
+
group_start = parse_ts(re.match(r'(\d+:\d+)', group[0]["timestamp"]).group(1))
|
| 201 |
+
curr_end = parse_ts(m2.group(2))
|
| 202 |
+
group_duration = curr_end - group_start
|
| 203 |
+
|
| 204 |
+
if gap <= 3 and group_duration <= 15:
|
| 205 |
+
group.append(seg)
|
| 206 |
+
else:
|
| 207 |
+
# Emit current group
|
| 208 |
+
coarse.append(_emit_group(group))
|
| 209 |
+
group = [seg]
|
| 210 |
+
|
| 211 |
+
if group:
|
| 212 |
+
coarse.append(_emit_group(group))
|
| 213 |
+
|
| 214 |
+
return coarse
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def _emit_group(group):
|
| 218 |
+
"""Create a coarse segment from a group of fine segments."""
|
| 219 |
+
m_start = re.match(r'(\d+:\d+)', group[0]["timestamp"])
|
| 220 |
+
m_end = re.match(r'\d+:\d+\s*-\s*(\d+:\d+)', group[-1]["timestamp"])
|
| 221 |
+
start = m_start.group(1) if m_start else "00:00"
|
| 222 |
+
end = m_end.group(1) if m_end else "00:00"
|
| 223 |
+
|
| 224 |
+
labels = [seg.get("action_label", "Idle") for seg in group]
|
| 225 |
+
label_counts = Counter(labels)
|
| 226 |
+
dominant = label_counts.most_common(1)[0][0]
|
| 227 |
+
|
| 228 |
+
# Map fine label to coarse
|
| 229 |
+
label_map = {
|
| 230 |
+
"Grasp": "Manipulate", "Place": "Manipulate", "Arrange": "CleanOrganize",
|
| 231 |
+
"Wipe": "CleanOrganize", "Fold": "CleanOrganize", "Transport": "Transfer",
|
| 232 |
+
"OpenClose": "Assemble", "TearCut": "Assemble",
|
| 233 |
+
"Pour": "FoodPrep", "Stir": "FoodPrep", "Idle": "Idle",
|
| 234 |
+
}
|
| 235 |
+
coarse_label = label_map.get(dominant, "Manipulate")
|
| 236 |
+
|
| 237 |
+
tasks = [seg["task"] for seg in group]
|
| 238 |
+
desc = tasks[0] if len(tasks) == 1 else f"{tasks[0]}...{tasks[-1]}"
|
| 239 |
+
|
| 240 |
+
return {
|
| 241 |
+
"timestamp": f"{start}-{end}",
|
| 242 |
+
"coarse_action": coarse_label,
|
| 243 |
+
"description": desc[:80],
|
| 244 |
+
"fine_segments": list(range(1, len(group) + 1)),
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def main():
|
| 249 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 250 |
+
|
| 251 |
+
total_fine = 0
|
| 252 |
+
total_coarse = 0
|
| 253 |
+
total_files = 0
|
| 254 |
+
coarse_labels = Counter()
|
| 255 |
+
|
| 256 |
+
for vol_dir in sorted(glob.glob(f"{INPUT_DIR}/v*")):
|
| 257 |
+
vol = os.path.basename(vol_dir)
|
| 258 |
+
out_dir = os.path.join(OUTPUT_DIR, vol)
|
| 259 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 260 |
+
|
| 261 |
+
for ann_file in sorted(glob.glob(f"{vol_dir}/s*.json")):
|
| 262 |
+
scenario = os.path.basename(ann_file).replace(".json", "")
|
| 263 |
+
print(f"[{vol}/{scenario}]", flush=True)
|
| 264 |
+
|
| 265 |
+
result, n_coarse = process_file(ann_file, vol, scenario)
|
| 266 |
+
|
| 267 |
+
out_path = os.path.join(out_dir, f"{scenario}.json")
|
| 268 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
| 269 |
+
json.dump(result, f, ensure_ascii=False, indent=2)
|
| 270 |
+
|
| 271 |
+
n_fine = len(result["fine_segments"])
|
| 272 |
+
total_fine += n_fine
|
| 273 |
+
total_coarse += n_coarse
|
| 274 |
+
total_files += 1
|
| 275 |
+
|
| 276 |
+
for seg in result["coarse_segments"]:
|
| 277 |
+
coarse_labels[seg["coarse_action"]] += 1
|
| 278 |
+
|
| 279 |
+
print(f" {n_fine} fine → {n_coarse} coarse segments", flush=True)
|
| 280 |
+
|
| 281 |
+
print(f"\n{'='*60}")
|
| 282 |
+
print(f"Total: {total_files} files")
|
| 283 |
+
print(f" Fine segments: {total_fine}")
|
| 284 |
+
print(f" Coarse segments: {total_coarse}")
|
| 285 |
+
print(f" Compression: {total_fine/max(total_coarse,1):.1f}x")
|
| 286 |
+
print(f" API calls: {call_count}")
|
| 287 |
+
|
| 288 |
+
print(f"\n Coarse label distribution:")
|
| 289 |
+
for label, count in coarse_labels.most_common():
|
| 290 |
+
print(f" {label:<20} {count:>5} ({count/max(total_coarse,1)*100:.1f}%)")
|
| 291 |
+
|
| 292 |
+
print(f"\n Output: {OUTPUT_DIR}")
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
if __name__ == "__main__":
|
| 296 |
+
main()
|
experiments/analysis/grasp_phase_analysis.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Grasp Phase Timing Analysis — Flagship visualization for the paper.
|
| 4 |
+
|
| 5 |
+
Classic neuroscience finding:
|
| 6 |
+
Eye gaze → EMG activation → Hand motion → Pressure contact
|
| 7 |
+
|
| 8 |
+
This script:
|
| 9 |
+
1. Detects grasp events (pressure onset: 0 → >5g)
|
| 10 |
+
2. Looks back in time to find:
|
| 11 |
+
- EMG envelope activation onset
|
| 12 |
+
- Hand velocity peak (from MoCap)
|
| 13 |
+
- Eye gaze fixation (if available)
|
| 14 |
+
3. Computes statistics over all grasp events
|
| 15 |
+
4. Produces the canonical "grasp phase" timing figure
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import glob
|
| 20 |
+
import json
|
| 21 |
+
import numpy as np
|
| 22 |
+
import pandas as pd
|
| 23 |
+
import matplotlib
|
| 24 |
+
matplotlib.use('Agg')
|
| 25 |
+
import matplotlib.pyplot as plt
|
| 26 |
+
from scipy import signal as scisig
|
| 27 |
+
from collections import defaultdict
|
| 28 |
+
|
| 29 |
+
DATASET_DIR = "${PULSE_ROOT}/dataset"
|
| 30 |
+
OUTPUT_DIR = "${PULSE_ROOT}/results/grasp_phase"
|
| 31 |
+
SAMPLING_RATE = 100 # Hz
|
| 32 |
+
PRESSURE_THRESHOLD = 5.0 # grams
|
| 33 |
+
CONTEXT_WINDOW_SEC = 2.0 # look back 2s before contact
|
| 34 |
+
CONTEXT_FRAMES = int(CONTEXT_WINDOW_SEC * SAMPLING_RATE)
|
| 35 |
+
|
| 36 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def load_pressure(scenario_dir):
|
| 40 |
+
"""Load pressure data and return (T, 2) array: [right_total, left_total]."""
|
| 41 |
+
f = os.path.join(scenario_dir, "aligned_pressure_100hz.csv")
|
| 42 |
+
if not os.path.exists(f):
|
| 43 |
+
return None
|
| 44 |
+
df = pd.read_csv(f, low_memory=False)
|
| 45 |
+
r_cols = [c for c in df.columns if c.startswith('R') and c.endswith('(g)')]
|
| 46 |
+
l_cols = [c for c in df.columns if c.startswith('L') and c.endswith('(g)')]
|
| 47 |
+
if not r_cols or not l_cols:
|
| 48 |
+
return None
|
| 49 |
+
r = df[r_cols].apply(pd.to_numeric, errors='coerce').fillna(0).values.sum(axis=1)
|
| 50 |
+
l = df[l_cols].apply(pd.to_numeric, errors='coerce').fillna(0).values.sum(axis=1)
|
| 51 |
+
return np.stack([r, l], axis=1) # (T, 2)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def load_emg(scenario_dir):
|
| 55 |
+
"""Load EMG data: (T, 8) array."""
|
| 56 |
+
f = os.path.join(scenario_dir, "aligned_emg_100hz.csv")
|
| 57 |
+
if not os.path.exists(f):
|
| 58 |
+
return None
|
| 59 |
+
df = pd.read_csv(f, low_memory=False)
|
| 60 |
+
# Find EMG channel columns (e.g., EMG1...EMG8 or channels)
|
| 61 |
+
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
| 62 |
+
numeric_cols = [c for c in numeric_cols if c not in ('Frame', 'Time', 'time', 'UTC')]
|
| 63 |
+
if len(numeric_cols) < 4:
|
| 64 |
+
return None
|
| 65 |
+
arr = df[numeric_cols].values.astype(np.float32)
|
| 66 |
+
arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
|
| 67 |
+
return arr
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def load_mocap(scenario_dir, vol, scenario):
|
| 71 |
+
"""Load MoCap hand position, return (T, 3) right hand velocity magnitude, (T, 3) left hand."""
|
| 72 |
+
f = os.path.join(scenario_dir, f"aligned_{vol}{scenario}_s_Q.tsv")
|
| 73 |
+
if not os.path.exists(f):
|
| 74 |
+
return None, None
|
| 75 |
+
df = pd.read_csv(f, sep='\t', low_memory=False)
|
| 76 |
+
# Find right/left hand position columns
|
| 77 |
+
# Try common naming patterns
|
| 78 |
+
r_cols = [c for c in df.columns if 'RightHand' in c and (c.endswith('_X') or c.endswith('_Y') or c.endswith('_Z'))]
|
| 79 |
+
l_cols = [c for c in df.columns if 'LeftHand' in c and (c.endswith('_X') or c.endswith('_Y') or c.endswith('_Z'))]
|
| 80 |
+
if not r_cols or not l_cols:
|
| 81 |
+
# Try alternative naming
|
| 82 |
+
r_cols = [c for c in df.columns if 'R_Hand' in c or 'RHand' in c][:3]
|
| 83 |
+
l_cols = [c for c in df.columns if 'L_Hand' in c or 'LHand' in c][:3]
|
| 84 |
+
if not r_cols or not l_cols:
|
| 85 |
+
return None, None
|
| 86 |
+
|
| 87 |
+
r_pos = df[r_cols[:3]].apply(pd.to_numeric, errors='coerce').fillna(0).values
|
| 88 |
+
l_pos = df[l_cols[:3]].apply(pd.to_numeric, errors='coerce').fillna(0).values
|
| 89 |
+
return r_pos, l_pos
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def compute_emg_envelope(emg, window_size=20):
|
| 93 |
+
"""Rectify and low-pass filter EMG to get envelope."""
|
| 94 |
+
# Rectify
|
| 95 |
+
rectified = np.abs(emg - np.mean(emg, axis=0))
|
| 96 |
+
# Moving average
|
| 97 |
+
kernel = np.ones(window_size) / window_size
|
| 98 |
+
envelope = np.zeros_like(rectified)
|
| 99 |
+
for ch in range(rectified.shape[1]):
|
| 100 |
+
envelope[:, ch] = np.convolve(rectified[:, ch], kernel, mode='same')
|
| 101 |
+
# Sum across channels and normalize
|
| 102 |
+
total = envelope.sum(axis=1)
|
| 103 |
+
if total.max() > total.min():
|
| 104 |
+
total = (total - total.min()) / (total.max() - total.min() + 1e-8)
|
| 105 |
+
return total # (T,)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def compute_velocity(position, window=3):
|
| 109 |
+
"""Compute velocity magnitude from 3D position."""
|
| 110 |
+
vel = np.zeros_like(position)
|
| 111 |
+
vel[1:] = position[1:] - position[:-1]
|
| 112 |
+
vel_mag = np.linalg.norm(vel, axis=1)
|
| 113 |
+
# Smooth
|
| 114 |
+
kernel = np.ones(window) / window
|
| 115 |
+
vel_mag = np.convolve(vel_mag, kernel, mode='same')
|
| 116 |
+
return vel_mag # (T,)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def detect_grasp_events(pressure_1d, threshold=5.0, min_duration=10, min_gap=50):
|
| 120 |
+
"""Detect pressure onset events (0 → >threshold).
|
| 121 |
+
|
| 122 |
+
Returns list of onset frame indices.
|
| 123 |
+
"""
|
| 124 |
+
above = pressure_1d > threshold
|
| 125 |
+
# Hysteresis smoothing: require persistence
|
| 126 |
+
onsets = []
|
| 127 |
+
last_state = False
|
| 128 |
+
stable_counter = 0
|
| 129 |
+
for i, a in enumerate(above):
|
| 130 |
+
if a and not last_state:
|
| 131 |
+
# Candidate onset, check persistence
|
| 132 |
+
if i + min_duration < len(above) and np.mean(above[i:i+min_duration]) > 0.7:
|
| 133 |
+
if not onsets or i - onsets[-1] > min_gap:
|
| 134 |
+
onsets.append(i)
|
| 135 |
+
last_state = True
|
| 136 |
+
elif not a and last_state:
|
| 137 |
+
# Check if really released
|
| 138 |
+
if i + 5 < len(above) and np.mean(above[i:i+5]) < 0.3:
|
| 139 |
+
last_state = False
|
| 140 |
+
return onsets
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def find_signal_onset(signal, ref_idx, window_frames, threshold_ratio=0.3):
|
| 144 |
+
"""Find the LATEST pre-contact onset of signal activation.
|
| 145 |
+
|
| 146 |
+
Strategy: walk backward from ref_idx. Look for the last sample that's
|
| 147 |
+
still 'active' (> baseline + threshold_ratio * (peak-baseline)).
|
| 148 |
+
The first 'inactive' sample going backward marks the onset.
|
| 149 |
+
|
| 150 |
+
Returns: frame index of onset relative to ref_idx (negative = before).
|
| 151 |
+
"""
|
| 152 |
+
start = max(0, ref_idx - window_frames)
|
| 153 |
+
segment = signal[start:ref_idx + 1] # pre-contact window
|
| 154 |
+
if len(segment) < 10:
|
| 155 |
+
return None
|
| 156 |
+
|
| 157 |
+
# Baseline: lower quartile of the pre-contact window (robust to activation)
|
| 158 |
+
# Only use the earliest 30% as baseline estimate
|
| 159 |
+
early_part = segment[:max(10, int(len(segment) * 0.3))]
|
| 160 |
+
baseline = np.percentile(early_part, 25)
|
| 161 |
+
|
| 162 |
+
# Peak of the pre-contact activation
|
| 163 |
+
peak = np.max(segment)
|
| 164 |
+
if peak - baseline < 1e-4:
|
| 165 |
+
return None
|
| 166 |
+
|
| 167 |
+
threshold = baseline + (peak - baseline) * threshold_ratio
|
| 168 |
+
|
| 169 |
+
# Walk BACKWARD from ref_idx: find the last consecutive 'active' region
|
| 170 |
+
# ending at ref_idx, then the onset is where that region starts
|
| 171 |
+
above = segment > threshold
|
| 172 |
+
if not above[-1]:
|
| 173 |
+
# Not active at contact - use threshold crossing pattern
|
| 174 |
+
# Find the rising edge closest to ref_idx
|
| 175 |
+
rising = np.where(np.diff(above.astype(int)) == 1)[0]
|
| 176 |
+
if len(rising) == 0:
|
| 177 |
+
return None
|
| 178 |
+
onset_local = rising[-1] + 1 # first active frame
|
| 179 |
+
else:
|
| 180 |
+
# Active at contact - walk back to find onset
|
| 181 |
+
onset_local = len(segment) - 1
|
| 182 |
+
while onset_local > 0 and above[onset_local - 1]:
|
| 183 |
+
onset_local -= 1
|
| 184 |
+
|
| 185 |
+
onset_global = start + onset_local
|
| 186 |
+
return onset_global - ref_idx # negative = before contact
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def is_clean_grasp(emg_env, velocity, pressure_trace, onset, look_back=150, rest_window=50):
|
| 190 |
+
"""Check if this is a CLEAN grasp starting from rest.
|
| 191 |
+
|
| 192 |
+
Requires: EMG and velocity are both low in the REST window (onset-150 ~ onset-100).
|
| 193 |
+
"""
|
| 194 |
+
rest_start = onset - look_back
|
| 195 |
+
rest_end = onset - (look_back - rest_window)
|
| 196 |
+
if rest_start < 0:
|
| 197 |
+
return False
|
| 198 |
+
|
| 199 |
+
# Quiescent rest period: EMG and velocity both low
|
| 200 |
+
emg_rest = emg_env[rest_start:rest_end].mean()
|
| 201 |
+
vel_rest = velocity[rest_start:rest_end].mean()
|
| 202 |
+
|
| 203 |
+
# Compare to the entire pre-contact activation
|
| 204 |
+
emg_pre = emg_env[rest_end:onset]
|
| 205 |
+
vel_pre = velocity[rest_end:onset]
|
| 206 |
+
|
| 207 |
+
if len(emg_pre) < 10:
|
| 208 |
+
return False
|
| 209 |
+
|
| 210 |
+
# The rest period should be significantly lower than the activation period
|
| 211 |
+
emg_active = np.percentile(emg_pre, 75)
|
| 212 |
+
vel_active = np.percentile(vel_pre, 75)
|
| 213 |
+
|
| 214 |
+
emg_increase = emg_active - emg_rest
|
| 215 |
+
vel_increase = vel_active - vel_rest
|
| 216 |
+
|
| 217 |
+
# Require meaningful increase from rest to activation
|
| 218 |
+
emg_dyn = emg_env.max() - emg_env.min()
|
| 219 |
+
vel_dyn = velocity.max() - velocity.min()
|
| 220 |
+
|
| 221 |
+
if emg_dyn < 1e-6 or vel_dyn < 1e-6:
|
| 222 |
+
return False
|
| 223 |
+
|
| 224 |
+
return (emg_increase / emg_dyn > 0.1) and (vel_increase / vel_dyn > 0.1)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def analyze_one_scenario(vol, scenario):
|
| 228 |
+
"""Analyze clean grasp events starting from rest."""
|
| 229 |
+
scenario_dir = os.path.join(DATASET_DIR, vol, scenario)
|
| 230 |
+
|
| 231 |
+
pressure = load_pressure(scenario_dir)
|
| 232 |
+
emg = load_emg(scenario_dir)
|
| 233 |
+
mocap_r, mocap_l = load_mocap(scenario_dir, vol, scenario)
|
| 234 |
+
|
| 235 |
+
if pressure is None or emg is None or mocap_r is None:
|
| 236 |
+
return None
|
| 237 |
+
|
| 238 |
+
min_len = min(pressure.shape[0], emg.shape[0], mocap_r.shape[0])
|
| 239 |
+
pressure = pressure[:min_len]
|
| 240 |
+
emg = emg[:min_len]
|
| 241 |
+
mocap_r = mocap_r[:min_len]
|
| 242 |
+
mocap_l = mocap_l[:min_len]
|
| 243 |
+
|
| 244 |
+
emg_env = compute_emg_envelope(emg)
|
| 245 |
+
vel_r = compute_velocity(mocap_r)
|
| 246 |
+
vel_l = compute_velocity(mocap_l)
|
| 247 |
+
|
| 248 |
+
events = []
|
| 249 |
+
|
| 250 |
+
for hand_name, hand_pressure, hand_vel in [
|
| 251 |
+
('right', pressure[:, 0], vel_r),
|
| 252 |
+
('left', pressure[:, 1], vel_l),
|
| 253 |
+
]:
|
| 254 |
+
onsets = detect_grasp_events(hand_pressure, threshold=PRESSURE_THRESHOLD)
|
| 255 |
+
for onset in onsets:
|
| 256 |
+
if onset < CONTEXT_FRAMES:
|
| 257 |
+
continue
|
| 258 |
+
|
| 259 |
+
# Filter: only clean grasps starting from rest
|
| 260 |
+
if not is_clean_grasp(emg_env, hand_vel, hand_pressure, onset):
|
| 261 |
+
continue
|
| 262 |
+
|
| 263 |
+
# Find EMG onset: look for sustained activation rising from rest
|
| 264 |
+
emg_delay = find_signal_onset(emg_env, onset, CONTEXT_FRAMES, threshold_ratio=0.3)
|
| 265 |
+
motion_delay = find_signal_onset(hand_vel, onset, CONTEXT_FRAMES, threshold_ratio=0.3)
|
| 266 |
+
if emg_delay is None or motion_delay is None:
|
| 267 |
+
continue
|
| 268 |
+
|
| 269 |
+
# Sanity check: delays should be within [-1500, 0] ms
|
| 270 |
+
if emg_delay * 10 < -1500 or emg_delay * 10 > 0:
|
| 271 |
+
continue
|
| 272 |
+
if motion_delay * 10 < -1500 or motion_delay * 10 > 0:
|
| 273 |
+
continue
|
| 274 |
+
|
| 275 |
+
start = onset - CONTEXT_FRAMES
|
| 276 |
+
end = onset + 50
|
| 277 |
+
events.append({
|
| 278 |
+
'pressure': hand_pressure[start:end],
|
| 279 |
+
'emg': emg_env[start:end],
|
| 280 |
+
'velocity': hand_vel[start:end],
|
| 281 |
+
'hand': hand_name,
|
| 282 |
+
'onset_idx': onset,
|
| 283 |
+
'emg_delay_ms': emg_delay * 10,
|
| 284 |
+
'motion_delay_ms': motion_delay * 10,
|
| 285 |
+
})
|
| 286 |
+
|
| 287 |
+
return events
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def main():
|
| 291 |
+
all_events = []
|
| 292 |
+
stats = defaultdict(int)
|
| 293 |
+
|
| 294 |
+
for vol_dir in sorted(glob.glob(f"{DATASET_DIR}/v*")):
|
| 295 |
+
vol = os.path.basename(vol_dir)
|
| 296 |
+
for scenario_dir in sorted(glob.glob(f"{vol_dir}/s*")):
|
| 297 |
+
scenario = os.path.basename(scenario_dir)
|
| 298 |
+
meta_path = os.path.join(scenario_dir, 'alignment_metadata.json')
|
| 299 |
+
if not os.path.exists(meta_path):
|
| 300 |
+
continue
|
| 301 |
+
meta = json.load(open(meta_path))
|
| 302 |
+
# Need all 3 modalities
|
| 303 |
+
if not {'pressure', 'emg', 'mocap'}.issubset(set(meta['modalities'])):
|
| 304 |
+
stats['no_modality'] += 1
|
| 305 |
+
continue
|
| 306 |
+
|
| 307 |
+
events = analyze_one_scenario(vol, scenario)
|
| 308 |
+
if events is None:
|
| 309 |
+
stats['load_error'] += 1
|
| 310 |
+
continue
|
| 311 |
+
all_events.extend(events)
|
| 312 |
+
stats['scenarios'] += 1
|
| 313 |
+
stats['events'] += len(events)
|
| 314 |
+
print(f"[{vol}/{scenario}] {len(events)} grasp events", flush=True)
|
| 315 |
+
|
| 316 |
+
print(f"\n=== Summary ===")
|
| 317 |
+
print(f"Scenarios processed: {stats['scenarios']}")
|
| 318 |
+
print(f"Total grasp events: {stats['events']}")
|
| 319 |
+
print(f"Loading errors: {stats['load_error']}")
|
| 320 |
+
print(f"Missing modality: {stats['no_modality']}")
|
| 321 |
+
|
| 322 |
+
if not all_events:
|
| 323 |
+
print("No events found!")
|
| 324 |
+
return
|
| 325 |
+
|
| 326 |
+
# Extract delays
|
| 327 |
+
emg_delays = np.array([e['emg_delay_ms'] for e in all_events])
|
| 328 |
+
motion_delays = np.array([e['motion_delay_ms'] for e in all_events])
|
| 329 |
+
|
| 330 |
+
print(f"\n=== Timing Statistics (ms, negative = before contact) ===")
|
| 331 |
+
print(f"EMG onset delay: mean={emg_delays.mean():.1f} median={np.median(emg_delays):.1f} std={emg_delays.std():.1f}")
|
| 332 |
+
print(f"Motion peak delay: mean={motion_delays.mean():.1f} median={np.median(motion_delays):.1f} std={motion_delays.std():.1f}")
|
| 333 |
+
|
| 334 |
+
# Save statistics
|
| 335 |
+
stats_dict = {
|
| 336 |
+
'n_events': len(all_events),
|
| 337 |
+
'emg_delay_ms': {'mean': float(emg_delays.mean()), 'median': float(np.median(emg_delays)),
|
| 338 |
+
'std': float(emg_delays.std()), 'p25': float(np.percentile(emg_delays, 25)),
|
| 339 |
+
'p75': float(np.percentile(emg_delays, 75))},
|
| 340 |
+
'motion_delay_ms': {'mean': float(motion_delays.mean()), 'median': float(np.median(motion_delays)),
|
| 341 |
+
'std': float(motion_delays.std()), 'p25': float(np.percentile(motion_delays, 25)),
|
| 342 |
+
'p75': float(np.percentile(motion_delays, 75))},
|
| 343 |
+
}
|
| 344 |
+
with open(os.path.join(OUTPUT_DIR, 'timing_stats.json'), 'w') as f:
|
| 345 |
+
json.dump(stats_dict, f, indent=2)
|
| 346 |
+
|
| 347 |
+
# ============ Figure 1: Aligned signal traces (averaged) ============
|
| 348 |
+
# Filter to events that have sufficient context
|
| 349 |
+
valid = [e for e in all_events if len(e['pressure']) == CONTEXT_FRAMES + 50]
|
| 350 |
+
print(f"\nEvents with full context: {len(valid)} / {len(all_events)}")
|
| 351 |
+
|
| 352 |
+
if len(valid) < 10:
|
| 353 |
+
print("Not enough events for plotting")
|
| 354 |
+
return
|
| 355 |
+
|
| 356 |
+
# Normalize signals (per-event max)
|
| 357 |
+
def normalize(sigs):
|
| 358 |
+
sigs = np.stack(sigs)
|
| 359 |
+
# Normalize each to [0, 1]
|
| 360 |
+
sigs = sigs - sigs.min(axis=1, keepdims=True)
|
| 361 |
+
maxs = sigs.max(axis=1, keepdims=True)
|
| 362 |
+
sigs = sigs / (maxs + 1e-8)
|
| 363 |
+
return sigs
|
| 364 |
+
|
| 365 |
+
pressure_stack = normalize([e['pressure'] for e in valid])
|
| 366 |
+
emg_stack = normalize([e['emg'] for e in valid])
|
| 367 |
+
vel_stack = normalize([e['velocity'] for e in valid])
|
| 368 |
+
|
| 369 |
+
time_axis = np.arange(-CONTEXT_FRAMES, 50) * 10 # ms
|
| 370 |
+
|
| 371 |
+
fig, ax = plt.subplots(figsize=(9, 5))
|
| 372 |
+
|
| 373 |
+
# Plot mean ± std
|
| 374 |
+
for sigs, color, label in [
|
| 375 |
+
(emg_stack, '#E74C3C', 'EMG envelope'),
|
| 376 |
+
(vel_stack, '#3498DB', 'Hand velocity'),
|
| 377 |
+
(pressure_stack, '#27AE60', 'Pressure (contact)'),
|
| 378 |
+
]:
|
| 379 |
+
mean = sigs.mean(axis=0)
|
| 380 |
+
std = sigs.std(axis=0)
|
| 381 |
+
ax.plot(time_axis, mean, color=color, linewidth=2.5, label=label)
|
| 382 |
+
ax.fill_between(time_axis, mean - std * 0.5, mean + std * 0.5, color=color, alpha=0.15)
|
| 383 |
+
|
| 384 |
+
ax.axvline(0, color='black', linestyle='--', linewidth=1.2, alpha=0.7, label='Contact onset')
|
| 385 |
+
ax.axvline(emg_delays.mean(), color='#E74C3C', linestyle=':', alpha=0.8)
|
| 386 |
+
ax.axvline(motion_delays.mean(), color='#3498DB', linestyle=':', alpha=0.8)
|
| 387 |
+
|
| 388 |
+
# Annotations
|
| 389 |
+
ax.annotate(f'EMG\n{emg_delays.mean():.0f}ms',
|
| 390 |
+
xy=(emg_delays.mean(), 0.85), ha='center', fontsize=10, color='#C0392B',
|
| 391 |
+
bbox=dict(boxstyle="round,pad=0.3", fc='#FADBD8', ec='#E74C3C', alpha=0.9))
|
| 392 |
+
ax.annotate(f'Motion\n{motion_delays.mean():.0f}ms',
|
| 393 |
+
xy=(motion_delays.mean(), 0.65), ha='center', fontsize=10, color='#1F618D',
|
| 394 |
+
bbox=dict(boxstyle="round,pad=0.3", fc='#D6EAF8', ec='#3498DB', alpha=0.9))
|
| 395 |
+
|
| 396 |
+
ax.set_xlabel('Time relative to contact onset (ms)', fontsize=12)
|
| 397 |
+
ax.set_ylabel('Normalized amplitude', fontsize=12)
|
| 398 |
+
ax.set_title(f'Grasp Phase Timing ({len(valid)} events, {stats["scenarios"]} recordings)',
|
| 399 |
+
fontsize=13, fontweight='bold')
|
| 400 |
+
ax.set_xlim(-CONTEXT_WINDOW_SEC * 1000, 500)
|
| 401 |
+
ax.legend(loc='upper left', frameon=True, fontsize=10)
|
| 402 |
+
ax.grid(True, alpha=0.3)
|
| 403 |
+
ax.set_ylim(-0.05, 1.1)
|
| 404 |
+
|
| 405 |
+
plt.tight_layout()
|
| 406 |
+
fig_path = os.path.join(OUTPUT_DIR, 'grasp_phase_timing.png')
|
| 407 |
+
plt.savefig(fig_path, dpi=150, bbox_inches='tight')
|
| 408 |
+
plt.savefig(fig_path.replace('.png', '.pdf'), bbox_inches='tight')
|
| 409 |
+
print(f"Saved figure: {fig_path}")
|
| 410 |
+
|
| 411 |
+
# ============ Figure 2: Delay distributions ============
|
| 412 |
+
fig, axes = plt.subplots(1, 2, figsize=(11, 4))
|
| 413 |
+
|
| 414 |
+
axes[0].hist(emg_delays, bins=30, color='#E74C3C', alpha=0.7, edgecolor='black')
|
| 415 |
+
axes[0].axvline(emg_delays.mean(), color='black', linestyle='--', linewidth=2, label=f'Mean: {emg_delays.mean():.0f}ms')
|
| 416 |
+
axes[0].axvline(np.median(emg_delays), color='grey', linestyle=':', linewidth=2, label=f'Median: {np.median(emg_delays):.0f}ms')
|
| 417 |
+
axes[0].set_xlabel('EMG onset - Contact onset (ms)', fontsize=11)
|
| 418 |
+
axes[0].set_ylabel('Count', fontsize=11)
|
| 419 |
+
axes[0].set_title('EMG → Contact Delay', fontsize=12, fontweight='bold')
|
| 420 |
+
axes[0].legend(fontsize=10)
|
| 421 |
+
axes[0].grid(True, alpha=0.3)
|
| 422 |
+
|
| 423 |
+
axes[1].hist(motion_delays, bins=30, color='#3498DB', alpha=0.7, edgecolor='black')
|
| 424 |
+
axes[1].axvline(motion_delays.mean(), color='black', linestyle='--', linewidth=2, label=f'Mean: {motion_delays.mean():.0f}ms')
|
| 425 |
+
axes[1].axvline(np.median(motion_delays), color='grey', linestyle=':', linewidth=2, label=f'Median: {np.median(motion_delays):.0f}ms')
|
| 426 |
+
axes[1].set_xlabel('Motion onset - Contact onset (ms)', fontsize=11)
|
| 427 |
+
axes[1].set_ylabel('Count', fontsize=11)
|
| 428 |
+
axes[1].set_title('Hand Motion → Contact Delay', fontsize=12, fontweight='bold')
|
| 429 |
+
axes[1].legend(fontsize=10)
|
| 430 |
+
axes[1].grid(True, alpha=0.3)
|
| 431 |
+
|
| 432 |
+
plt.tight_layout()
|
| 433 |
+
fig2_path = os.path.join(OUTPUT_DIR, 'delay_distributions.png')
|
| 434 |
+
plt.savefig(fig2_path, dpi=150, bbox_inches='tight')
|
| 435 |
+
plt.savefig(fig2_path.replace('.png', '.pdf'), bbox_inches='tight')
|
| 436 |
+
print(f"Saved figure: {fig2_path}")
|
| 437 |
+
|
| 438 |
+
print(f"\nAll outputs saved to: {OUTPUT_DIR}")
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
if __name__ == '__main__':
|
| 442 |
+
main()
|
experiments/analysis/modality_viz.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Visualize mocap skeleton frames, IMU waveforms, EMG waveforms."""
|
| 2 |
+
import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
|
| 3 |
+
from mpl_toolkits.mplot3d import Axes3D # noqa
|
| 4 |
+
|
| 5 |
+
REC = "${PULSE_ROOT}/dataset/v1/s1"
|
| 6 |
+
OUT = "${PULSE_ROOT}/paper/figures"
|
| 7 |
+
os.makedirs(OUT, exist_ok=True)
|
| 8 |
+
|
| 9 |
+
# ---- Skeleton bone definition (marker pairs) ----
|
| 10 |
+
BONES = [
|
| 11 |
+
# torso
|
| 12 |
+
("HeadTop","HeadFront"),("HeadL","HeadR"),("HeadFront","SpineTop"),
|
| 13 |
+
("SpineTop","Chest"),("Chest","WaistLFront"),("Chest","WaistRFront"),
|
| 14 |
+
("WaistLFront","WaistLBack"),("WaistRFront","WaistRBack"),
|
| 15 |
+
("WaistLBack","BackL"),("WaistRBack","BackR"),("BackL","BackR"),
|
| 16 |
+
("SpineTop","LShoulderTop"),("SpineTop","RShoulderTop"),
|
| 17 |
+
("LShoulderTop","LShoulderBack"),("RShoulderTop","RShoulderBack"),
|
| 18 |
+
# left arm
|
| 19 |
+
("LShoulderTop","LArm"),("LArm","LElbowOut"),("LElbowOut","LElbowBack"),
|
| 20 |
+
("LElbowOut","LForearmRoll"),("LForearmRoll","LWristOut"),
|
| 21 |
+
("LWristOut","LWristIn"),("LWristOut","LHandOut"),("LWristIn","LHandIn"),
|
| 22 |
+
("LHandOut","LIndex2"),("LIndex2","LIndexTip"),
|
| 23 |
+
("LHandOut","LMiddle2"),("LMiddle2","LMiddleTip"),
|
| 24 |
+
("LHandIn","LRing2"),("LRing2","LRingTip"),
|
| 25 |
+
("LHandIn","LPinky2"),("LPinky2","LPinkyTip"),
|
| 26 |
+
("LWristIn","LThumb1"),("LThumb1","LThumbTip"),
|
| 27 |
+
# right arm
|
| 28 |
+
("RShoulderTop","RArm"),("RArm","RElbowOut"),("RElbowOut","RElbowBack"),
|
| 29 |
+
("RElbowOut","RForearmRoll"),("RForearmRoll","RWristOut"),
|
| 30 |
+
("RWristOut","RWristIn"),("RWristOut","RHandOut"),("RWristIn","RHandIn"),
|
| 31 |
+
("RHandOut","RIndex2"),("RIndex2","RIndexTip"),
|
| 32 |
+
("RHandOut","RMiddle2"),("RMiddle2","RMiddleTip"),
|
| 33 |
+
("RHandIn","RRing2"),("RRing2","RRingTip"),
|
| 34 |
+
("RHandIn","RPinky2"),("RPinky2","RPinkyTip"),
|
| 35 |
+
("RWristIn","RThumb1"),("RThumb1","RThumbTip"),
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def load_mocap(path):
|
| 40 |
+
df = pd.read_csv(path)
|
| 41 |
+
# Extract x,y,z for each marker ignoring Type cols
|
| 42 |
+
markers = {}
|
| 43 |
+
for col in df.columns:
|
| 44 |
+
if col.startswith("Q_") and col.endswith(" X"):
|
| 45 |
+
name = col[2:-2]
|
| 46 |
+
xs = df[f"Q_{name} X"].to_numpy()
|
| 47 |
+
ys = df[f"Q_{name} Y"].to_numpy()
|
| 48 |
+
zs = df[f"Q_{name} Z"].to_numpy()
|
| 49 |
+
markers[name] = np.stack([xs, ys, zs], axis=-1)
|
| 50 |
+
return df["Time"].to_numpy(), markers
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def plot_skeletons():
|
| 54 |
+
t, mk = load_mocap(os.path.join(REC, "aligned_mocap_100hz.csv"))
|
| 55 |
+
N = len(t)
|
| 56 |
+
# pick 4 time frames well spread through the recording with valid data
|
| 57 |
+
candidate = np.linspace(int(0.1*N), int(0.9*N), 4).astype(int)
|
| 58 |
+
|
| 59 |
+
fig = plt.figure(figsize=(12, 3.2))
|
| 60 |
+
for i, fr in enumerate(candidate):
|
| 61 |
+
ax = fig.add_subplot(1, 4, i+1, projection='3d')
|
| 62 |
+
# gather all points at this frame
|
| 63 |
+
pts = np.array([mk[n][fr] for n in mk])
|
| 64 |
+
pts = pts[~np.isnan(pts).any(axis=1)]
|
| 65 |
+
if len(pts) == 0:
|
| 66 |
+
continue
|
| 67 |
+
# draw bones
|
| 68 |
+
for a, b in BONES:
|
| 69 |
+
if a in mk and b in mk:
|
| 70 |
+
pa, pb = mk[a][fr], mk[b][fr]
|
| 71 |
+
if np.isnan(pa).any() or np.isnan(pb).any():
|
| 72 |
+
continue
|
| 73 |
+
ax.plot([pa[0], pb[0]], [pa[1], pb[1]], [pa[2], pb[2]],
|
| 74 |
+
color='#2266aa', lw=1.2)
|
| 75 |
+
ax.scatter(pts[:, 0], pts[:, 1], pts[:, 2], s=4, c='#cc3333', alpha=0.8)
|
| 76 |
+
# equal aspect
|
| 77 |
+
c = pts.mean(0)
|
| 78 |
+
r = np.ptp(pts, axis=0).max() / 2
|
| 79 |
+
ax.set_xlim(c[0]-r, c[0]+r); ax.set_ylim(c[1]-r, c[1]+r); ax.set_zlim(c[2]-r, c[2]+r)
|
| 80 |
+
ax.set_xticks([]); ax.set_yticks([]); ax.set_zticks([])
|
| 81 |
+
ax.set_title(f"t={t[fr]:.1f}s", fontsize=9)
|
| 82 |
+
ax.view_init(elev=12, azim=-75)
|
| 83 |
+
fig.suptitle("MoCap skeleton frames (56-marker Qualisys, v1/s1)", fontsize=11)
|
| 84 |
+
fig.tight_layout()
|
| 85 |
+
out = os.path.join(OUT, "mocap_skeleton.pdf")
|
| 86 |
+
fig.savefig(out, bbox_inches='tight'); fig.savefig(out.replace('.pdf', '.png'), dpi=150, bbox_inches='tight')
|
| 87 |
+
plt.close(fig)
|
| 88 |
+
print("Saved", out)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def plot_imu():
|
| 92 |
+
df = pd.read_csv(os.path.join(REC, "aligned_imu_100hz.csv"))
|
| 93 |
+
t = df["time"].to_numpy(); t = t - t[0]
|
| 94 |
+
# pick 5 body locations (WT0..WT9 order roughly: wrists, forearms, upper arms, shins, thighs, torso)
|
| 95 |
+
sites = [("WT0", "Wrist R"), ("WT2", "Forearm R"),
|
| 96 |
+
("WT4", "Upper arm R"), ("WT6", "Shin R"), ("WT9", "Torso")]
|
| 97 |
+
fig, axes = plt.subplots(len(sites), 1, figsize=(9, 6), sharex=True)
|
| 98 |
+
# crop to 20s window mid-recording
|
| 99 |
+
mid = len(t)//2
|
| 100 |
+
sl = slice(max(0, mid-1000), min(len(t), mid+1000))
|
| 101 |
+
for ax, (sid, lbl) in zip(axes, sites):
|
| 102 |
+
for comp, col in zip(["x", "y", "z"], ["#d62728", "#2ca02c", "#1f77b4"]):
|
| 103 |
+
ax.plot(t[sl], df[f"{sid}_acc_{comp}"].to_numpy()[sl], color=col, lw=0.8, label=f"acc_{comp}")
|
| 104 |
+
ax.set_ylabel(lbl, fontsize=9)
|
| 105 |
+
ax.grid(alpha=0.3)
|
| 106 |
+
axes[0].legend(loc="upper right", ncol=3, fontsize=8)
|
| 107 |
+
axes[-1].set_xlabel("Time (s)")
|
| 108 |
+
fig.suptitle("IMU 3-axis acceleration across 5 body sites (v1/s1, 20s window)", fontsize=11)
|
| 109 |
+
fig.tight_layout()
|
| 110 |
+
out = os.path.join(OUT, "imu_waveforms.pdf")
|
| 111 |
+
fig.savefig(out, bbox_inches='tight'); fig.savefig(out.replace('.pdf', '.png'), dpi=150, bbox_inches='tight')
|
| 112 |
+
plt.close(fig)
|
| 113 |
+
print("Saved", out)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def plot_emg():
|
| 117 |
+
df = pd.read_csv(os.path.join(REC, "aligned_emg_100hz.csv"))
|
| 118 |
+
t = df["time"].to_numpy(); t = t - t[0]
|
| 119 |
+
ch = [f"emg_{i}" for i in range(1, 9)]
|
| 120 |
+
# 20s window mid-recording
|
| 121 |
+
mid = len(t)//2
|
| 122 |
+
sl = slice(max(0, mid-1000), min(len(t), mid+1000))
|
| 123 |
+
fig, axes = plt.subplots(8, 1, figsize=(9, 7), sharex=True)
|
| 124 |
+
for ax, c in zip(axes, ch):
|
| 125 |
+
sig = df[c].to_numpy()[sl]
|
| 126 |
+
ax.plot(t[sl], sig, color="#555", lw=0.5)
|
| 127 |
+
# envelope overlay
|
| 128 |
+
env = pd.Series(np.abs(sig)).rolling(20, min_periods=1).mean().to_numpy()
|
| 129 |
+
ax.plot(t[sl], env, color="#d62728", lw=0.9)
|
| 130 |
+
ax.set_ylabel(c, fontsize=8)
|
| 131 |
+
ax.grid(alpha=0.3)
|
| 132 |
+
axes[-1].set_xlabel("Time (s)")
|
| 133 |
+
fig.suptitle("Surface EMG 8-channel raw (grey) with rectified envelope (red), v1/s1, 20s window",
|
| 134 |
+
fontsize=11)
|
| 135 |
+
fig.tight_layout()
|
| 136 |
+
out = os.path.join(OUT, "emg_waveforms.pdf")
|
| 137 |
+
fig.savefig(out, bbox_inches='tight'); fig.savefig(out.replace('.pdf', '.png'), dpi=150, bbox_inches='tight')
|
| 138 |
+
plt.close(fig)
|
| 139 |
+
print("Saved", out)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
if __name__ == "__main__":
|
| 143 |
+
plot_skeletons()
|
| 144 |
+
plot_imu()
|
| 145 |
+
plot_emg()
|
experiments/analysis/reannotate_actions.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Re-annotate action segments using LLM (GPT-4o-mini).
|
| 4 |
+
1. Re-classify existing segments with better accuracy
|
| 5 |
+
2. Infer actions in unlabeled gaps based on context (scene, surrounding actions)
|
| 6 |
+
3. Output improved annotations with higher coverage
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import json
|
| 12 |
+
import re
|
| 13 |
+
import time
|
| 14 |
+
import copy
|
| 15 |
+
import glob
|
| 16 |
+
import urllib.request
|
| 17 |
+
from collections import Counter
|
| 18 |
+
|
| 19 |
+
ANN_DIR = "${PULSE_ROOT}/annotations_by_scene"
|
| 20 |
+
OUTPUT_DIR = "${PULSE_ROOT}/annotations_v2"
|
| 21 |
+
DATASET_DIR = "${PULSE_ROOT}/dataset"
|
| 22 |
+
|
| 23 |
+
API_URL = "https://api.chatanywhere.tech/v1/chat/completions"
|
| 24 |
+
API_KEYS = [
|
| 25 |
+
"sk-MN5n1uEETyaky96fLJdHqZobXF1f7KmOrZHzwD3lt585asFQ",
|
| 26 |
+
"sk-YnYrtPdAXwlE12hRpi6dYqlE1RRVR3LDVBka6wKaefU4iQRY",
|
| 27 |
+
"sk-jOZtodDv6OxUOMu3NuJ8lzffjwBlshn9OHY5KSmqmPTtc9qs",
|
| 28 |
+
"sk-qAaKTKYIRF24btu1oQWgubWG4UdA92bILNtzOkHNEPAcCxdB",
|
| 29 |
+
"sk-MgCBBonblMrCFnSXd6fJZaBLTCfCJ5FjYZfSe2e46bgmyktk",
|
| 30 |
+
"sk-79e30kYRgduuf2fSU0Lsc814YjNkClXXzQqIbx0iLS40IOEH",
|
| 31 |
+
"sk-h9Tej4tW6AQC6fT0njfzrPKXEk6fBwpiSvvQd0aJAhw4UwLz",
|
| 32 |
+
"sk-k2QNHt5wAH26Fw8hZuPWuVXw8Psd1jX09qusiA6PdBj5Vzuu",
|
| 33 |
+
"sk-w7EkTblciNI44cwosHXi0PGZNUf1hnJmpzOQ85va9VPdAKbz",
|
| 34 |
+
"sk-Dexs5ZF7OjFCq7CZW45wJ8EKoGtIswv6rsLUMzUXXkWBDBBJ",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
SCENE_DESCRIPTIONS = {
|
| 38 |
+
"s1": "办公桌面整理与工作准备(整理文件、电源线、鼠标、笔记本电脑等)",
|
| 39 |
+
"s2": "快递打包发送(折叠纸箱、放入物品、封箱、贴标签等)",
|
| 40 |
+
"s3": "厨房调料整理(拿取调料瓶、倒调料、拧瓶盖、擦拭等)",
|
| 41 |
+
"s4": "清理餐后桌面(收碗碟、擦桌子、整理餐具、倒残渣等)",
|
| 42 |
+
"s5": "餐前桌面布置(铺桌布、摆放餐具碗碟、放杯子等)",
|
| 43 |
+
"s6": "商务旅行行李箱打包(折叠衣物、放入行李箱、整理物品等)",
|
| 44 |
+
"s7": "冲泡咖啡/饮品(取杯子、放咖啡粉/茶包、倒热水、搅拌等)",
|
| 45 |
+
"s8": "晾衣架整理与衣物收纳(取衣架、挂衣服、折叠衣物等)",
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
ACTION_CATEGORIES = """动作类别定义(共11类):
|
| 49 |
+
|
| 50 |
+
1. Grasp - 抓取/拿起物体(手从无接触到接触并握住物体)
|
| 51 |
+
2. Place - 放置/放下物体(将物体放到某个位置并释放)
|
| 52 |
+
3. Pour - 倾倒/注入液体或颗粒(倒水、倒调料、倒咖啡粉等)
|
| 53 |
+
4. Wipe - 擦拭/清洁表面(用抹布或手擦桌面、瓶身等)
|
| 54 |
+
5. Fold - 折叠/卷起(折衣服、折桌布、折纸箱等)
|
| 55 |
+
6. OpenClose - 打开/关闭/旋开/旋紧(开盒子、拧瓶盖、拉拉链、合箱盖等)
|
| 56 |
+
7. Stir - 搅拌(搅拌咖啡、搅拌饮品等)
|
| 57 |
+
8. TearCut - 撕/剪/粘贴(撕胶带、剪快递单、贴标签等)
|
| 58 |
+
9. Arrange - 整理/摆放/调整位置(摆餐具、整理文件、调整物品位置、理线等)
|
| 59 |
+
10. Transport - 搬运/移动物体到较远位置(把包裹搬到架子、把碗端到水槽等)
|
| 60 |
+
11. Idle - 空闲/过渡/无明确操作(双手无目的性动作、等待、观察等)
|
| 61 |
+
|
| 62 |
+
注意:
|
| 63 |
+
- 只有真正没有任何手部操作时才标Idle
|
| 64 |
+
- "调整姿态"、"检查物体"等属于Arrange
|
| 65 |
+
- "插入"、"装入"等属于Place
|
| 66 |
+
- "提起并移动"如果距离短属于Grasp,距离远属于Transport
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
current_key_idx = 0
|
| 70 |
+
call_count = 0
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def call_llm(prompt, max_tokens=1000, retries=3):
|
| 74 |
+
"""Call LLM API with automatic key rotation."""
|
| 75 |
+
global current_key_idx, call_count
|
| 76 |
+
|
| 77 |
+
for attempt in range(retries * len(API_KEYS)):
|
| 78 |
+
key = API_KEYS[current_key_idx]
|
| 79 |
+
try:
|
| 80 |
+
data = json.dumps({
|
| 81 |
+
"model": "gpt-4o-mini",
|
| 82 |
+
"messages": [{"role": "user", "content": prompt}],
|
| 83 |
+
"max_tokens": max_tokens,
|
| 84 |
+
"temperature": 0.1,
|
| 85 |
+
}).encode()
|
| 86 |
+
req = urllib.request.Request(
|
| 87 |
+
API_URL, data=data,
|
| 88 |
+
headers={
|
| 89 |
+
"Content-Type": "application/json",
|
| 90 |
+
"Authorization": f"Bearer {key}",
|
| 91 |
+
}
|
| 92 |
+
)
|
| 93 |
+
resp = urllib.request.urlopen(req, timeout=30)
|
| 94 |
+
result = json.loads(resp.read())
|
| 95 |
+
call_count += 1
|
| 96 |
+
return result["choices"][0]["message"]["content"]
|
| 97 |
+
except Exception as e:
|
| 98 |
+
err = str(e)
|
| 99 |
+
if "429" in err or "quota" in err or "limit" in err or "402" in err:
|
| 100 |
+
# Key exhausted, rotate
|
| 101 |
+
print(f" Key {current_key_idx+1} exhausted, rotating...")
|
| 102 |
+
current_key_idx = (current_key_idx + 1) % len(API_KEYS)
|
| 103 |
+
elif "timeout" in err.lower():
|
| 104 |
+
time.sleep(1)
|
| 105 |
+
else:
|
| 106 |
+
print(f" API error: {err[:100]}")
|
| 107 |
+
current_key_idx = (current_key_idx + 1) % len(API_KEYS)
|
| 108 |
+
time.sleep(0.5)
|
| 109 |
+
|
| 110 |
+
print(" WARNING: All API keys failed!")
|
| 111 |
+
return None
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def reclassify_segments(segments, scene_id):
|
| 115 |
+
"""Use LLM to reclassify all segments in a recording."""
|
| 116 |
+
scene_desc = SCENE_DESCRIPTIONS.get(scene_id, "日常活动")
|
| 117 |
+
|
| 118 |
+
# Build segment list for prompt
|
| 119 |
+
seg_list = []
|
| 120 |
+
for i, seg in enumerate(segments):
|
| 121 |
+
seg_list.append(f"{i+1}. [{seg['timestamp']}] {seg['task']}")
|
| 122 |
+
seg_text = "\n".join(seg_list)
|
| 123 |
+
|
| 124 |
+
prompt = f"""你是一个人体动作标注专家。请为以下每个动作片段分配一个动作类别。
|
| 125 |
+
|
| 126 |
+
场景:{scene_desc}
|
| 127 |
+
|
| 128 |
+
{ACTION_CATEGORIES}
|
| 129 |
+
|
| 130 |
+
动作片段列表:
|
| 131 |
+
{seg_text}
|
| 132 |
+
|
| 133 |
+
请严格按以下JSON格式返回,不要添加任何额外文字:
|
| 134 |
+
[{{"id": 1, "action": "类别名"}}, {{"id": 2, "action": "类别名"}}, ...]
|
| 135 |
+
|
| 136 |
+
每个action必须是以下之一:Grasp, Place, Pour, Wipe, Fold, OpenClose, Stir, TearCut, Arrange, Transport, Idle"""
|
| 137 |
+
|
| 138 |
+
response = call_llm(prompt, max_tokens=len(segments) * 40)
|
| 139 |
+
if response is None:
|
| 140 |
+
return None
|
| 141 |
+
|
| 142 |
+
# Parse response
|
| 143 |
+
try:
|
| 144 |
+
# Extract JSON from response
|
| 145 |
+
match = re.search(r'\[.*\]', response, re.DOTALL)
|
| 146 |
+
if match:
|
| 147 |
+
results = json.loads(match.group())
|
| 148 |
+
return {r["id"]: r["action"] for r in results}
|
| 149 |
+
except (json.JSONDecodeError, KeyError) as e:
|
| 150 |
+
print(f" Parse error: {e}, response: {response[:200]}")
|
| 151 |
+
return None
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def infer_gap_actions(scene_id, before_seg, after_seg, gap_start, gap_end):
|
| 155 |
+
"""Use LLM to infer what actions likely happened in an unlabeled gap."""
|
| 156 |
+
scene_desc = SCENE_DESCRIPTIONS.get(scene_id, "日常活动")
|
| 157 |
+
gap_duration = gap_end - gap_start
|
| 158 |
+
|
| 159 |
+
before_text = f"[{before_seg['timestamp']}] {before_seg['task']}" if before_seg else "(录制开始)"
|
| 160 |
+
after_text = f"[{after_seg['timestamp']}] {after_seg['task']}" if after_seg else "(录制结束)"
|
| 161 |
+
|
| 162 |
+
prompt = f"""你是一个人体动作标注专家。在一段日常活动录制中,有一段时间没有被标注。请根据场景和前后动作推断这段时间内最可能发生的动作。
|
| 163 |
+
|
| 164 |
+
场景:{scene_desc}
|
| 165 |
+
未标注时间段:{gap_start//60:02d}:{gap_start%60:02d} - {gap_end//60:02d}:{gap_end%60:02d}(共{gap_duration}秒)
|
| 166 |
+
前一个标注动作:{before_text}
|
| 167 |
+
后一个标注动作:{after_text}
|
| 168 |
+
|
| 169 |
+
{ACTION_CATEGORIES}
|
| 170 |
+
|
| 171 |
+
请推断这段时间内可能发生的动作序列。每个动作段落2-4秒,时间用MM:SS格式。
|
| 172 |
+
如果确实是空闲等待,标注为Idle。
|
| 173 |
+
|
| 174 |
+
严格按以下JSON格式返回,不要添加任何额外文字:
|
| 175 |
+
[{{"timestamp": "MM:SS-MM:SS", "task": "动作描述", "action": "类别名"}}]
|
| 176 |
+
|
| 177 |
+
每个action必须是以下之一:Grasp, Place, Pour, Wipe, Fold, OpenClose, Stir, TearCut, Arrange, Transport, Idle"""
|
| 178 |
+
|
| 179 |
+
response = call_llm(prompt, max_tokens=500)
|
| 180 |
+
if response is None:
|
| 181 |
+
return []
|
| 182 |
+
|
| 183 |
+
try:
|
| 184 |
+
match = re.search(r'\[.*\]', response, re.DOTALL)
|
| 185 |
+
if match:
|
| 186 |
+
results = json.loads(match.group())
|
| 187 |
+
# Validate timestamps
|
| 188 |
+
valid = []
|
| 189 |
+
for r in results:
|
| 190 |
+
if "timestamp" in r and "action" in r and "task" in r:
|
| 191 |
+
ts_match = re.match(r'(\d+):(\d+)\s*-\s*(\d+):(\d+)', r["timestamp"])
|
| 192 |
+
if ts_match:
|
| 193 |
+
s = int(ts_match.group(1))*60 + int(ts_match.group(2))
|
| 194 |
+
e = int(ts_match.group(3))*60 + int(ts_match.group(4))
|
| 195 |
+
if gap_start <= s < e <= gap_end:
|
| 196 |
+
valid.append(r)
|
| 197 |
+
return valid
|
| 198 |
+
except (json.JSONDecodeError, KeyError) as e:
|
| 199 |
+
print(f" Parse error: {e}")
|
| 200 |
+
return []
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def get_recording_duration(vol, scenario):
|
| 204 |
+
"""Get total recording duration in seconds."""
|
| 205 |
+
meta_path = os.path.join(DATASET_DIR, vol, scenario, "alignment_metadata.json")
|
| 206 |
+
if os.path.exists(meta_path):
|
| 207 |
+
meta = json.load(open(meta_path))
|
| 208 |
+
if "aligned_length_sec" in meta:
|
| 209 |
+
return meta["aligned_length_sec"]
|
| 210 |
+
if "aligned_length_frames" in meta:
|
| 211 |
+
return meta["aligned_length_frames"] / 100.0
|
| 212 |
+
return None
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def process_one_file(ann_path, vol, scenario):
|
| 216 |
+
"""Process one annotation file: reclassify + fill gaps."""
|
| 217 |
+
data = json.load(open(ann_path))
|
| 218 |
+
segments = data["segments"]
|
| 219 |
+
|
| 220 |
+
if not segments:
|
| 221 |
+
return data, {"reclassified": 0, "gaps_filled": 0}
|
| 222 |
+
|
| 223 |
+
# Step 1: Reclassify existing segments
|
| 224 |
+
print(f" Reclassifying {len(segments)} segments...")
|
| 225 |
+
classifications = reclassify_segments(segments, scenario)
|
| 226 |
+
|
| 227 |
+
if classifications:
|
| 228 |
+
for i, seg in enumerate(segments):
|
| 229 |
+
action = classifications.get(i + 1)
|
| 230 |
+
if action and action in {"Grasp", "Place", "Pour", "Wipe", "Fold",
|
| 231 |
+
"OpenClose", "Stir", "TearCut", "Arrange",
|
| 232 |
+
"Transport", "Idle"}:
|
| 233 |
+
seg["action_label"] = action
|
| 234 |
+
else:
|
| 235 |
+
seg["action_label"] = "Idle"
|
| 236 |
+
else:
|
| 237 |
+
# Fallback: keep without label
|
| 238 |
+
for seg in segments:
|
| 239 |
+
seg["action_label"] = "Idle"
|
| 240 |
+
|
| 241 |
+
reclassified = sum(1 for s in segments if "action_label" in s)
|
| 242 |
+
|
| 243 |
+
# Step 2: Find and fill gaps ≥ 3 seconds
|
| 244 |
+
# Parse all timestamps
|
| 245 |
+
parsed = []
|
| 246 |
+
for seg in segments:
|
| 247 |
+
m = re.match(r'(\d+):(\d+)\s*-\s*(\d+):(\d+)', seg["timestamp"])
|
| 248 |
+
if m:
|
| 249 |
+
s = int(m.group(1))*60 + int(m.group(2))
|
| 250 |
+
e = int(m.group(3))*60 + int(m.group(4))
|
| 251 |
+
parsed.append((s, e, seg))
|
| 252 |
+
parsed.sort()
|
| 253 |
+
|
| 254 |
+
total_dur = get_recording_duration(vol, scenario)
|
| 255 |
+
|
| 256 |
+
new_segments = []
|
| 257 |
+
gaps_filled = 0
|
| 258 |
+
|
| 259 |
+
for i in range(len(parsed)):
|
| 260 |
+
new_segments.append(parsed[i][2])
|
| 261 |
+
|
| 262 |
+
# Check gap after this segment
|
| 263 |
+
if i < len(parsed) - 1:
|
| 264 |
+
gap_start = parsed[i][1]
|
| 265 |
+
gap_end = parsed[i + 1][0]
|
| 266 |
+
elif total_dur:
|
| 267 |
+
gap_start = parsed[i][1]
|
| 268 |
+
gap_end = int(total_dur)
|
| 269 |
+
else:
|
| 270 |
+
continue
|
| 271 |
+
|
| 272 |
+
gap_duration = gap_end - gap_start
|
| 273 |
+
if gap_duration >= 3:
|
| 274 |
+
before_seg = parsed[i][2]
|
| 275 |
+
after_seg = parsed[i + 1][2] if i < len(parsed) - 1 else None
|
| 276 |
+
|
| 277 |
+
print(f" Filling gap {gap_start}s-{gap_end}s ({gap_duration}s)...")
|
| 278 |
+
inferred = infer_gap_actions(scenario, before_seg, after_seg, gap_start, gap_end)
|
| 279 |
+
|
| 280 |
+
for inf in inferred:
|
| 281 |
+
new_seg = {
|
| 282 |
+
"timestamp": inf["timestamp"],
|
| 283 |
+
"task": inf["task"],
|
| 284 |
+
"action_label": inf["action"],
|
| 285 |
+
"source": "llm_inferred",
|
| 286 |
+
"left_hand": "",
|
| 287 |
+
"right_hand": "",
|
| 288 |
+
"bimanual_interaction": "",
|
| 289 |
+
"objects": [],
|
| 290 |
+
}
|
| 291 |
+
new_segments.append(new_seg)
|
| 292 |
+
gaps_filled += 1
|
| 293 |
+
|
| 294 |
+
# Also check gap at the beginning
|
| 295 |
+
if parsed and parsed[0][0] >= 3:
|
| 296 |
+
print(f" Filling start gap 0s-{parsed[0][0]}s...")
|
| 297 |
+
inferred = infer_gap_actions(scenario, None, parsed[0][2], 0, parsed[0][0])
|
| 298 |
+
for inf in inferred:
|
| 299 |
+
new_seg = {
|
| 300 |
+
"timestamp": inf["timestamp"],
|
| 301 |
+
"task": inf["task"],
|
| 302 |
+
"action_label": inf["action"],
|
| 303 |
+
"source": "llm_inferred",
|
| 304 |
+
"left_hand": "",
|
| 305 |
+
"right_hand": "",
|
| 306 |
+
"bimanual_interaction": "",
|
| 307 |
+
"objects": [],
|
| 308 |
+
}
|
| 309 |
+
new_segments.insert(0, new_seg)
|
| 310 |
+
gaps_filled += 1
|
| 311 |
+
|
| 312 |
+
# Sort by timestamp
|
| 313 |
+
def sort_key(seg):
|
| 314 |
+
m = re.match(r'(\d+):(\d+)', seg["timestamp"])
|
| 315 |
+
return int(m.group(1))*60 + int(m.group(2)) if m else 0
|
| 316 |
+
new_segments.sort(key=sort_key)
|
| 317 |
+
|
| 318 |
+
result = copy.deepcopy(data)
|
| 319 |
+
result["segments"] = new_segments
|
| 320 |
+
|
| 321 |
+
return result, {"reclassified": reclassified, "gaps_filled": gaps_filled}
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def main():
|
| 325 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 326 |
+
|
| 327 |
+
total_reclassified = 0
|
| 328 |
+
total_gaps_filled = 0
|
| 329 |
+
total_files = 0
|
| 330 |
+
|
| 331 |
+
for vol_dir in sorted(glob.glob(f"{ANN_DIR}/v*")):
|
| 332 |
+
vol = os.path.basename(vol_dir)
|
| 333 |
+
out_vol_dir = os.path.join(OUTPUT_DIR, vol)
|
| 334 |
+
os.makedirs(out_vol_dir, exist_ok=True)
|
| 335 |
+
|
| 336 |
+
for ann_file in sorted(glob.glob(f"{vol_dir}/s*.json")):
|
| 337 |
+
scenario = os.path.basename(ann_file).replace(".json", "")
|
| 338 |
+
print(f"\n[{vol}/{scenario}]", flush=True)
|
| 339 |
+
|
| 340 |
+
result, stats = process_one_file(ann_file, vol, scenario)
|
| 341 |
+
|
| 342 |
+
# Save
|
| 343 |
+
out_path = os.path.join(out_vol_dir, f"{scenario}.json")
|
| 344 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
| 345 |
+
json.dump(result, f, ensure_ascii=False, indent=2)
|
| 346 |
+
|
| 347 |
+
total_reclassified += stats["reclassified"]
|
| 348 |
+
total_gaps_filled += stats["gaps_filled"]
|
| 349 |
+
total_files += 1
|
| 350 |
+
|
| 351 |
+
print(f" Done: {stats['reclassified']} reclassified, {stats['gaps_filled']} gaps filled",
|
| 352 |
+
flush=True)
|
| 353 |
+
|
| 354 |
+
print(f"\n{'='*60}")
|
| 355 |
+
print(f"Total: {total_files} files processed")
|
| 356 |
+
print(f" Reclassified: {total_reclassified} segments")
|
| 357 |
+
print(f" Gap-filled: {total_gaps_filled} new segments")
|
| 358 |
+
print(f" API calls: {call_count}")
|
| 359 |
+
print(f" Output: {OUTPUT_DIR}")
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
if __name__ == "__main__":
|
| 363 |
+
main()
|
experiments/data/__init__.py
ADDED
|
File without changes
|
experiments/data/__pycache__/dataset.cpython-312.pyc
ADDED
|
Binary file (18.8 kB). View file
|
|
|
experiments/data/dataset.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multimodal scene dataset for Experiment 1: Activity Recognition.
|
| 3 |
+
Loads aligned 100Hz multi-modal data, supports modality selection,
|
| 4 |
+
subject-independent splits, and variable-length sequence handling.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pandas as pd
|
| 11 |
+
import torch
|
| 12 |
+
from torch.utils.data import Dataset, DataLoader
|
| 13 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 14 |
+
|
| 15 |
+
DATASET_DIR = "${PULSE_ROOT}/dataset"
|
| 16 |
+
|
| 17 |
+
MODALITY_FILES = {
|
| 18 |
+
'mocap': None, # Special: uses aligned_{vol}{scene}_s_Q.tsv (skeleton data)
|
| 19 |
+
'emg': 'aligned_emg_100hz.csv',
|
| 20 |
+
'eyetrack': 'aligned_eyetrack_100hz.csv',
|
| 21 |
+
'imu': 'aligned_imu_100hz.csv',
|
| 22 |
+
'pressure': 'aligned_pressure_100hz.csv',
|
| 23 |
+
'video': 'video_features_100hz.npy', # ViT-B/16 (ImageNet)
|
| 24 |
+
'videomae': 'video_features_videomae_100hz.npy', # VideoMAE (Kinetics-400)
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_modality_filepath(scenario_dir, modality, vol=None, scenario=None):
|
| 29 |
+
"""Return the file path for a given modality.
|
| 30 |
+
|
| 31 |
+
Mocap uses a special naming pattern: aligned_{vol}{scene}_s_Q.tsv
|
| 32 |
+
All other modalities use MODALITY_FILES directly.
|
| 33 |
+
"""
|
| 34 |
+
if modality == 'mocap':
|
| 35 |
+
if vol is None or scenario is None:
|
| 36 |
+
raise ValueError("vol and scenario required for mocap modality")
|
| 37 |
+
return os.path.join(scenario_dir, f"aligned_{vol}{scenario}_s_Q.tsv")
|
| 38 |
+
return os.path.join(scenario_dir, MODALITY_FILES[modality])
|
| 39 |
+
|
| 40 |
+
SKIP_COLS = {'Frame', 'Time', 'time', 'UTC'}
|
| 41 |
+
SKIP_COL_SUFFIXES = (' Type',)
|
| 42 |
+
|
| 43 |
+
# Eyetrack exports sometimes include volunteer-specific marker/ICA columns.
|
| 44 |
+
# Benchmark inputs use the fixed 24 core gaze columns below; recordings missing
|
| 45 |
+
# any core column are skipped instead of truncating the full dataset.
|
| 46 |
+
EYETRACK_SKIP_PATTERNS = ('Index Of Cognitive Activity', 'Marker Coordinates', 'Markers_')
|
| 47 |
+
EYETRACK_CORE_COLS = [
|
| 48 |
+
'Dikablis Glasses 3_Eye Data_Original_Pupil X',
|
| 49 |
+
'Dikablis Glasses 3_Eye Data_Original_Pupil Y',
|
| 50 |
+
'Dikablis Glasses 3_Eye Data_Original_Left Eye_Pupil X',
|
| 51 |
+
'Dikablis Glasses 3_Eye Data_Original_Left Eye_Pupil Y',
|
| 52 |
+
'Dikablis Glasses 3_Eye Data_Original_Left Eye_Pupil Area',
|
| 53 |
+
'Dikablis Glasses 3_Eye Data_Original_Left Eye_Pupil Height',
|
| 54 |
+
'Dikablis Glasses 3_Eye Data_Original_Left Eye_Pupil Width',
|
| 55 |
+
'Dikablis Glasses 3_Eye Data_Original_Left Eye_Fixations_Fixations',
|
| 56 |
+
'Dikablis Glasses 3_Eye Data_Original_Left Eye_Fixations_Fixations Duration',
|
| 57 |
+
'Dikablis Glasses 3_Eye Data_Original_Left Eye_Saccades_Saccades',
|
| 58 |
+
'Dikablis Glasses 3_Eye Data_Original_Left Eye_Saccades_Saccades Duration',
|
| 59 |
+
'Dikablis Glasses 3_Eye Data_Original_Left Eye_Saccades_Saccades Angle',
|
| 60 |
+
'Dikablis Glasses 3_Eye Data_Original_Right Eye_Pupil X',
|
| 61 |
+
'Dikablis Glasses 3_Eye Data_Original_Right Eye_Pupil Y',
|
| 62 |
+
'Dikablis Glasses 3_Eye Data_Original_Right Eye_Pupil Area',
|
| 63 |
+
'Dikablis Glasses 3_Eye Data_Original_Right Eye_Pupil Height',
|
| 64 |
+
'Dikablis Glasses 3_Eye Data_Original_Right Eye_Pupil Width',
|
| 65 |
+
'Dikablis Glasses 3_Eye Data_Original_Right Eye_Fixations_Fixations',
|
| 66 |
+
'Dikablis Glasses 3_Eye Data_Original_Right Eye_Fixations_Fixations Duration',
|
| 67 |
+
'Dikablis Glasses 3_Eye Data_Original_Right Eye_Saccades_Saccades',
|
| 68 |
+
'Dikablis Glasses 3_Eye Data_Original_Right Eye_Saccades_Saccades Duration',
|
| 69 |
+
'Dikablis Glasses 3_Eye Data_Original_Right Eye_Saccades_Saccades Angle',
|
| 70 |
+
'Dikablis Glasses 3_Field Data_Scene Cam_Original_Gaze_Gaze X',
|
| 71 |
+
'Dikablis Glasses 3_Field Data_Scene Cam_Original_Gaze_Gaze Y',
|
| 72 |
+
]
|
| 73 |
+
EYETRACK_EXCLUDED_RECORDINGS = {('v1', 's1'), ('v14', 's8')}
|
| 74 |
+
|
| 75 |
+
SCENE_LABELS = {f's{i}': i - 1 for i in range(1, 9)}
|
| 76 |
+
NUM_CLASSES = 8
|
| 77 |
+
|
| 78 |
+
TRAIN_VOLS = ['v1', 'v2', 'v11', 'v12', 'v13', 'v15', 'v16', 'v17', 'v19', 'v20', 'v21', 'v22', 'v23', 'v24']
|
| 79 |
+
VAL_VOLS = [] # No separate val set; use train for early stopping or cross-val
|
| 80 |
+
TEST_VOLS = ['v25', 'v26', 'v27', 'v3']
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _preprocess_mocap_skeleton(arr, feat_cols):
|
| 84 |
+
"""Convert absolute skeleton coords to hip-relative positions + velocity.
|
| 85 |
+
|
| 86 |
+
Input: (T, F) with absolute XYZ + quaternions
|
| 87 |
+
Output: (T, F + N_pos) where N_pos = number of XYZ position features
|
| 88 |
+
[hip-relative features, XYZ velocity]
|
| 89 |
+
"""
|
| 90 |
+
col_to_idx = {c: i for i, c in enumerate(feat_cols)}
|
| 91 |
+
|
| 92 |
+
# Find hip position for subtraction
|
| 93 |
+
hip_x_idx = col_to_idx.get('Hips_X')
|
| 94 |
+
hip_y_idx = col_to_idx.get('Hips_Y')
|
| 95 |
+
hip_z_idx = col_to_idx.get('Hips_Z')
|
| 96 |
+
if hip_x_idx is None:
|
| 97 |
+
return arr # No hip joint found, skip preprocessing
|
| 98 |
+
|
| 99 |
+
# Identify all position columns (_X, _Y, _Z)
|
| 100 |
+
x_indices = [i for i, c in enumerate(feat_cols) if c.endswith('_X')]
|
| 101 |
+
y_indices = [i for i, c in enumerate(feat_cols) if c.endswith('_Y')]
|
| 102 |
+
z_indices = [i for i, c in enumerate(feat_cols) if c.endswith('_Z')]
|
| 103 |
+
all_pos_indices = sorted(x_indices + y_indices + z_indices)
|
| 104 |
+
|
| 105 |
+
# 1. Make XYZ positions hip-relative
|
| 106 |
+
arr_rel = arr.copy()
|
| 107 |
+
hip_xyz = arr[:, [hip_x_idx, hip_y_idx, hip_z_idx]] # (T, 3)
|
| 108 |
+
for idx in x_indices:
|
| 109 |
+
arr_rel[:, idx] -= hip_xyz[:, 0]
|
| 110 |
+
for idx in y_indices:
|
| 111 |
+
arr_rel[:, idx] -= hip_xyz[:, 1]
|
| 112 |
+
for idx in z_indices:
|
| 113 |
+
arr_rel[:, idx] -= hip_xyz[:, 2]
|
| 114 |
+
|
| 115 |
+
# 2. Compute velocity of position features only
|
| 116 |
+
pos_data = arr_rel[:, all_pos_indices] # (T, N_pos)
|
| 117 |
+
velocity = np.zeros_like(pos_data)
|
| 118 |
+
velocity[1:] = pos_data[1:] - pos_data[:-1]
|
| 119 |
+
|
| 120 |
+
# 3. Concatenate: [hip-relative features (pos+quat), position velocity]
|
| 121 |
+
return np.concatenate([arr_rel, velocity], axis=1)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def load_modality_array(filepath, modality):
|
| 125 |
+
"""Load a modality CSV/TSV/NPY and return numpy_array.
|
| 126 |
+
Returns None if data is corrupted (extreme values or mostly zeros)."""
|
| 127 |
+
# Video features stored as .npy
|
| 128 |
+
if filepath.endswith('.npy'):
|
| 129 |
+
if not os.path.exists(filepath):
|
| 130 |
+
return None
|
| 131 |
+
arr = np.load(filepath).astype(np.float32)
|
| 132 |
+
arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
|
| 133 |
+
return arr
|
| 134 |
+
# Mocap uses TSV with tab separator
|
| 135 |
+
sep = '\t' if filepath.endswith('.tsv') else ','
|
| 136 |
+
df = pd.read_csv(filepath, sep=sep, low_memory=False)
|
| 137 |
+
df.columns = [str(c).strip() for c in df.columns]
|
| 138 |
+
if modality == 'eyetrack':
|
| 139 |
+
parts = os.path.normpath(filepath).split(os.sep)
|
| 140 |
+
if len(parts) >= 3 and (parts[-3], parts[-2]) in EYETRACK_EXCLUDED_RECORDINGS:
|
| 141 |
+
return None
|
| 142 |
+
feat_cols = [c for c in df.columns
|
| 143 |
+
if c not in SKIP_COLS
|
| 144 |
+
and not any(c.endswith(s) for s in SKIP_COL_SUFFIXES)]
|
| 145 |
+
if modality == 'eyetrack':
|
| 146 |
+
feat_cols = [c for c in EYETRACK_CORE_COLS if c in feat_cols]
|
| 147 |
+
if len(feat_cols) != len(EYETRACK_CORE_COLS):
|
| 148 |
+
return None
|
| 149 |
+
sub = df[feat_cols]
|
| 150 |
+
# Coerce non-numeric columns
|
| 151 |
+
obj_cols = sub.select_dtypes(include=['object']).columns
|
| 152 |
+
if len(obj_cols) > 0:
|
| 153 |
+
sub = sub.copy()
|
| 154 |
+
sub[obj_cols] = sub[obj_cols].apply(pd.to_numeric, errors='coerce')
|
| 155 |
+
arr = sub.values.astype(np.float64)
|
| 156 |
+
arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
|
| 157 |
+
# Quality check: reject samples with extreme values (corrupted data)
|
| 158 |
+
max_abs = np.max(np.abs(arr))
|
| 159 |
+
if max_abs > 1e6:
|
| 160 |
+
return None # Corrupted
|
| 161 |
+
# Quality check: reject samples that are mostly zeros (sensor dropout).
|
| 162 |
+
# Pressure and EMG are legitimately zero for long periods (rest, no grip)
|
| 163 |
+
# so we only apply the strict near-total-loss check to the modalities
|
| 164 |
+
# where a flat-zero stream is a clear dropout signal.
|
| 165 |
+
if modality not in ("pressure", "emg"):
|
| 166 |
+
zero_ratio = np.mean(arr == 0.0)
|
| 167 |
+
if zero_ratio > 0.9:
|
| 168 |
+
return None # Near-total data loss
|
| 169 |
+
# Mocap skeleton: convert to hip-relative + velocity
|
| 170 |
+
if modality == 'mocap' and filepath.endswith('.tsv'):
|
| 171 |
+
arr = _preprocess_mocap_skeleton(arr, feat_cols)
|
| 172 |
+
arr = arr.astype(np.float32)
|
| 173 |
+
return arr
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class MultimodalSceneDataset(Dataset):
|
| 177 |
+
"""Dataset for scene-level classification from multimodal time series."""
|
| 178 |
+
|
| 179 |
+
def __init__(self, volunteers, modalities, downsample=5, stats=None):
|
| 180 |
+
self.modalities = modalities
|
| 181 |
+
self.downsample = downsample
|
| 182 |
+
self.data = []
|
| 183 |
+
self.labels = []
|
| 184 |
+
self.sample_info = []
|
| 185 |
+
self._modality_dims = {}
|
| 186 |
+
|
| 187 |
+
for vol in volunteers:
|
| 188 |
+
vol_dir = os.path.join(DATASET_DIR, vol)
|
| 189 |
+
if not os.path.isdir(vol_dir):
|
| 190 |
+
continue
|
| 191 |
+
for scenario in sorted(os.listdir(vol_dir)):
|
| 192 |
+
scenario_dir = os.path.join(vol_dir, scenario)
|
| 193 |
+
if not os.path.isdir(scenario_dir) or scenario not in SCENE_LABELS:
|
| 194 |
+
continue
|
| 195 |
+
meta_path = os.path.join(scenario_dir, 'alignment_metadata.json')
|
| 196 |
+
if not os.path.exists(meta_path):
|
| 197 |
+
continue
|
| 198 |
+
with open(meta_path) as f:
|
| 199 |
+
meta = json.load(f)
|
| 200 |
+
available = set(meta['modalities'])
|
| 201 |
+
if not set(modalities).issubset(available):
|
| 202 |
+
continue
|
| 203 |
+
|
| 204 |
+
parts = []
|
| 205 |
+
skip = False
|
| 206 |
+
for mod in modalities:
|
| 207 |
+
if mod == 'mocap':
|
| 208 |
+
# Skeleton data: aligned_{vol}{scene}_s_Q.tsv
|
| 209 |
+
tsv_name = f"aligned_{vol}{scenario}_s_Q.tsv"
|
| 210 |
+
filepath = os.path.join(scenario_dir, tsv_name)
|
| 211 |
+
else:
|
| 212 |
+
filepath = os.path.join(scenario_dir, MODALITY_FILES[mod])
|
| 213 |
+
if not os.path.exists(filepath):
|
| 214 |
+
skip = True
|
| 215 |
+
break
|
| 216 |
+
arr = load_modality_array(filepath, mod)
|
| 217 |
+
if arr is None:
|
| 218 |
+
print(f" SKIP {vol}/{scenario} {mod}: corrupted data", flush=True)
|
| 219 |
+
skip = True
|
| 220 |
+
break
|
| 221 |
+
# Validate dimension consistency
|
| 222 |
+
if mod in self._modality_dims and arr.shape[1] != self._modality_dims[mod]:
|
| 223 |
+
print(f" WARNING: {vol}/{scenario} {mod} dim {arr.shape[1]} "
|
| 224 |
+
f"!= expected {self._modality_dims[mod]}, padding/truncating",
|
| 225 |
+
flush=True)
|
| 226 |
+
expected = self._modality_dims[mod]
|
| 227 |
+
if arr.shape[1] < expected:
|
| 228 |
+
pad = np.zeros((arr.shape[0], expected - arr.shape[1]), dtype=np.float32)
|
| 229 |
+
arr = np.concatenate([arr, pad], axis=1)
|
| 230 |
+
else:
|
| 231 |
+
arr = arr[:, :expected]
|
| 232 |
+
if mod not in self._modality_dims:
|
| 233 |
+
self._modality_dims[mod] = arr.shape[1]
|
| 234 |
+
parts.append(arr)
|
| 235 |
+
|
| 236 |
+
if skip:
|
| 237 |
+
continue
|
| 238 |
+
|
| 239 |
+
min_len = min(p.shape[0] for p in parts)
|
| 240 |
+
parts = [p[:min_len] for p in parts]
|
| 241 |
+
combined = np.concatenate(parts, axis=1)
|
| 242 |
+
combined = combined[::downsample]
|
| 243 |
+
|
| 244 |
+
self.data.append(combined)
|
| 245 |
+
self.labels.append(SCENE_LABELS[scenario])
|
| 246 |
+
self.sample_info.append(f"{vol}/{scenario}")
|
| 247 |
+
|
| 248 |
+
print(f" Loaded {len(self.data)} samples, modality dims: {self._modality_dims}, "
|
| 249 |
+
f"total feat dim: {sum(self._modality_dims.values())}", flush=True)
|
| 250 |
+
|
| 251 |
+
# Normalization (compute in float64 to avoid overflow)
|
| 252 |
+
if stats is not None:
|
| 253 |
+
self.mean, self.std = stats
|
| 254 |
+
else:
|
| 255 |
+
self._compute_stats()
|
| 256 |
+
for i in range(len(self.data)):
|
| 257 |
+
self.data[i] = ((self.data[i].astype(np.float64) - self.mean) / self.std).astype(np.float32)
|
| 258 |
+
self.data[i] = np.nan_to_num(self.data[i], nan=0.0, posinf=0.0, neginf=0.0)
|
| 259 |
+
|
| 260 |
+
def _compute_stats(self):
|
| 261 |
+
# Use float64 for accumulation to prevent overflow
|
| 262 |
+
all_frames = np.concatenate(self.data, axis=0).astype(np.float64)
|
| 263 |
+
self.mean = np.mean(all_frames, axis=0, keepdims=True)
|
| 264 |
+
self.std = np.std(all_frames, axis=0, keepdims=True)
|
| 265 |
+
self.std[self.std < 1e-8] = 1.0
|
| 266 |
+
|
| 267 |
+
def get_stats(self):
|
| 268 |
+
return (self.mean, self.std)
|
| 269 |
+
|
| 270 |
+
@property
|
| 271 |
+
def feat_dim(self):
|
| 272 |
+
return sum(self._modality_dims.values())
|
| 273 |
+
|
| 274 |
+
@property
|
| 275 |
+
def modality_dims(self):
|
| 276 |
+
return dict(self._modality_dims)
|
| 277 |
+
|
| 278 |
+
def get_class_weights(self):
|
| 279 |
+
counts = np.bincount(self.labels, minlength=NUM_CLASSES).astype(np.float32)
|
| 280 |
+
counts[counts == 0] = 1.0
|
| 281 |
+
weights = 1.0 / counts
|
| 282 |
+
weights = weights / weights.sum() * NUM_CLASSES
|
| 283 |
+
return torch.FloatTensor(weights)
|
| 284 |
+
|
| 285 |
+
def __len__(self):
|
| 286 |
+
return len(self.data)
|
| 287 |
+
|
| 288 |
+
def __getitem__(self, idx):
|
| 289 |
+
return torch.from_numpy(self.data[idx]), self.labels[idx]
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def collate_fn(batch):
|
| 293 |
+
"""Pad variable-length sequences and create masks."""
|
| 294 |
+
sequences, labels = zip(*batch)
|
| 295 |
+
lengths = torch.LongTensor([s.shape[0] for s in sequences])
|
| 296 |
+
padded = pad_sequence(sequences, batch_first=True, padding_value=0.0)
|
| 297 |
+
max_len = padded.shape[1]
|
| 298 |
+
mask = torch.arange(max_len).unsqueeze(0) < lengths.unsqueeze(1)
|
| 299 |
+
labels = torch.LongTensor(labels)
|
| 300 |
+
return padded, labels, mask, lengths
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def get_dataloaders(modalities, batch_size=16, downsample=5, num_workers=0):
|
| 304 |
+
"""Create train/val/test DataLoaders with proper normalization."""
|
| 305 |
+
print("Loading training data...", flush=True)
|
| 306 |
+
train_ds = MultimodalSceneDataset(TRAIN_VOLS, modalities, downsample)
|
| 307 |
+
stats = train_ds.get_stats()
|
| 308 |
+
|
| 309 |
+
print("Loading validation data...", flush=True)
|
| 310 |
+
val_ds = MultimodalSceneDataset(VAL_VOLS, modalities, downsample, stats=stats)
|
| 311 |
+
|
| 312 |
+
print("Loading test data...", flush=True)
|
| 313 |
+
test_ds = MultimodalSceneDataset(TEST_VOLS, modalities, downsample, stats=stats)
|
| 314 |
+
|
| 315 |
+
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
|
| 316 |
+
collate_fn=collate_fn, num_workers=num_workers,
|
| 317 |
+
drop_last=False)
|
| 318 |
+
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
|
| 319 |
+
collate_fn=collate_fn, num_workers=num_workers)
|
| 320 |
+
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False,
|
| 321 |
+
collate_fn=collate_fn, num_workers=num_workers)
|
| 322 |
+
|
| 323 |
+
info = {
|
| 324 |
+
'feat_dim': train_ds.feat_dim,
|
| 325 |
+
'modality_dims': train_ds.modality_dims,
|
| 326 |
+
'num_classes': NUM_CLASSES,
|
| 327 |
+
'train_size': len(train_ds),
|
| 328 |
+
'val_size': len(val_ds),
|
| 329 |
+
'test_size': len(test_ds),
|
| 330 |
+
'class_weights': train_ds.get_class_weights(),
|
| 331 |
+
}
|
| 332 |
+
return train_loader, val_loader, test_loader, info
|
experiments/data/dataset_forecast.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Frame-level future motor-primitive forecasting dataset.
|
| 2 |
+
|
| 3 |
+
Task definition
|
| 4 |
+
---------------
|
| 5 |
+
At a sampled anchor time t in a recording:
|
| 6 |
+
past = sensor frames over [t - T_obs, t] ← input
|
| 7 |
+
future = per-frame verb_fine labels over (t, t + T_fut] ← target
|
| 8 |
+
|
| 9 |
+
We use NUM_VERB_FINE (= 17) as a sentinel "idle / no segment" class for
|
| 10 |
+
frames not covered by any annotated segment, so every future frame has a
|
| 11 |
+
valid label (output cardinality = NUM_VERB_FINE + 1 = 18).
|
| 12 |
+
|
| 13 |
+
Anchors are sampled at fixed stride within each recording so the model
|
| 14 |
+
sees both intra-segment future (mostly stationary) and across-boundary
|
| 15 |
+
future (where the next-action label changes — the interesting cases).
|
| 16 |
+
"""
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import Dict, List, Optional, Sequence, Tuple
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import torch
|
| 26 |
+
from torch.utils.data import Dataset
|
| 27 |
+
|
| 28 |
+
THIS = Path(__file__).resolve()
|
| 29 |
+
sys.path.insert(0, str(THIS.parent))
|
| 30 |
+
sys.path.insert(0, str(THIS.parents[1]))
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
from experiments.dataset_seqpred import (
|
| 34 |
+
SAMPLING_RATE_HZ, _load_recording_sensors, _load_annotations,
|
| 35 |
+
parse_ts_range, TRAIN_VOLS_V3, TEST_VOLS_V3,
|
| 36 |
+
DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR,
|
| 37 |
+
)
|
| 38 |
+
from experiments.taxonomy import (
|
| 39 |
+
classify_segment, NUM_VERB_FINE,
|
| 40 |
+
)
|
| 41 |
+
except ModuleNotFoundError:
|
| 42 |
+
from dataset_seqpred import (
|
| 43 |
+
SAMPLING_RATE_HZ, _load_recording_sensors, _load_annotations,
|
| 44 |
+
parse_ts_range, TRAIN_VOLS_V3, TEST_VOLS_V3,
|
| 45 |
+
DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR,
|
| 46 |
+
)
|
| 47 |
+
from taxonomy import classify_segment, NUM_VERB_FINE
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
IDLE_LABEL = NUM_VERB_FINE # = 17, sentinel for "no segment covers this frame"
|
| 51 |
+
NUM_FORECAST_CLASSES = NUM_VERB_FINE + 1 # = 18
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class ForecastDataset(Dataset):
|
| 55 |
+
"""Forecast next T_fut seconds of per-frame verb_fine given past T_obs."""
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
volunteers: Sequence[str],
|
| 60 |
+
modalities: Sequence[str],
|
| 61 |
+
t_obs_sec: float = 1.5,
|
| 62 |
+
t_fut_sec: float = 0.5,
|
| 63 |
+
anchor_stride_sec: float = 0.25,
|
| 64 |
+
downsample: int = 5,
|
| 65 |
+
dataset_dir: Path = DEFAULT_DATASET_DIR,
|
| 66 |
+
annot_dir: Path = DEFAULT_ANNOT_DIR,
|
| 67 |
+
stats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]] = None,
|
| 68 |
+
expected_dims: Optional[Dict[str, int]] = None,
|
| 69 |
+
contact_only: bool = False,
|
| 70 |
+
contact_threshold_g: float = 5.0,
|
| 71 |
+
log: bool = True,
|
| 72 |
+
):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.modalities = list(modalities)
|
| 75 |
+
self.t_obs_sec = float(t_obs_sec)
|
| 76 |
+
self.t_fut_sec = float(t_fut_sec)
|
| 77 |
+
self.anchor_stride_sec = float(anchor_stride_sec)
|
| 78 |
+
self.downsample = int(downsample)
|
| 79 |
+
self.sr = SAMPLING_RATE_HZ // self.downsample
|
| 80 |
+
self.dataset_dir = Path(dataset_dir)
|
| 81 |
+
self.annot_dir = Path(annot_dir)
|
| 82 |
+
self.contact_only = bool(contact_only)
|
| 83 |
+
self.contact_threshold_g = float(contact_threshold_g)
|
| 84 |
+
|
| 85 |
+
# Output time-step counts (after downsample)
|
| 86 |
+
self.T_obs = int(round(self.t_obs_sec * self.sr))
|
| 87 |
+
self.T_fut = int(round(self.t_fut_sec * self.sr))
|
| 88 |
+
|
| 89 |
+
self._items: List[dict] = []
|
| 90 |
+
# Pre-seed modality dims if caller (e.g. test set) provides them
|
| 91 |
+
self._modality_dims: Dict[str, int] = dict(expected_dims) if expected_dims else {}
|
| 92 |
+
|
| 93 |
+
for vol in volunteers:
|
| 94 |
+
vol_dir = self.dataset_dir / vol
|
| 95 |
+
if not vol_dir.is_dir():
|
| 96 |
+
continue
|
| 97 |
+
for scenario_dir in sorted(vol_dir.glob("s*")):
|
| 98 |
+
if not scenario_dir.is_dir():
|
| 99 |
+
continue
|
| 100 |
+
scene = scenario_dir.name
|
| 101 |
+
annot_path = self.annot_dir / vol / f"{scene}.json"
|
| 102 |
+
if not annot_path.exists():
|
| 103 |
+
continue
|
| 104 |
+
|
| 105 |
+
# Always include pressure for the filter, even if model
|
| 106 |
+
# doesn't see it as input. We separate "filter sensors"
|
| 107 |
+
# (load_mods) from "model input sensors" (self.modalities).
|
| 108 |
+
load_mods = list(dict.fromkeys(list(self.modalities) + ["pressure"]))
|
| 109 |
+
try:
|
| 110 |
+
sensors_all = _load_recording_sensors(
|
| 111 |
+
scenario_dir, vol, scene, load_mods
|
| 112 |
+
)
|
| 113 |
+
except Exception:
|
| 114 |
+
continue
|
| 115 |
+
if sensors_all is None or any(a is None for a in sensors_all.values()):
|
| 116 |
+
continue
|
| 117 |
+
pressure_full = sensors_all.get("pressure") # (T, 50)
|
| 118 |
+
# Subset to model-input modalities for everything downstream
|
| 119 |
+
sensors = {m: sensors_all[m] for m in self.modalities}
|
| 120 |
+
|
| 121 |
+
# Track modality dim consistency
|
| 122 |
+
for m, arr in sensors.items():
|
| 123 |
+
if m in self._modality_dims:
|
| 124 |
+
target = self._modality_dims[m]
|
| 125 |
+
if arr.shape[1] != target:
|
| 126 |
+
if arr.shape[1] < target:
|
| 127 |
+
pad = np.zeros((arr.shape[0], target - arr.shape[1]),
|
| 128 |
+
dtype=np.float32)
|
| 129 |
+
sensors[m] = np.concatenate([arr, pad], axis=1)
|
| 130 |
+
else:
|
| 131 |
+
sensors[m] = arr[:, :target]
|
| 132 |
+
else:
|
| 133 |
+
self._modality_dims[m] = arr.shape[1]
|
| 134 |
+
|
| 135 |
+
T_avail = min(a.shape[0] for a in sensors.values())
|
| 136 |
+
if T_avail < (self.T_obs + self.T_fut) * self.downsample:
|
| 137 |
+
continue
|
| 138 |
+
|
| 139 |
+
# Build per-frame verb_fine timeline at full 100 Hz
|
| 140 |
+
timeline = np.full(T_avail, IDLE_LABEL, dtype=np.int64)
|
| 141 |
+
segs = _load_annotations(annot_path)
|
| 142 |
+
for seg in segs:
|
| 143 |
+
a = seg.get("action_annotation", {})
|
| 144 |
+
labels = classify_segment(a)
|
| 145 |
+
if labels is None:
|
| 146 |
+
continue
|
| 147 |
+
start_sec, end_sec = parse_ts_range(seg.get("timestamp", ""))
|
| 148 |
+
s = int(round(start_sec * SAMPLING_RATE_HZ))
|
| 149 |
+
e = int(round(end_sec * SAMPLING_RATE_HZ))
|
| 150 |
+
s = max(0, s); e = min(T_avail, e)
|
| 151 |
+
if e > s:
|
| 152 |
+
timeline[s:e] = labels["verb_fine"]
|
| 153 |
+
|
| 154 |
+
# Downsample timeline to 20 Hz
|
| 155 |
+
timeline_ds = timeline[::self.downsample]
|
| 156 |
+
T_ds = len(timeline_ds)
|
| 157 |
+
|
| 158 |
+
# Downsample sensors to 20 Hz (kept as full record;
|
| 159 |
+
# we'll slice windows below)
|
| 160 |
+
sensors_ds = {m: arr[::self.downsample] for m, arr in sensors.items()}
|
| 161 |
+
|
| 162 |
+
# Build contact mask at 20 Hz (per-frame): is pressure-sum > thr?
|
| 163 |
+
# Pressure is 50 channels; we follow the T2 contact convention
|
| 164 |
+
# (sum across all fingertips and threshold at 5 g).
|
| 165 |
+
if pressure_full is not None:
|
| 166 |
+
pressure_ds = pressure_full[::self.downsample]
|
| 167 |
+
contact_ds = pressure_ds.sum(axis=1) > self.contact_threshold_g
|
| 168 |
+
else:
|
| 169 |
+
contact_ds = np.zeros(T_ds, dtype=bool)
|
| 170 |
+
|
| 171 |
+
# Sample anchors at fixed stride (in 20 Hz frames)
|
| 172 |
+
stride = max(1, int(round(self.anchor_stride_sec * self.sr)))
|
| 173 |
+
first_anchor = self.T_obs
|
| 174 |
+
last_anchor = T_ds - self.T_fut
|
| 175 |
+
if last_anchor <= first_anchor:
|
| 176 |
+
continue
|
| 177 |
+
|
| 178 |
+
for anchor in range(first_anchor, last_anchor + 1, stride):
|
| 179 |
+
# contact-rich filter: any contact frame in past or future window?
|
| 180 |
+
if self.contact_only:
|
| 181 |
+
win = contact_ds[max(0, anchor - self.T_obs):
|
| 182 |
+
min(T_ds, anchor + self.T_fut)]
|
| 183 |
+
if not win.any():
|
| 184 |
+
continue
|
| 185 |
+
past_slice = {m: arr[anchor - self.T_obs:anchor]
|
| 186 |
+
for m, arr in sensors_ds.items()}
|
| 187 |
+
fut_labels = timeline_ds[anchor:anchor + self.T_fut].copy()
|
| 188 |
+
# length sanity
|
| 189 |
+
if any(w.shape[0] != self.T_obs for w in past_slice.values()):
|
| 190 |
+
continue
|
| 191 |
+
if fut_labels.shape[0] != self.T_fut:
|
| 192 |
+
continue
|
| 193 |
+
self._items.append({
|
| 194 |
+
"x": past_slice, # dict[mod] -> (T_obs, F_mod)
|
| 195 |
+
"y_seq": fut_labels, # (T_fut,) int in [0..17]
|
| 196 |
+
"meta": {"vol": vol, "scene": scene, "anchor_idx": int(anchor)},
|
| 197 |
+
})
|
| 198 |
+
|
| 199 |
+
if not self._items:
|
| 200 |
+
raise RuntimeError("ForecastDataset: collected 0 anchors. Check annot_dir / modalities.")
|
| 201 |
+
|
| 202 |
+
# Per-modality z-score using training stats
|
| 203 |
+
if stats is None:
|
| 204 |
+
stats = self._compute_stats()
|
| 205 |
+
self._stats = stats
|
| 206 |
+
self._apply_stats(stats)
|
| 207 |
+
|
| 208 |
+
if log:
|
| 209 |
+
print(f"[ForecastDataset] vols={len(volunteers)} "
|
| 210 |
+
f"anchors={len(self._items)} "
|
| 211 |
+
f"T_obs={self.T_obs} T_fut={self.T_fut} "
|
| 212 |
+
f"contact_only={self.contact_only} "
|
| 213 |
+
f"modality_dims={self._modality_dims} "
|
| 214 |
+
f"sr={self.sr}Hz", flush=True)
|
| 215 |
+
|
| 216 |
+
# ----- Stats / normalization -----
|
| 217 |
+
def _compute_stats(self) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
|
| 218 |
+
accs = {m: [] for m in self._modality_dims}
|
| 219 |
+
for it in self._items:
|
| 220 |
+
for m, w in it["x"].items():
|
| 221 |
+
accs[m].append(w)
|
| 222 |
+
out = {}
|
| 223 |
+
for m, ws in accs.items():
|
| 224 |
+
cat = np.concatenate(ws, axis=0)
|
| 225 |
+
mu = cat.mean(axis=0)
|
| 226 |
+
sd = cat.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd)
|
| 227 |
+
out[m] = (mu.astype(np.float32), sd.astype(np.float32))
|
| 228 |
+
return out
|
| 229 |
+
|
| 230 |
+
def _apply_stats(self, stats):
|
| 231 |
+
for it in self._items:
|
| 232 |
+
for m, w in it["x"].items():
|
| 233 |
+
if m in stats:
|
| 234 |
+
mu, sd = stats[m]
|
| 235 |
+
it["x"][m] = ((w - mu) / sd).astype(np.float32)
|
| 236 |
+
|
| 237 |
+
# ----- Dataset protocol -----
|
| 238 |
+
def __len__(self):
|
| 239 |
+
return len(self._items)
|
| 240 |
+
|
| 241 |
+
def __getitem__(self, idx):
|
| 242 |
+
it = self._items[idx]
|
| 243 |
+
x = {m: torch.from_numpy(np.ascontiguousarray(w)) for m, w in it["x"].items()}
|
| 244 |
+
y_seq = torch.from_numpy(np.ascontiguousarray(it["y_seq"])) # (T_fut,)
|
| 245 |
+
return x, y_seq, it["meta"]
|
| 246 |
+
|
| 247 |
+
@property
|
| 248 |
+
def modality_dims(self):
|
| 249 |
+
return dict(self._modality_dims)
|
| 250 |
+
|
| 251 |
+
def class_freq(self) -> np.ndarray:
|
| 252 |
+
c = np.zeros(NUM_FORECAST_CLASSES, dtype=np.int64)
|
| 253 |
+
for it in self._items:
|
| 254 |
+
for v in it["y_seq"]:
|
| 255 |
+
c[int(v)] += 1
|
| 256 |
+
return c
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def collate_forecast(batch):
|
| 260 |
+
"""Stack (x_dict, y_seq, meta) -> batched tensors. All samples share T_obs/T_fut."""
|
| 261 |
+
xs, ys, metas = zip(*batch)
|
| 262 |
+
B = len(batch)
|
| 263 |
+
mods = list(xs[0].keys())
|
| 264 |
+
x_out: Dict[str, torch.Tensor] = {}
|
| 265 |
+
for m in mods:
|
| 266 |
+
x_out[m] = torch.stack([x[m] for x in xs], dim=0) # (B, T_obs, F_mod)
|
| 267 |
+
y_out = torch.stack(ys, dim=0) # (B, T_fut)
|
| 268 |
+
return x_out, y_out, list(metas)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def build_train_test(
|
| 272 |
+
modalities: Sequence[str],
|
| 273 |
+
t_obs_sec: float = 1.5,
|
| 274 |
+
t_fut_sec: float = 0.5,
|
| 275 |
+
anchor_stride_sec: float = 0.25,
|
| 276 |
+
downsample: int = 5,
|
| 277 |
+
dataset_dir: Path = DEFAULT_DATASET_DIR,
|
| 278 |
+
annot_dir: Path = DEFAULT_ANNOT_DIR,
|
| 279 |
+
contact_only: bool = False,
|
| 280 |
+
contact_threshold_g: float = 5.0,
|
| 281 |
+
):
|
| 282 |
+
train = ForecastDataset(
|
| 283 |
+
TRAIN_VOLS_V3, modalities=modalities,
|
| 284 |
+
t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec,
|
| 285 |
+
anchor_stride_sec=anchor_stride_sec, downsample=downsample,
|
| 286 |
+
dataset_dir=dataset_dir, annot_dir=annot_dir,
|
| 287 |
+
contact_only=contact_only, contact_threshold_g=contact_threshold_g,
|
| 288 |
+
stats=None, log=True,
|
| 289 |
+
)
|
| 290 |
+
test = ForecastDataset(
|
| 291 |
+
TEST_VOLS_V3, modalities=modalities,
|
| 292 |
+
t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec,
|
| 293 |
+
anchor_stride_sec=anchor_stride_sec, downsample=downsample,
|
| 294 |
+
dataset_dir=dataset_dir, annot_dir=annot_dir,
|
| 295 |
+
contact_only=contact_only, contact_threshold_g=contact_threshold_g,
|
| 296 |
+
stats=train._stats, expected_dims=train._modality_dims, log=True,
|
| 297 |
+
)
|
| 298 |
+
return train, test
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
if __name__ == "__main__":
|
| 302 |
+
import argparse
|
| 303 |
+
ap = argparse.ArgumentParser()
|
| 304 |
+
ap.add_argument("--modalities", type=str, default="imu,emg,eyetrack,mocap,pressure")
|
| 305 |
+
ap.add_argument("--t_obs", type=float, default=1.5)
|
| 306 |
+
ap.add_argument("--t_fut", type=float, default=0.5)
|
| 307 |
+
ap.add_argument("--stride", type=float, default=0.25)
|
| 308 |
+
args = ap.parse_args()
|
| 309 |
+
mods = args.modalities.split(",")
|
| 310 |
+
tr, te = build_train_test(
|
| 311 |
+
modalities=mods,
|
| 312 |
+
t_obs_sec=args.t_obs, t_fut_sec=args.t_fut,
|
| 313 |
+
anchor_stride_sec=args.stride,
|
| 314 |
+
)
|
| 315 |
+
print(f"\nTrain={len(tr)} Test={len(te)} T_obs={tr.T_obs} T_fut={tr.T_fut}")
|
| 316 |
+
print(f"Train class freq:\n{tr.class_freq()}")
|
| 317 |
+
print(f"Test class freq:\n{te.class_freq()}")
|
| 318 |
+
x, y, meta = tr[0]
|
| 319 |
+
print(f"Sample: x={ {m: tuple(v.shape) for m,v in x.items()} } y_seq={tuple(y.shape)}")
|
experiments/data/dataset_grasp_state.py
ADDED
|
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Anchor-based binary "is_grasping" classification dataset (T5 v3 / TGSR).
|
| 2 |
+
|
| 3 |
+
At each sampled anchor t in a recording:
|
| 4 |
+
past = sensor frames over [t - T_obs, t] ← input
|
| 5 |
+
label = majority vote of grasp-annotation mask over (t, t+T_fut] ← binary class
|
| 6 |
+
|
| 7 |
+
Ground-truth source: annotations_v3 verb segments. A frame is marked
|
| 8 |
+
"is_grasp" if it falls inside a segment whose action_name belongs to
|
| 9 |
+
GRASP_VERBS (set below). The label is annotation-derived, completely
|
| 10 |
+
independent of pressure — so adding/removing pressure as input does
|
| 11 |
+
NOT leak the label.
|
| 12 |
+
|
| 13 |
+
This is the cleanest test of "does pressure improve recognition of
|
| 14 |
+
object-interaction state when human-annotated grasp segments are GT?"
|
| 15 |
+
"""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import json
|
| 19 |
+
import sys
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Dict, List, Optional, Sequence, Tuple
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import torch
|
| 25 |
+
from torch.utils.data import Dataset
|
| 26 |
+
|
| 27 |
+
THIS = Path(__file__).resolve()
|
| 28 |
+
sys.path.insert(0, str(THIS.parent))
|
| 29 |
+
sys.path.insert(0, str(THIS.parents[1]))
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
from experiments.dataset_seqpred import (
|
| 33 |
+
SAMPLING_RATE_HZ, _load_recording_sensors,
|
| 34 |
+
TRAIN_VOLS_V3, TEST_VOLS_V3,
|
| 35 |
+
DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR,
|
| 36 |
+
)
|
| 37 |
+
except ModuleNotFoundError:
|
| 38 |
+
from dataset_seqpred import (
|
| 39 |
+
SAMPLING_RATE_HZ, _load_recording_sensors,
|
| 40 |
+
TRAIN_VOLS_V3, TEST_VOLS_V3,
|
| 41 |
+
DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
GRASP_VERBS = {
|
| 46 |
+
"grasp", "hold", "pick_up", "move", "place", "put_down",
|
| 47 |
+
"pull", "rotate", "insert", "remove",
|
| 48 |
+
}
|
| 49 |
+
# User-specified subset of action verbs that mean "the object has been lifted
|
| 50 |
+
# off its resting surface and held in hand" (used as Class 2 stricter definition).
|
| 51 |
+
LIFT_VERBS = {"grasp", "open", "move", "pick_up", "hold"}
|
| 52 |
+
|
| 53 |
+
# Multi-class verb taxonomy (annotations_v3 verb_fine universe).
|
| 54 |
+
# Verb 0 = background (anchor outside any segment).
|
| 55 |
+
VERB_LIST = [
|
| 56 |
+
"background",
|
| 57 |
+
"grasp", "move", "place", "adjust", "pick_up",
|
| 58 |
+
"close", "put_down", "pull", "hold", "open",
|
| 59 |
+
"rotate", "release", "push", "insert", "remove",
|
| 60 |
+
"align", "stabilize",
|
| 61 |
+
]
|
| 62 |
+
VERB_TO_IDX = {v: i for i, v in enumerate(VERB_LIST)}
|
| 63 |
+
|
| 64 |
+
# Top-15 most common object categories with non-zero coverage in the
|
| 65 |
+
# pressure-bearing test set (annotations_v3 survey of TRAIN+TEST_VOLS_V3).
|
| 66 |
+
# Index 0 = "_other": anchor outside any segment OR object not in top-15.
|
| 67 |
+
# Note: "coat" excluded because it appears only in v14, which has no
|
| 68 |
+
# pressure-aligned sessions and is silently dropped by the loader.
|
| 69 |
+
OBJECT_TOP_LIST = [
|
| 70 |
+
"_other",
|
| 71 |
+
"sealed jar", "towel", "tablecloth", "box", "pot",
|
| 72 |
+
"rice bowl", "tape", "pants", "spoon", "plate",
|
| 73 |
+
"marker", "cloth", "laptop", "toothbrush case", "tea canister",
|
| 74 |
+
]
|
| 75 |
+
OBJECT_TO_IDX = {o: i for i, o in enumerate(OBJECT_TOP_LIST)}
|
| 76 |
+
EVENT_NAMES = {0: "non-contact", 1: "pre-contact", 2: "steady-grip", 3: "release"}
|
| 77 |
+
CLASS_NAMES_BINARY = {0: "non-grasp", 1: "grasp"}
|
| 78 |
+
CLASS_NAMES_THREE = {0: "no-grasp", 1: "attempted", 2: "sustained"}
|
| 79 |
+
# Back-compat default (used by binary code paths)
|
| 80 |
+
CLASS_NAMES = CLASS_NAMES_BINARY
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _parse_one(x: str, fmt_mode: str) -> float:
|
| 84 |
+
p = x.split(":")
|
| 85 |
+
if len(p) == 2:
|
| 86 |
+
return int(p[0]) * 60 + int(p[1])
|
| 87 |
+
if fmt_mode == "hhmmss":
|
| 88 |
+
return int(p[0]) * 3600 + int(p[1]) * 60 + int(p[2])
|
| 89 |
+
return int(p[0]) * 60 + int(p[1]) + int(p[2]) / 30.0 # mmssff @ 30fps
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _detect_fmt(segments, rec_sec: float) -> str:
|
| 93 |
+
for s in segments:
|
| 94 |
+
b = s["timestamp"].split("-")[1]
|
| 95 |
+
p = b.split(":")
|
| 96 |
+
if len(p) == 3:
|
| 97 |
+
hh = int(p[0]) * 3600 + int(p[1]) * 60 + int(p[2])
|
| 98 |
+
if hh > rec_sec * 1.05:
|
| 99 |
+
return "mmssff"
|
| 100 |
+
return "hhmmss"
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def build_object_label(annot_path: Path, n_frames: int,
|
| 104 |
+
sr: int = SAMPLING_RATE_HZ) -> np.ndarray:
|
| 105 |
+
"""Per-frame object index (top-15 + '_other' fallback as class 0)."""
|
| 106 |
+
label = np.zeros(n_frames, dtype=np.int8)
|
| 107 |
+
if not annot_path.exists():
|
| 108 |
+
return label
|
| 109 |
+
try:
|
| 110 |
+
ann = json.load(open(annot_path))
|
| 111 |
+
except Exception:
|
| 112 |
+
return label
|
| 113 |
+
segments = ann.get("segments", [])
|
| 114 |
+
if not segments:
|
| 115 |
+
return label
|
| 116 |
+
rec_sec = n_frames / sr
|
| 117 |
+
fmt = _detect_fmt(segments, rec_sec)
|
| 118 |
+
for s in segments:
|
| 119 |
+
obj = s.get("action_annotation", {}).get("object_name")
|
| 120 |
+
idx = OBJECT_TO_IDX.get(obj, 0)
|
| 121 |
+
if idx == 0:
|
| 122 |
+
continue # leave as 0 ("_other"/background)
|
| 123 |
+
try:
|
| 124 |
+
a, b = s["timestamp"].split("-")
|
| 125 |
+
t0 = _parse_one(a, fmt); t1 = _parse_one(b, fmt)
|
| 126 |
+
except Exception:
|
| 127 |
+
continue
|
| 128 |
+
if t1 <= t0 or t1 > rec_sec * 1.10:
|
| 129 |
+
continue
|
| 130 |
+
i0 = max(0, int(round(t0 * sr)))
|
| 131 |
+
i1 = min(n_frames, int(round(t1 * sr)))
|
| 132 |
+
label[i0:i1] = idx
|
| 133 |
+
return label
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def build_lift_eligible_mask(annot_path: Path, n_frames: int,
|
| 137 |
+
sr: int = SAMPLING_RATE_HZ) -> np.ndarray:
|
| 138 |
+
"""Per-frame bool: True if frame is inside a segment that meets the
|
| 139 |
+
lifted-grasp criterion: verb ∈ LIFT_VERBS OR hand_type == 'both'.
|
| 140 |
+
Used by 3-class label_mode when require_lift_for_sustained=True."""
|
| 141 |
+
mask = np.zeros(n_frames, dtype=bool)
|
| 142 |
+
if not annot_path.exists():
|
| 143 |
+
return mask
|
| 144 |
+
try:
|
| 145 |
+
ann = json.load(open(annot_path))
|
| 146 |
+
except Exception:
|
| 147 |
+
return mask
|
| 148 |
+
segments = ann.get("segments", [])
|
| 149 |
+
if not segments:
|
| 150 |
+
return mask
|
| 151 |
+
rec_sec = n_frames / sr
|
| 152 |
+
fmt = _detect_fmt(segments, rec_sec)
|
| 153 |
+
for s in segments:
|
| 154 |
+
a = s.get("action_annotation", {})
|
| 155 |
+
verb = a.get("action_name")
|
| 156 |
+
hand = a.get("hand_type", "")
|
| 157 |
+
is_lift = (verb in LIFT_VERBS) or (hand == "both")
|
| 158 |
+
if not is_lift:
|
| 159 |
+
continue
|
| 160 |
+
try:
|
| 161 |
+
ts0, ts1 = s["timestamp"].split("-")
|
| 162 |
+
t0 = _parse_one(ts0, fmt); t1 = _parse_one(ts1, fmt)
|
| 163 |
+
except Exception:
|
| 164 |
+
continue
|
| 165 |
+
if t1 <= t0 or t1 > rec_sec * 1.10:
|
| 166 |
+
continue
|
| 167 |
+
i0 = max(0, int(round(t0 * sr)))
|
| 168 |
+
i1 = min(n_frames, int(round(t1 * sr)))
|
| 169 |
+
mask[i0:i1] = True
|
| 170 |
+
return mask
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def build_verb_label(annot_path: Path, n_frames: int,
|
| 174 |
+
sr: int = SAMPLING_RATE_HZ) -> np.ndarray:
|
| 175 |
+
"""Per-frame verb index (int8). Default (no segment) = 0 (background)."""
|
| 176 |
+
label = np.zeros(n_frames, dtype=np.int8)
|
| 177 |
+
if not annot_path.exists():
|
| 178 |
+
return label
|
| 179 |
+
try:
|
| 180 |
+
ann = json.load(open(annot_path))
|
| 181 |
+
except Exception:
|
| 182 |
+
return label
|
| 183 |
+
segments = ann.get("segments", [])
|
| 184 |
+
if not segments:
|
| 185 |
+
return label
|
| 186 |
+
rec_sec = n_frames / sr
|
| 187 |
+
fmt = _detect_fmt(segments, rec_sec)
|
| 188 |
+
for s in segments:
|
| 189 |
+
verb = s.get("action_annotation", {}).get("action_name")
|
| 190 |
+
v_idx = VERB_TO_IDX.get(verb, 0) # unknown verb → background
|
| 191 |
+
if v_idx == 0:
|
| 192 |
+
continue
|
| 193 |
+
try:
|
| 194 |
+
a, b = s["timestamp"].split("-")
|
| 195 |
+
t0 = _parse_one(a, fmt); t1 = _parse_one(b, fmt)
|
| 196 |
+
except Exception:
|
| 197 |
+
continue
|
| 198 |
+
if t1 <= t0 or t1 > rec_sec * 1.10:
|
| 199 |
+
continue
|
| 200 |
+
i0 = max(0, int(round(t0 * sr)))
|
| 201 |
+
i1 = min(n_frames, int(round(t1 * sr)))
|
| 202 |
+
label[i0:i1] = v_idx
|
| 203 |
+
return label
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def build_grasp_mask(annot_path: Path, n_frames: int,
|
| 207 |
+
sr: int = SAMPLING_RATE_HZ) -> np.ndarray:
|
| 208 |
+
"""Return bool array of shape (n_frames,)."""
|
| 209 |
+
mask = np.zeros(n_frames, dtype=bool)
|
| 210 |
+
if not annot_path.exists():
|
| 211 |
+
return mask
|
| 212 |
+
try:
|
| 213 |
+
ann = json.load(open(annot_path))
|
| 214 |
+
except Exception:
|
| 215 |
+
return mask
|
| 216 |
+
segments = ann.get("segments", [])
|
| 217 |
+
if not segments:
|
| 218 |
+
return mask
|
| 219 |
+
rec_sec = n_frames / sr
|
| 220 |
+
fmt = _detect_fmt(segments, rec_sec)
|
| 221 |
+
for s in segments:
|
| 222 |
+
verb = s.get("action_annotation", {}).get("action_name")
|
| 223 |
+
if verb not in GRASP_VERBS:
|
| 224 |
+
continue
|
| 225 |
+
try:
|
| 226 |
+
a, b = s["timestamp"].split("-")
|
| 227 |
+
t0 = _parse_one(a, fmt); t1 = _parse_one(b, fmt)
|
| 228 |
+
except Exception:
|
| 229 |
+
continue
|
| 230 |
+
if t1 <= t0 or t1 > rec_sec * 1.10:
|
| 231 |
+
continue
|
| 232 |
+
i0 = max(0, int(round(t0 * sr)))
|
| 233 |
+
i1 = min(n_frames, int(round(t1 * sr)))
|
| 234 |
+
mask[i0:i1] = True
|
| 235 |
+
return mask
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class GraspStateDataset(Dataset):
|
| 239 |
+
"""Predict binary 'is_grasping' label over future window from past sensor signals."""
|
| 240 |
+
|
| 241 |
+
def __init__(
|
| 242 |
+
self,
|
| 243 |
+
volunteers: Sequence[str],
|
| 244 |
+
input_modalities: Sequence[str],
|
| 245 |
+
t_obs_sec: float = 1.0,
|
| 246 |
+
t_fut_sec: float = 0.5,
|
| 247 |
+
anchor_stride_sec: float = 0.25,
|
| 248 |
+
downsample: int = 5,
|
| 249 |
+
dataset_dir: Path = DEFAULT_DATASET_DIR,
|
| 250 |
+
annot_dir: Path = DEFAULT_ANNOT_DIR,
|
| 251 |
+
contact_threshold_g: float = 5.0, # legacy sum-threshold (kept for back-compat, unused if use_per_cell_contact=True)
|
| 252 |
+
per_cell_threshold_g: float = 10.0, # per-cell threshold to declare a sensor cell "active"
|
| 253 |
+
min_active_cells: int = 3, # need ≥ this many active cells to declare contact
|
| 254 |
+
use_per_cell_contact: bool = True, # NEW default: use per-cell active-count for event_type
|
| 255 |
+
label_mode: str = "binary", # "binary", "three_class", or "verb"
|
| 256 |
+
sustained_threshold_sec: float = 0.3, # (3-class only) min contiguous contact for "Sustained"
|
| 257 |
+
require_lift_for_sustained: bool = False, # (3-class only) Class 2 also requires verb ∈ LIFT_VERBS
|
| 258 |
+
per_class_max: Optional[int] = None,
|
| 259 |
+
input_stats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]] = None,
|
| 260 |
+
expected_input_dims: Optional[Dict[str, int]] = None,
|
| 261 |
+
majority_threshold: float = 0.5,
|
| 262 |
+
rng_seed: int = 0,
|
| 263 |
+
log: bool = True,
|
| 264 |
+
):
|
| 265 |
+
super().__init__()
|
| 266 |
+
self.input_modalities = list(input_modalities)
|
| 267 |
+
self.t_obs_sec = float(t_obs_sec)
|
| 268 |
+
self.t_fut_sec = float(t_fut_sec)
|
| 269 |
+
self.anchor_stride_sec = float(anchor_stride_sec)
|
| 270 |
+
self.downsample = int(downsample)
|
| 271 |
+
self.sr = SAMPLING_RATE_HZ // self.downsample
|
| 272 |
+
self.dataset_dir = Path(dataset_dir)
|
| 273 |
+
self.annot_dir = Path(annot_dir)
|
| 274 |
+
self.contact_threshold_g = float(contact_threshold_g)
|
| 275 |
+
self.per_cell_threshold_g = float(per_cell_threshold_g)
|
| 276 |
+
self.min_active_cells = int(min_active_cells)
|
| 277 |
+
self.use_per_cell_contact = bool(use_per_cell_contact)
|
| 278 |
+
self.label_mode = str(label_mode)
|
| 279 |
+
if self.label_mode not in ("binary", "three_class", "verb", "object"):
|
| 280 |
+
raise ValueError(f"label_mode must be binary|three_class|verb|object, got {label_mode}")
|
| 281 |
+
if self.label_mode == "binary":
|
| 282 |
+
self.num_classes = 2
|
| 283 |
+
elif self.label_mode == "three_class":
|
| 284 |
+
self.num_classes = 3
|
| 285 |
+
elif self.label_mode == "verb":
|
| 286 |
+
self.num_classes = len(VERB_LIST)
|
| 287 |
+
else: # object
|
| 288 |
+
self.num_classes = len(OBJECT_TOP_LIST)
|
| 289 |
+
self.sustained_threshold_sec = float(sustained_threshold_sec)
|
| 290 |
+
self.require_lift_for_sustained = bool(require_lift_for_sustained)
|
| 291 |
+
self.per_class_max = per_class_max
|
| 292 |
+
self.majority_threshold = float(majority_threshold)
|
| 293 |
+
self.T_obs = int(round(self.t_obs_sec * self.sr))
|
| 294 |
+
self.T_fut = int(round(self.t_fut_sec * self.sr))
|
| 295 |
+
|
| 296 |
+
self._items: List[dict] = []
|
| 297 |
+
self._modality_dims: Dict[str, int] = dict(expected_input_dims) if expected_input_dims else {}
|
| 298 |
+
rng = np.random.default_rng(rng_seed)
|
| 299 |
+
|
| 300 |
+
# Load pressure even if not in inputs, for event_type stratification.
|
| 301 |
+
load_mods = list(dict.fromkeys(list(self.input_modalities) + ["pressure"]))
|
| 302 |
+
|
| 303 |
+
# Per-class anchor pool
|
| 304 |
+
pools: Dict[int, List[dict]] = {c: [] for c in range(self.num_classes)}
|
| 305 |
+
sustained_thresh_frames = int(round(self.sustained_threshold_sec * self.sr))
|
| 306 |
+
|
| 307 |
+
for vol in volunteers:
|
| 308 |
+
vol_dir = self.dataset_dir / vol
|
| 309 |
+
if not vol_dir.is_dir():
|
| 310 |
+
continue
|
| 311 |
+
for scenario_dir in sorted(vol_dir.glob("s*")):
|
| 312 |
+
if not scenario_dir.is_dir():
|
| 313 |
+
continue
|
| 314 |
+
scene = scenario_dir.name
|
| 315 |
+
annot_path = self.annot_dir / vol / f"{scene}.json"
|
| 316 |
+
if not annot_path.exists():
|
| 317 |
+
continue
|
| 318 |
+
try:
|
| 319 |
+
sensors_all = _load_recording_sensors(
|
| 320 |
+
scenario_dir, vol, scene, load_mods
|
| 321 |
+
)
|
| 322 |
+
except Exception:
|
| 323 |
+
continue
|
| 324 |
+
if sensors_all is None or any(a is None for a in sensors_all.values()):
|
| 325 |
+
continue
|
| 326 |
+
|
| 327 |
+
pressure_full = sensors_all["pressure"] # (T, 50)
|
| 328 |
+
input_arrs = {m: sensors_all[m] for m in self.input_modalities}
|
| 329 |
+
for m, arr in input_arrs.items():
|
| 330 |
+
self._enforce_dim(input_arrs, m, arr, self._modality_dims)
|
| 331 |
+
|
| 332 |
+
T_avail = min(a.shape[0] for a in input_arrs.values())
|
| 333 |
+
T_avail = min(T_avail, pressure_full.shape[0])
|
| 334 |
+
if T_avail < (self.T_obs + self.T_fut) * self.downsample:
|
| 335 |
+
continue
|
| 336 |
+
|
| 337 |
+
# Build grasp mask at 100 Hz, then downsample.
|
| 338 |
+
mask_full = build_grasp_mask(annot_path, T_avail,
|
| 339 |
+
sr=SAMPLING_RATE_HZ)
|
| 340 |
+
if self.label_mode == "verb":
|
| 341 |
+
verb_full = build_verb_label(annot_path, T_avail, sr=SAMPLING_RATE_HZ)
|
| 342 |
+
verb_ds = verb_full[:T_avail:self.downsample]
|
| 343 |
+
else:
|
| 344 |
+
verb_ds = None
|
| 345 |
+
if self.label_mode == "object":
|
| 346 |
+
obj_full = build_object_label(annot_path, T_avail, sr=SAMPLING_RATE_HZ)
|
| 347 |
+
obj_ds = obj_full[:T_avail:self.downsample]
|
| 348 |
+
else:
|
| 349 |
+
obj_ds = None
|
| 350 |
+
if self.label_mode == "three_class" and self.require_lift_for_sustained:
|
| 351 |
+
lift_full = build_lift_eligible_mask(annot_path, T_avail, sr=SAMPLING_RATE_HZ)
|
| 352 |
+
lift_eligible_ds = lift_full[:T_avail:self.downsample]
|
| 353 |
+
else:
|
| 354 |
+
lift_eligible_ds = None
|
| 355 |
+
input_ds = {m: arr[:T_avail:self.downsample] for m, arr in input_arrs.items()}
|
| 356 |
+
pressure_ds = pressure_full[:T_avail:self.downsample]
|
| 357 |
+
mask_ds = mask_full[:T_avail:self.downsample]
|
| 358 |
+
T_ds = mask_ds.shape[0]
|
| 359 |
+
if self.use_per_cell_contact:
|
| 360 |
+
# n_active per frame: count cells with value > per_cell_threshold_g
|
| 361 |
+
n_active = (pressure_ds > self.per_cell_threshold_g).sum(axis=1)
|
| 362 |
+
contact_frame = n_active >= self.min_active_cells
|
| 363 |
+
else:
|
| 364 |
+
pressure_sum = pressure_ds.sum(axis=1)
|
| 365 |
+
contact_frame = pressure_sum > self.contact_threshold_g
|
| 366 |
+
|
| 367 |
+
stride = max(1, int(round(self.anchor_stride_sec * self.sr)))
|
| 368 |
+
first_anchor = self.T_obs
|
| 369 |
+
last_anchor = T_ds - self.T_fut
|
| 370 |
+
if last_anchor <= first_anchor:
|
| 371 |
+
continue
|
| 372 |
+
|
| 373 |
+
for anchor in range(first_anchor, last_anchor + 1, stride):
|
| 374 |
+
fut_mask = mask_ds[anchor:anchor + self.T_fut]
|
| 375 |
+
if fut_mask.shape[0] != self.T_fut:
|
| 376 |
+
continue
|
| 377 |
+
annotation_is_grasp = fut_mask.mean() >= self.majority_threshold
|
| 378 |
+
|
| 379 |
+
if self.label_mode == "binary":
|
| 380 |
+
label = int(annotation_is_grasp)
|
| 381 |
+
elif self.label_mode == "three_class":
|
| 382 |
+
if not annotation_is_grasp:
|
| 383 |
+
label = 0 # NoGrasp
|
| 384 |
+
else:
|
| 385 |
+
# longest contiguous run of contact in future window
|
| 386 |
+
fut_contact = contact_frame[anchor:anchor + self.T_fut]
|
| 387 |
+
longest = 0; cur = 0
|
| 388 |
+
for v in fut_contact:
|
| 389 |
+
if v: cur += 1; longest = max(longest, cur)
|
| 390 |
+
else: cur = 0
|
| 391 |
+
is_sustained = longest >= sustained_thresh_frames
|
| 392 |
+
if is_sustained and self.require_lift_for_sustained:
|
| 393 |
+
# Demote to Class 1 unless majority of future window is in
|
| 394 |
+
# a "lift-eligible" segment (verb ∈ LIFT_VERBS or hand=both).
|
| 395 |
+
fut_lift = lift_eligible_ds[anchor:anchor + self.T_fut]
|
| 396 |
+
if fut_lift.mean() < 0.5:
|
| 397 |
+
is_sustained = False
|
| 398 |
+
label = 2 if is_sustained else 1
|
| 399 |
+
elif self.label_mode == "verb":
|
| 400 |
+
fut_v = verb_ds[anchor:anchor + self.T_fut]
|
| 401 |
+
counts = np.bincount(fut_v, minlength=self.num_classes)
|
| 402 |
+
label = int(np.argmax(counts))
|
| 403 |
+
else: # object — majority object in future window
|
| 404 |
+
fut_o = obj_ds[anchor:anchor + self.T_fut]
|
| 405 |
+
counts = np.bincount(fut_o, minlength=self.num_classes)
|
| 406 |
+
label = int(np.argmax(counts))
|
| 407 |
+
|
| 408 |
+
# event_type for stratification (4-class transition taxonomy)
|
| 409 |
+
past_high = contact_frame[anchor - self.T_obs:anchor].mean() > 0.5
|
| 410 |
+
fut_high = contact_frame[anchor:anchor + self.T_fut].mean() > 0.5
|
| 411 |
+
if not past_high and not fut_high: et = 0
|
| 412 |
+
elif not past_high and fut_high: et = 1
|
| 413 |
+
elif past_high and fut_high: et = 2
|
| 414 |
+
else: et = 3
|
| 415 |
+
|
| 416 |
+
past_slice = {m: arr[anchor - self.T_obs:anchor]
|
| 417 |
+
for m, arr in input_ds.items()}
|
| 418 |
+
if any(w.shape[0] != self.T_obs for w in past_slice.values()):
|
| 419 |
+
continue
|
| 420 |
+
|
| 421 |
+
item = {
|
| 422 |
+
"x": past_slice,
|
| 423 |
+
"label": label,
|
| 424 |
+
"event_type": et,
|
| 425 |
+
"meta": {"vol": vol, "scene": scene, "anchor_idx": int(anchor)},
|
| 426 |
+
}
|
| 427 |
+
pools[label].append(item)
|
| 428 |
+
|
| 429 |
+
# Balance classes if requested (cap larger pool to per_class_max)
|
| 430 |
+
if self.per_class_max is not None:
|
| 431 |
+
for c, pool in pools.items():
|
| 432 |
+
if len(pool) > self.per_class_max:
|
| 433 |
+
idx = rng.choice(len(pool), size=self.per_class_max, replace=False)
|
| 434 |
+
pools[c] = [pool[i] for i in sorted(idx)]
|
| 435 |
+
self._items = [it for c in range(self.num_classes) for it in pools[c]]
|
| 436 |
+
|
| 437 |
+
if not self._items:
|
| 438 |
+
raise RuntimeError("GraspStateDataset: collected 0 anchors.")
|
| 439 |
+
|
| 440 |
+
# Z-score inputs
|
| 441 |
+
if input_stats is None:
|
| 442 |
+
input_stats = self._compute_input_stats()
|
| 443 |
+
self._input_stats = input_stats
|
| 444 |
+
self._apply_input_stats(input_stats)
|
| 445 |
+
|
| 446 |
+
if log:
|
| 447 |
+
if self.label_mode == "binary":
|
| 448 |
+
class_names = CLASS_NAMES_BINARY
|
| 449 |
+
elif self.label_mode == "three_class":
|
| 450 |
+
class_names = CLASS_NAMES_THREE
|
| 451 |
+
elif self.label_mode == "verb":
|
| 452 |
+
class_names = {i: v for i, v in enumerate(VERB_LIST)}
|
| 453 |
+
else: # object
|
| 454 |
+
class_names = {i: v for i, v in enumerate(OBJECT_TOP_LIST)}
|
| 455 |
+
counts_class = {class_names[c]: sum(1 for it in self._items if it["label"] == c)
|
| 456 |
+
for c in range(self.num_classes)}
|
| 457 |
+
counts_event = {EVENT_NAMES[k]: sum(1 for it in self._items if it["event_type"] == k)
|
| 458 |
+
for k in (0, 1, 2, 3)}
|
| 459 |
+
print(f"[GraspStateDataset] vols={len(volunteers)} "
|
| 460 |
+
f"inputs={self.input_modalities} "
|
| 461 |
+
f"anchors={len(self._items)} class={counts_class} "
|
| 462 |
+
f"event={counts_event} "
|
| 463 |
+
f"T_obs={self.T_obs} T_fut={self.T_fut} sr={self.sr}Hz "
|
| 464 |
+
f"input_dims={self._modality_dims}", flush=True)
|
| 465 |
+
|
| 466 |
+
@staticmethod
|
| 467 |
+
def _enforce_dim(arrs, m, arr, dim_dict):
|
| 468 |
+
if m in dim_dict:
|
| 469 |
+
tgt = dim_dict[m]
|
| 470 |
+
if arr.shape[1] != tgt:
|
| 471 |
+
if arr.shape[1] < tgt:
|
| 472 |
+
pad = np.zeros((arr.shape[0], tgt - arr.shape[1]), dtype=np.float32)
|
| 473 |
+
arrs[m] = np.concatenate([arr, pad], axis=1)
|
| 474 |
+
else:
|
| 475 |
+
arrs[m] = arr[:, :tgt]
|
| 476 |
+
else:
|
| 477 |
+
dim_dict[m] = arr.shape[1]
|
| 478 |
+
|
| 479 |
+
def _compute_input_stats(self):
|
| 480 |
+
accs = {m: [] for m in self._modality_dims}
|
| 481 |
+
for it in self._items:
|
| 482 |
+
for m, w in it["x"].items():
|
| 483 |
+
accs[m].append(w)
|
| 484 |
+
out = {}
|
| 485 |
+
for m, ws in accs.items():
|
| 486 |
+
cat = np.concatenate(ws, axis=0)
|
| 487 |
+
mu = cat.mean(axis=0).astype(np.float32)
|
| 488 |
+
sd = cat.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd)
|
| 489 |
+
out[m] = (mu, sd.astype(np.float32))
|
| 490 |
+
return out
|
| 491 |
+
|
| 492 |
+
def _apply_input_stats(self, stats):
|
| 493 |
+
for it in self._items:
|
| 494 |
+
for m, w in it["x"].items():
|
| 495 |
+
if m in stats:
|
| 496 |
+
mu, sd = stats[m]
|
| 497 |
+
it["x"][m] = ((w - mu) / sd).astype(np.float32)
|
| 498 |
+
|
| 499 |
+
def __len__(self): return len(self._items)
|
| 500 |
+
|
| 501 |
+
def __getitem__(self, idx):
|
| 502 |
+
it = self._items[idx]
|
| 503 |
+
x = {m: torch.from_numpy(np.ascontiguousarray(w)) for m, w in it["x"].items()}
|
| 504 |
+
label = int(it["label"])
|
| 505 |
+
et = int(it["event_type"])
|
| 506 |
+
return x, label, et, it["meta"]
|
| 507 |
+
|
| 508 |
+
@property
|
| 509 |
+
def modality_dims(self): return dict(self._modality_dims)
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def collate_grasp_state(batch):
|
| 513 |
+
xs, labels, ets, metas = zip(*batch)
|
| 514 |
+
mods = list(xs[0].keys())
|
| 515 |
+
x_out = {m: torch.stack([x[m] for x in xs], dim=0) for m in mods}
|
| 516 |
+
y_out = torch.tensor(labels, dtype=torch.long)
|
| 517 |
+
et_out = torch.tensor(ets, dtype=torch.long)
|
| 518 |
+
return x_out, y_out, et_out, list(metas)
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
def build_grasp_train_test(
|
| 522 |
+
input_modalities,
|
| 523 |
+
t_obs_sec=1.0, t_fut_sec=0.5, anchor_stride_sec=0.25,
|
| 524 |
+
downsample=5,
|
| 525 |
+
dataset_dir=DEFAULT_DATASET_DIR, annot_dir=DEFAULT_ANNOT_DIR,
|
| 526 |
+
contact_threshold_g=5.0, per_class_max=None,
|
| 527 |
+
label_mode="binary", sustained_threshold_sec=0.3,
|
| 528 |
+
require_lift_for_sustained=False,
|
| 529 |
+
rng_seed=0,
|
| 530 |
+
train_vols=None, test_vols=None,
|
| 531 |
+
):
|
| 532 |
+
if train_vols is None: train_vols = TRAIN_VOLS_V3
|
| 533 |
+
if test_vols is None: test_vols = TEST_VOLS_V3
|
| 534 |
+
train = GraspStateDataset(
|
| 535 |
+
train_vols, input_modalities=input_modalities,
|
| 536 |
+
t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec,
|
| 537 |
+
anchor_stride_sec=anchor_stride_sec, downsample=downsample,
|
| 538 |
+
dataset_dir=dataset_dir, annot_dir=annot_dir,
|
| 539 |
+
contact_threshold_g=contact_threshold_g, per_class_max=per_class_max,
|
| 540 |
+
label_mode=label_mode, sustained_threshold_sec=sustained_threshold_sec,
|
| 541 |
+
require_lift_for_sustained=require_lift_for_sustained,
|
| 542 |
+
rng_seed=rng_seed, log=True,
|
| 543 |
+
)
|
| 544 |
+
test = GraspStateDataset(
|
| 545 |
+
test_vols, input_modalities=input_modalities,
|
| 546 |
+
t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec,
|
| 547 |
+
anchor_stride_sec=anchor_stride_sec, downsample=downsample,
|
| 548 |
+
dataset_dir=dataset_dir, annot_dir=annot_dir,
|
| 549 |
+
contact_threshold_g=contact_threshold_g, per_class_max=None, # don't cap test
|
| 550 |
+
label_mode=label_mode, sustained_threshold_sec=sustained_threshold_sec,
|
| 551 |
+
require_lift_for_sustained=require_lift_for_sustained,
|
| 552 |
+
input_stats=train._input_stats,
|
| 553 |
+
expected_input_dims=train._modality_dims,
|
| 554 |
+
rng_seed=rng_seed + 1, log=True,
|
| 555 |
+
)
|
| 556 |
+
return train, test
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
if __name__ == "__main__":
|
| 560 |
+
import argparse
|
| 561 |
+
ap = argparse.ArgumentParser()
|
| 562 |
+
ap.add_argument("--input_modalities", default="emg,imu,mocap")
|
| 563 |
+
ap.add_argument("--t_obs", type=float, default=1.0)
|
| 564 |
+
ap.add_argument("--t_fut", type=float, default=0.5)
|
| 565 |
+
args = ap.parse_args()
|
| 566 |
+
tr, te = build_grasp_train_test(
|
| 567 |
+
input_modalities=args.input_modalities.split(","),
|
| 568 |
+
t_obs_sec=args.t_obs, t_fut_sec=args.t_fut,
|
| 569 |
+
)
|
| 570 |
+
x, y, et, meta = tr[0]
|
| 571 |
+
print(f"sample: x={ {m: tuple(v.shape) for m,v in x.items()} } y={y} et={et}")
|
experiments/data/dataset_seqpred.py
ADDED
|
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Segment-to-Next-Segment Triplet Prediction dataset (T10).
|
| 3 |
+
|
| 4 |
+
For every annotated action segment k in every recording:
|
| 5 |
+
anchor_t = start_time(segment_k) - T_fut (seconds)
|
| 6 |
+
observation = sensor frames in [anchor_t - T_obs, anchor_t]
|
| 7 |
+
target = triplet labels of segment_k: (verb_fine, verb_composite,
|
| 8 |
+
noun, hand)
|
| 9 |
+
|
| 10 |
+
Segments whose observation window would spill before t=0 of the recording
|
| 11 |
+
are skipped (no left-padding), so we never mix noise with real sensor data.
|
| 12 |
+
|
| 13 |
+
Strategy A is enforced in taxonomy.classify_segment(): segments whose noun is
|
| 14 |
+
not in the kept set (<50 occurrences) are dropped entirely.
|
| 15 |
+
|
| 16 |
+
Per-modality tensors are returned as a dict so downstream models can either
|
| 17 |
+
concat them (single-flow baselines) or keep them separate (our cross-modal
|
| 18 |
+
fusion model). A float mask is returned alongside the sensor tensor so
|
| 19 |
+
variable-length obs windows can be padded within a batch.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
# pandas must be imported BEFORE torch/numpy to avoid a GLIBCXX load-order bug
|
| 25 |
+
# on this cluster.
|
| 26 |
+
import pandas as pd
|
| 27 |
+
|
| 28 |
+
import json
|
| 29 |
+
import os
|
| 30 |
+
import sys
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
from typing import Dict, List, Optional, Sequence, Tuple
|
| 33 |
+
|
| 34 |
+
import numpy as np
|
| 35 |
+
import torch
|
| 36 |
+
from torch.utils.data import Dataset
|
| 37 |
+
|
| 38 |
+
# Make sibling modules importable from either (a) the neurips26 root, or
|
| 39 |
+
# (b) the frozen row/code/ folder (populated by setup_row.sh).
|
| 40 |
+
_THIS = Path(__file__).resolve()
|
| 41 |
+
sys.path.insert(0, str(_THIS.parent)) # code/ itself
|
| 42 |
+
sys.path.insert(0, str(_THIS.parent.parent)) # neurips26/
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
from data.dataset import ( # noqa: E402
|
| 46 |
+
MODALITY_FILES, load_modality_array,
|
| 47 |
+
)
|
| 48 |
+
from experiments.taxonomy import ( # noqa: E402
|
| 49 |
+
classify_segment, NOUN, NUM_VERB_FINE, NUM_VERB_COMPOSITE, NUM_NOUN,
|
| 50 |
+
NUM_HAND,
|
| 51 |
+
)
|
| 52 |
+
except ModuleNotFoundError:
|
| 53 |
+
from dataset import ( # noqa: E402
|
| 54 |
+
MODALITY_FILES, load_modality_array,
|
| 55 |
+
)
|
| 56 |
+
from taxonomy import ( # noqa: E402
|
| 57 |
+
classify_segment, NOUN, NUM_VERB_FINE, NUM_VERB_COMPOSITE, NUM_NOUN,
|
| 58 |
+
NUM_HAND,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# ---------------------------------------------------------------------------
|
| 62 |
+
# Constants
|
| 63 |
+
# ---------------------------------------------------------------------------
|
| 64 |
+
|
| 65 |
+
# Hard-code the dataset and annotation paths. The frozen row/code/ folders sit
|
| 66 |
+
# at arbitrary depths under the repo, so relative-to-__file__ discovery is
|
| 67 |
+
# unreliable. An env override is available for e.g. running on a mirror.
|
| 68 |
+
REPO = Path(os.environ.get(
|
| 69 |
+
"DAILYACT_REPO", "${PULSE_ROOT}"
|
| 70 |
+
))
|
| 71 |
+
DEFAULT_DATASET_DIR = REPO / "aligned_gy"
|
| 72 |
+
DEFAULT_ANNOT_DIR = REPO / "annotations_v3"
|
| 73 |
+
|
| 74 |
+
SAMPLING_RATE_HZ = 100
|
| 75 |
+
# 5x downsample -> 20 Hz. Matches the existing pipeline in dataset.py.
|
| 76 |
+
DEFAULT_DOWNSAMPLE = 5
|
| 77 |
+
|
| 78 |
+
VALID_MODALITIES = ("mocap", "emg", "eyetrack", "imu", "pressure")
|
| 79 |
+
|
| 80 |
+
# Fixed subject-independent split. Hand-picked 5 test volunteers with full
|
| 81 |
+
# 8-scene coverage, spread across the ID range. Any volunteer not listed
|
| 82 |
+
# below but annotated in v3 is assumed to be train data (so the lists stay
|
| 83 |
+
# stable as more volunteers get annotated).
|
| 84 |
+
TEST_VOLS_V3 = ["v14", "v30", "v34", "v38", "v41"]
|
| 85 |
+
TRAIN_VOLS_V3 = [
|
| 86 |
+
"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
|
| 87 |
+
"v11", "v12", "v13", "v15", "v16", "v17", "v18", "v19", "v20",
|
| 88 |
+
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
|
| 89 |
+
"v31", "v32", "v33", "v35", "v36", "v37", "v39", "v40",
|
| 90 |
+
]
|
| 91 |
+
assert set(TRAIN_VOLS_V3).isdisjoint(TEST_VOLS_V3), "Split must be disjoint"
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# ---------------------------------------------------------------------------
|
| 95 |
+
# Helpers
|
| 96 |
+
# ---------------------------------------------------------------------------
|
| 97 |
+
|
| 98 |
+
def _parse_ts(ts: str) -> float:
|
| 99 |
+
"""Parse 'HH:MM:SS' or 'MM:SS' (or 'M:S') into seconds."""
|
| 100 |
+
parts = ts.strip().split(":")
|
| 101 |
+
try:
|
| 102 |
+
if len(parts) == 2:
|
| 103 |
+
return float(parts[0]) * 60 + float(parts[1])
|
| 104 |
+
if len(parts) == 3:
|
| 105 |
+
return float(parts[0]) * 3600 + float(parts[1]) * 60 + float(parts[2])
|
| 106 |
+
except ValueError:
|
| 107 |
+
return 0.0
|
| 108 |
+
return 0.0
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def parse_ts_range(ts_range: str) -> Tuple[float, float]:
|
| 112 |
+
"""Parse 'MM:SS-MM:SS' or 'HH:MM:SS-HH:MM:SS' into (start_sec, end_sec)."""
|
| 113 |
+
if "-" not in ts_range:
|
| 114 |
+
return 0.0, 0.0
|
| 115 |
+
a, b = ts_range.split("-", 1)
|
| 116 |
+
return _parse_ts(a), _parse_ts(b)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _load_recording_sensors(
|
| 120 |
+
scenario_dir: Path, vol: str, scenario: str,
|
| 121 |
+
modalities: Sequence[str],
|
| 122 |
+
) -> Optional[Dict[str, np.ndarray]]:
|
| 123 |
+
"""Load each requested modality as a (T, F_mod) float32 array at 100 Hz.
|
| 124 |
+
|
| 125 |
+
Returns None if any requested modality is missing or corrupted."""
|
| 126 |
+
out: Dict[str, np.ndarray] = {}
|
| 127 |
+
for mod in modalities:
|
| 128 |
+
if mod == "mocap":
|
| 129 |
+
fp = scenario_dir / f"aligned_{vol}{scenario}_s_Q.tsv"
|
| 130 |
+
else:
|
| 131 |
+
fp = scenario_dir / MODALITY_FILES[mod]
|
| 132 |
+
if not fp.exists():
|
| 133 |
+
return None
|
| 134 |
+
arr = load_modality_array(str(fp), mod)
|
| 135 |
+
if arr is None:
|
| 136 |
+
return None
|
| 137 |
+
out[mod] = arr.astype(np.float32)
|
| 138 |
+
# Align lengths across modalities (take min); all start at sensor t=0.
|
| 139 |
+
T = min(a.shape[0] for a in out.values())
|
| 140 |
+
for m in out:
|
| 141 |
+
out[m] = out[m][:T]
|
| 142 |
+
return out
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _load_annotations(annot_path: Path) -> List[dict]:
|
| 146 |
+
with open(annot_path) as f:
|
| 147 |
+
d = json.load(f)
|
| 148 |
+
return d.get("segments", [])
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# ---------------------------------------------------------------------------
|
| 152 |
+
# Dataset
|
| 153 |
+
# ---------------------------------------------------------------------------
|
| 154 |
+
|
| 155 |
+
class TripletSeqPredDataset(Dataset):
|
| 156 |
+
"""One sample per (annotated segment, recording) pair.
|
| 157 |
+
|
| 158 |
+
Sample schema returned by __getitem__:
|
| 159 |
+
x: dict {mod_name: FloatTensor(T_frames, F_mod)}
|
| 160 |
+
y: dict {'verb_fine': int, 'verb_composite': int,
|
| 161 |
+
'noun': int, 'hand': int}
|
| 162 |
+
meta: dict {'vol', 'scene', 'seg_idx', 'anchor_sec'}
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
def __init__(
|
| 166 |
+
self,
|
| 167 |
+
volunteers: Sequence[str],
|
| 168 |
+
modalities: Sequence[str] = ("imu", "mocap", "emg", "eyetrack", "pressure"),
|
| 169 |
+
t_obs_sec: float = 8.0,
|
| 170 |
+
t_fut_sec: float = 2.0,
|
| 171 |
+
downsample: int = DEFAULT_DOWNSAMPLE,
|
| 172 |
+
dataset_dir: Path = DEFAULT_DATASET_DIR,
|
| 173 |
+
annot_dir: Path = DEFAULT_ANNOT_DIR,
|
| 174 |
+
stats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]] = None,
|
| 175 |
+
min_seg_duration_sec: float = 0.4,
|
| 176 |
+
log: bool = True,
|
| 177 |
+
mode: str = "recognition",
|
| 178 |
+
):
|
| 179 |
+
for m in modalities:
|
| 180 |
+
if m not in VALID_MODALITIES:
|
| 181 |
+
raise ValueError(f"Unknown modality: {m}")
|
| 182 |
+
if mode not in ("recognition", "anticipation"):
|
| 183 |
+
raise ValueError(f"mode must be 'recognition' or 'anticipation', got {mode!r}")
|
| 184 |
+
|
| 185 |
+
self.modalities = tuple(modalities)
|
| 186 |
+
self.t_obs_sec = float(t_obs_sec)
|
| 187 |
+
self.t_fut_sec = float(t_fut_sec)
|
| 188 |
+
self.downsample = int(downsample)
|
| 189 |
+
self.dataset_dir = Path(dataset_dir)
|
| 190 |
+
self.annot_dir = Path(annot_dir)
|
| 191 |
+
self.mode = mode
|
| 192 |
+
|
| 193 |
+
# Effective obs-window length in frames at the post-downsample rate.
|
| 194 |
+
sr = SAMPLING_RATE_HZ // self.downsample # 20 Hz
|
| 195 |
+
self.T_frames = int(round(self.t_obs_sec * sr)) # used only for anticipation
|
| 196 |
+
self._sr_down = sr
|
| 197 |
+
|
| 198 |
+
self._items: List[dict] = []
|
| 199 |
+
self._modality_dims: Dict[str, int] = {}
|
| 200 |
+
|
| 201 |
+
# If re-using training-set stats, force each modality's feature
|
| 202 |
+
# layout to match so we never apply a (14,)-mean to (24,)-data.
|
| 203 |
+
if stats is not None:
|
| 204 |
+
for m, (mu, _) in stats.items():
|
| 205 |
+
self._modality_dims[m] = mu.shape[1]
|
| 206 |
+
|
| 207 |
+
stats_counts = {
|
| 208 |
+
"recordings_scanned": 0,
|
| 209 |
+
"recordings_used": 0,
|
| 210 |
+
"segments_seen": 0,
|
| 211 |
+
"seg_dropped_label": 0, # Strategy A + invalid verb/hand
|
| 212 |
+
"seg_dropped_too_early": 0, # obs window before t=0
|
| 213 |
+
"seg_dropped_short": 0,
|
| 214 |
+
"seg_kept": 0,
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
for vol in volunteers:
|
| 218 |
+
vol_dir = self.dataset_dir / vol
|
| 219 |
+
if not vol_dir.is_dir():
|
| 220 |
+
continue
|
| 221 |
+
for scenario_dir in sorted(vol_dir.glob("s*")):
|
| 222 |
+
if not scenario_dir.is_dir():
|
| 223 |
+
continue
|
| 224 |
+
scene = scenario_dir.name
|
| 225 |
+
if scene not in {f"s{i}" for i in range(1, 9)}:
|
| 226 |
+
continue
|
| 227 |
+
|
| 228 |
+
annot_path = self.annot_dir / vol / f"{scene}.json"
|
| 229 |
+
if not annot_path.exists():
|
| 230 |
+
continue
|
| 231 |
+
|
| 232 |
+
stats_counts["recordings_scanned"] += 1
|
| 233 |
+
|
| 234 |
+
sensors = _load_recording_sensors(scenario_dir, vol, scene,
|
| 235 |
+
self.modalities)
|
| 236 |
+
if sensors is None:
|
| 237 |
+
continue
|
| 238 |
+
|
| 239 |
+
# Store / validate per-modality dim
|
| 240 |
+
for m, arr in sensors.items():
|
| 241 |
+
if m in self._modality_dims:
|
| 242 |
+
if arr.shape[1] != self._modality_dims[m]:
|
| 243 |
+
# Pad or truncate to match the first seen dim.
|
| 244 |
+
target = self._modality_dims[m]
|
| 245 |
+
if arr.shape[1] < target:
|
| 246 |
+
pad = np.zeros((arr.shape[0], target - arr.shape[1]),
|
| 247 |
+
dtype=np.float32)
|
| 248 |
+
sensors[m] = np.concatenate([arr, pad], axis=1)
|
| 249 |
+
else:
|
| 250 |
+
sensors[m] = arr[:, :target]
|
| 251 |
+
else:
|
| 252 |
+
self._modality_dims[m] = arr.shape[1]
|
| 253 |
+
|
| 254 |
+
segs = _load_annotations(annot_path)
|
| 255 |
+
rec_used = False
|
| 256 |
+
# BOS index for first segment in a recording (or after dropped segs).
|
| 257 |
+
BOS_VC = NUM_VERB_COMPOSITE # = 6
|
| 258 |
+
BOS_N = NUM_NOUN # = 34
|
| 259 |
+
prev_vc, prev_n = BOS_VC, BOS_N
|
| 260 |
+
for seg_idx, seg in enumerate(segs):
|
| 261 |
+
stats_counts["segments_seen"] += 1
|
| 262 |
+
a = seg.get("action_annotation", {})
|
| 263 |
+
labels = classify_segment(a)
|
| 264 |
+
if labels is None:
|
| 265 |
+
stats_counts["seg_dropped_label"] += 1
|
| 266 |
+
# do not advance prev (skipped segment doesn't update context)
|
| 267 |
+
continue
|
| 268 |
+
|
| 269 |
+
start_sec, end_sec = parse_ts_range(seg.get("timestamp", ""))
|
| 270 |
+
if end_sec - start_sec < min_seg_duration_sec:
|
| 271 |
+
stats_counts["seg_dropped_short"] += 1
|
| 272 |
+
continue
|
| 273 |
+
|
| 274 |
+
if self.mode == "anticipation":
|
| 275 |
+
anchor_sec = start_sec - self.t_fut_sec
|
| 276 |
+
obs_start_sec = anchor_sec - self.t_obs_sec
|
| 277 |
+
if obs_start_sec < 0:
|
| 278 |
+
stats_counts["seg_dropped_too_early"] += 1
|
| 279 |
+
continue
|
| 280 |
+
i0 = int(round(obs_start_sec * SAMPLING_RATE_HZ))
|
| 281 |
+
i1 = int(round(anchor_sec * SAMPLING_RATE_HZ))
|
| 282 |
+
meta_extra = {"anchor_sec": anchor_sec}
|
| 283 |
+
else: # recognition
|
| 284 |
+
# Use the segment's own [start, end] as the input window.
|
| 285 |
+
i0 = int(round(start_sec * SAMPLING_RATE_HZ))
|
| 286 |
+
i1 = int(round(end_sec * SAMPLING_RATE_HZ))
|
| 287 |
+
meta_extra = {"start_sec": start_sec, "end_sec": end_sec}
|
| 288 |
+
|
| 289 |
+
T_avail = min(a.shape[0] for a in sensors.values())
|
| 290 |
+
if i1 > T_avail:
|
| 291 |
+
stats_counts["seg_dropped_too_early"] += 1
|
| 292 |
+
continue
|
| 293 |
+
if i0 < 0:
|
| 294 |
+
i0 = 0 # safety; recognition mode shouldn't hit this
|
| 295 |
+
|
| 296 |
+
window: Dict[str, np.ndarray] = {}
|
| 297 |
+
for m, arr in sensors.items():
|
| 298 |
+
w = arr[i0:i1]
|
| 299 |
+
# Downsample: decimate every `downsample`-th frame.
|
| 300 |
+
w = w[::self.downsample]
|
| 301 |
+
window[m] = w
|
| 302 |
+
|
| 303 |
+
# Must have at least 4 post-downsample frames to be useful.
|
| 304 |
+
min_T = min(w.shape[0] for w in window.values())
|
| 305 |
+
if min_T < 4:
|
| 306 |
+
stats_counts["seg_dropped_short"] += 1
|
| 307 |
+
continue
|
| 308 |
+
|
| 309 |
+
self._items.append({
|
| 310 |
+
"x": window,
|
| 311 |
+
"y": labels,
|
| 312 |
+
"prev": {"verb_composite": prev_vc, "noun": prev_n},
|
| 313 |
+
"meta": {
|
| 314 |
+
"vol": vol, "scene": scene,
|
| 315 |
+
"seg_idx": seg_idx, **meta_extra,
|
| 316 |
+
},
|
| 317 |
+
})
|
| 318 |
+
stats_counts["seg_kept"] += 1
|
| 319 |
+
# Update context for next kept segment in this recording.
|
| 320 |
+
prev_vc = labels["verb_composite"]
|
| 321 |
+
prev_n = labels["noun"]
|
| 322 |
+
rec_used = True
|
| 323 |
+
|
| 324 |
+
if rec_used:
|
| 325 |
+
stats_counts["recordings_used"] += 1
|
| 326 |
+
|
| 327 |
+
if len(self._items) == 0:
|
| 328 |
+
raise RuntimeError(
|
| 329 |
+
"No samples collected. Check annot_dir, modalities, t_obs, t_fut."
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# Per-modality z-score normalization using training-set stats.
|
| 333 |
+
if stats is None:
|
| 334 |
+
stats = self._compute_stats()
|
| 335 |
+
self._stats = stats
|
| 336 |
+
self._apply_stats(stats)
|
| 337 |
+
|
| 338 |
+
if log:
|
| 339 |
+
print(f"[TripletSeqPredDataset:{self.mode}] "
|
| 340 |
+
f"vols={len(volunteers)} "
|
| 341 |
+
f"recs_scan={stats_counts['recordings_scanned']} "
|
| 342 |
+
f"recs_used={stats_counts['recordings_used']} "
|
| 343 |
+
f"segs_seen={stats_counts['segments_seen']} "
|
| 344 |
+
f"kept={stats_counts['seg_kept']} "
|
| 345 |
+
f"drop_label={stats_counts['seg_dropped_label']} "
|
| 346 |
+
f"drop_early={stats_counts['seg_dropped_too_early']} "
|
| 347 |
+
f"drop_short={stats_counts['seg_dropped_short']}",
|
| 348 |
+
flush=True)
|
| 349 |
+
print(f" modality_dims={self._modality_dims} "
|
| 350 |
+
f"T_frames={self.T_frames} sr_down={sr}Hz",
|
| 351 |
+
flush=True)
|
| 352 |
+
self.stats_counts = stats_counts
|
| 353 |
+
|
| 354 |
+
# ----- stats (per-modality mean/std on training split) -----
|
| 355 |
+
def _compute_stats(self) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
|
| 356 |
+
acc: Dict[str, List[np.ndarray]] = {m: [] for m in self.modalities}
|
| 357 |
+
for it in self._items:
|
| 358 |
+
for m, w in it["x"].items():
|
| 359 |
+
acc[m].append(w.astype(np.float64))
|
| 360 |
+
out: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
|
| 361 |
+
for m, arrs in acc.items():
|
| 362 |
+
cat = np.concatenate(arrs, axis=0)
|
| 363 |
+
mu = cat.mean(axis=0, keepdims=True)
|
| 364 |
+
sd = cat.std(axis=0, keepdims=True)
|
| 365 |
+
sd[sd < 1e-8] = 1.0
|
| 366 |
+
out[m] = (mu.astype(np.float32), sd.astype(np.float32))
|
| 367 |
+
return out
|
| 368 |
+
|
| 369 |
+
def _apply_stats(self, stats: Dict[str, Tuple[np.ndarray, np.ndarray]]) -> None:
|
| 370 |
+
for it in self._items:
|
| 371 |
+
for m, w in it["x"].items():
|
| 372 |
+
mu, sd = stats[m]
|
| 373 |
+
z = (w.astype(np.float32) - mu) / sd
|
| 374 |
+
z = np.nan_to_num(z, nan=0.0, posinf=0.0, neginf=0.0)
|
| 375 |
+
it["x"][m] = z.astype(np.float32)
|
| 376 |
+
|
| 377 |
+
def get_stats(self) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
|
| 378 |
+
return self._stats
|
| 379 |
+
|
| 380 |
+
# ----- Dataset protocol -----
|
| 381 |
+
def __len__(self) -> int:
|
| 382 |
+
return len(self._items)
|
| 383 |
+
|
| 384 |
+
def __getitem__(self, idx: int):
|
| 385 |
+
it = self._items[idx]
|
| 386 |
+
x = {m: torch.from_numpy(w) for m, w in it["x"].items()}
|
| 387 |
+
y = it["y"]
|
| 388 |
+
meta = it["meta"]
|
| 389 |
+
prev = it.get("prev", {"verb_composite": NUM_VERB_COMPOSITE, "noun": NUM_NOUN})
|
| 390 |
+
return x, y, meta, prev
|
| 391 |
+
|
| 392 |
+
# ----- convenience -----
|
| 393 |
+
@property
|
| 394 |
+
def modality_dims(self) -> Dict[str, int]:
|
| 395 |
+
return dict(self._modality_dims)
|
| 396 |
+
|
| 397 |
+
@property
|
| 398 |
+
def total_feat_dim(self) -> int:
|
| 399 |
+
return sum(self._modality_dims.values())
|
| 400 |
+
|
| 401 |
+
def class_counts(self) -> Dict[str, np.ndarray]:
|
| 402 |
+
vf = np.zeros(NUM_VERB_FINE, dtype=np.int64)
|
| 403 |
+
vc = np.zeros(NUM_VERB_COMPOSITE, dtype=np.int64)
|
| 404 |
+
n = np.zeros(NUM_NOUN, dtype=np.int64)
|
| 405 |
+
h = np.zeros(NUM_HAND, dtype=np.int64)
|
| 406 |
+
for it in self._items:
|
| 407 |
+
y = it["y"]
|
| 408 |
+
vf[y["verb_fine"]] += 1
|
| 409 |
+
vc[y["verb_composite"]] += 1
|
| 410 |
+
n[y["noun"]] += 1
|
| 411 |
+
h[y["hand"]] += 1
|
| 412 |
+
return {"verb_fine": vf, "verb_composite": vc, "noun": n, "hand": h}
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
# ---------------------------------------------------------------------------
|
| 416 |
+
# Collate: pad each modality to the max T_frames in the batch
|
| 417 |
+
# ---------------------------------------------------------------------------
|
| 418 |
+
|
| 419 |
+
def collate_triplet(batch):
|
| 420 |
+
"""Stack samples into batched tensors. Backward-compatible: accepts
|
| 421 |
+
samples of either (x, y, meta) or (x, y, meta, prev) form.
|
| 422 |
+
|
| 423 |
+
Returned:
|
| 424 |
+
x: dict[mod] -> FloatTensor (B, T_max, F_mod)
|
| 425 |
+
mask: BoolTensor (B, T_max)
|
| 426 |
+
lens: LongTensor (B,)
|
| 427 |
+
y: dict (each -> LongTensor (B,))
|
| 428 |
+
meta: list of dicts
|
| 429 |
+
prev: dict {'verb_composite': LongTensor (B,), 'noun': LongTensor (B,)}
|
| 430 |
+
values are class indices, with NUM_VERB_COMPOSITE / NUM_NOUN
|
| 431 |
+
used as a BOS sentinel for the first segment in a recording.
|
| 432 |
+
"""
|
| 433 |
+
has_prev = len(batch[0]) >= 4
|
| 434 |
+
if has_prev:
|
| 435 |
+
xs, ys, metas, prevs = zip(*batch)
|
| 436 |
+
else:
|
| 437 |
+
xs, ys, metas = zip(*batch)
|
| 438 |
+
prevs = [{"verb_composite": NUM_VERB_COMPOSITE, "noun": NUM_NOUN} for _ in batch]
|
| 439 |
+
B = len(batch)
|
| 440 |
+
mods = list(xs[0].keys())
|
| 441 |
+
lens = torch.tensor([x[mods[0]].shape[0] for x in xs], dtype=torch.long)
|
| 442 |
+
T_max = int(lens.max().item())
|
| 443 |
+
|
| 444 |
+
x_out: Dict[str, torch.Tensor] = {}
|
| 445 |
+
for m in mods:
|
| 446 |
+
F = xs[0][m].shape[1]
|
| 447 |
+
padded = torch.zeros(B, T_max, F, dtype=torch.float32)
|
| 448 |
+
for i, x in enumerate(xs):
|
| 449 |
+
w = x[m]
|
| 450 |
+
padded[i, :w.shape[0]] = w
|
| 451 |
+
x_out[m] = padded
|
| 452 |
+
|
| 453 |
+
ar = torch.arange(T_max).unsqueeze(0)
|
| 454 |
+
mask = ar < lens.unsqueeze(1)
|
| 455 |
+
|
| 456 |
+
y_out = {
|
| 457 |
+
k: torch.tensor([y[k] for y in ys], dtype=torch.long)
|
| 458 |
+
for k in ("verb_fine", "verb_composite", "noun", "hand")
|
| 459 |
+
}
|
| 460 |
+
prev_out = {
|
| 461 |
+
"verb_composite": torch.tensor([p["verb_composite"] for p in prevs], dtype=torch.long),
|
| 462 |
+
"noun": torch.tensor([p["noun"] for p in prevs], dtype=torch.long),
|
| 463 |
+
}
|
| 464 |
+
return x_out, mask, lens, y_out, list(metas), prev_out
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
# ---------------------------------------------------------------------------
|
| 468 |
+
# Convenience: build paired train/test datasets with shared normalization
|
| 469 |
+
# ---------------------------------------------------------------------------
|
| 470 |
+
|
| 471 |
+
def build_train_test(
|
| 472 |
+
modalities: Sequence[str] = ("imu", "mocap", "emg", "eyetrack", "pressure"),
|
| 473 |
+
t_obs_sec: float = 8.0,
|
| 474 |
+
t_fut_sec: float = 2.0,
|
| 475 |
+
downsample: int = DEFAULT_DOWNSAMPLE,
|
| 476 |
+
dataset_dir: Path = DEFAULT_DATASET_DIR,
|
| 477 |
+
annot_dir: Path = DEFAULT_ANNOT_DIR,
|
| 478 |
+
mode: str = "recognition",
|
| 479 |
+
) -> Tuple["TripletSeqPredDataset", "TripletSeqPredDataset"]:
|
| 480 |
+
train = TripletSeqPredDataset(
|
| 481 |
+
TRAIN_VOLS_V3, modalities=modalities,
|
| 482 |
+
t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec, downsample=downsample,
|
| 483 |
+
dataset_dir=dataset_dir, annot_dir=annot_dir, mode=mode,
|
| 484 |
+
)
|
| 485 |
+
test = TripletSeqPredDataset(
|
| 486 |
+
TEST_VOLS_V3, modalities=modalities,
|
| 487 |
+
t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec, downsample=downsample,
|
| 488 |
+
dataset_dir=dataset_dir, annot_dir=annot_dir,
|
| 489 |
+
stats=train.get_stats(), mode=mode,
|
| 490 |
+
)
|
| 491 |
+
return train, test
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
# ---------------------------------------------------------------------------
|
| 495 |
+
# CLI: quick sanity check
|
| 496 |
+
# ---------------------------------------------------------------------------
|
| 497 |
+
|
| 498 |
+
if __name__ == "__main__":
|
| 499 |
+
import argparse
|
| 500 |
+
|
| 501 |
+
ap = argparse.ArgumentParser()
|
| 502 |
+
ap.add_argument("--modalities", type=str, default="imu,emg,eyetrack")
|
| 503 |
+
ap.add_argument("--t_obs", type=float, default=8.0)
|
| 504 |
+
ap.add_argument("--t_fut", type=float, default=2.0)
|
| 505 |
+
ap.add_argument("--smoke_n", type=int, default=3,
|
| 506 |
+
help="Inspect first N samples per split")
|
| 507 |
+
args = ap.parse_args()
|
| 508 |
+
|
| 509 |
+
mods = args.modalities.split(",")
|
| 510 |
+
print(f"Building train/test with modalities={mods} "
|
| 511 |
+
f"t_obs={args.t_obs}s t_fut={args.t_fut}s ...")
|
| 512 |
+
train, test = build_train_test(
|
| 513 |
+
modalities=mods,
|
| 514 |
+
t_obs_sec=args.t_obs,
|
| 515 |
+
t_fut_sec=args.t_fut,
|
| 516 |
+
)
|
| 517 |
+
print(f"train: {len(train)} samples | test: {len(test)} samples")
|
| 518 |
+
|
| 519 |
+
for name, ds in [("train", train), ("test", test)]:
|
| 520 |
+
counts = ds.class_counts()
|
| 521 |
+
print(f"\n[{name}] class counts:")
|
| 522 |
+
print(" verb_fine:", counts["verb_fine"].tolist())
|
| 523 |
+
print(" verb_composite:", counts["verb_composite"].tolist())
|
| 524 |
+
print(" noun (sum):", int(counts["noun"].sum()),
|
| 525 |
+
"nonzero:", int((counts["noun"] > 0).sum()))
|
| 526 |
+
print(" hand:", counts["hand"].tolist())
|
| 527 |
+
|
| 528 |
+
print(f"\n[{name}] first {args.smoke_n} samples:")
|
| 529 |
+
for i in range(min(args.smoke_n, len(ds))):
|
| 530 |
+
x, y, meta = ds[i]
|
| 531 |
+
shape_str = " ".join(f"{m}:{tuple(x[m].shape)}" for m in x)
|
| 532 |
+
print(f" {i:3d} {meta['vol']}/{meta['scene']}#{meta['seg_idx']:3d} "
|
| 533 |
+
f"anchor={meta['anchor_sec']:.2f}s y={y} {shape_str}")
|
experiments/data/dataset_signal_forecast.py
ADDED
|
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Frame-level future *signal* forecasting dataset (T8 v2).
|
| 2 |
+
|
| 3 |
+
Task definition
|
| 4 |
+
---------------
|
| 5 |
+
At a sampled anchor t in a recording:
|
| 6 |
+
past = sensor frames over [t - T_obs, t] ← input
|
| 7 |
+
future = target-modality frames over (t, t + T_fut] ← regression target
|
| 8 |
+
|
| 9 |
+
Unlike the v1 ForecastDataset (which targets per-frame verb-fine class), this
|
| 10 |
+
predicts the raw *signal* values of one chosen target modality. This directly
|
| 11 |
+
tests the Johansson 1984 / monzee 2003 hypothesis that cutaneous force
|
| 12 |
+
feedback drives sub-second motor planning at the *signal* level (motor
|
| 13 |
+
commands / kinematics), not at the level of slow-changing semantic verbs.
|
| 14 |
+
|
| 15 |
+
Anchor stratification (4 event types based on contact transitions)
|
| 16 |
+
------------------------------------------------------------------
|
| 17 |
+
For each candidate anchor, we compute pressure_sum on past and future windows
|
| 18 |
+
and label it by the (past_majority_contact, future_majority_contact) pair:
|
| 19 |
+
|
| 20 |
+
type 0 = non-contact (past low, future low) — control: pressure ~ 0
|
| 21 |
+
type 1 = pre-contact (past low, future high) — pressure foretells onset
|
| 22 |
+
type 2 = steady-grip (past high, future high) — sustained contact dynamics
|
| 23 |
+
type 3 = release (past high, future low) — letting-go dynamics
|
| 24 |
+
|
| 25 |
+
Per-event-type counts are reported and (optionally) capped to balance.
|
| 26 |
+
Evaluation is broken down per event type so we can see WHERE pressure helps.
|
| 27 |
+
"""
|
| 28 |
+
from __future__ import annotations
|
| 29 |
+
|
| 30 |
+
import sys
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
from typing import Dict, List, Optional, Sequence, Tuple
|
| 33 |
+
|
| 34 |
+
import numpy as np
|
| 35 |
+
import torch
|
| 36 |
+
from torch.utils.data import Dataset
|
| 37 |
+
|
| 38 |
+
THIS = Path(__file__).resolve()
|
| 39 |
+
sys.path.insert(0, str(THIS.parent))
|
| 40 |
+
sys.path.insert(0, str(THIS.parents[1]))
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
from experiments.dataset_seqpred import (
|
| 44 |
+
SAMPLING_RATE_HZ, _load_recording_sensors,
|
| 45 |
+
TRAIN_VOLS_V3, TEST_VOLS_V3,
|
| 46 |
+
DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR,
|
| 47 |
+
)
|
| 48 |
+
except ModuleNotFoundError:
|
| 49 |
+
from dataset_seqpred import (
|
| 50 |
+
SAMPLING_RATE_HZ, _load_recording_sensors,
|
| 51 |
+
TRAIN_VOLS_V3, TEST_VOLS_V3,
|
| 52 |
+
DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
EVENT_NAMES = {0: "non-contact", 1: "pre-contact", 2: "steady-grip", 3: "release"}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class SignalForecastDataset(Dataset):
|
| 60 |
+
"""Predict future T_fut frames of `target_modality` from past T_obs of `input_modalities`."""
|
| 61 |
+
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
volunteers: Sequence[str],
|
| 65 |
+
input_modalities: Sequence[str],
|
| 66 |
+
target_modality: str,
|
| 67 |
+
t_obs_sec: float = 1.5,
|
| 68 |
+
t_fut_sec: float = 0.5,
|
| 69 |
+
anchor_stride_sec: float = 0.25,
|
| 70 |
+
downsample: int = 5,
|
| 71 |
+
dataset_dir: Path = DEFAULT_DATASET_DIR,
|
| 72 |
+
annot_dir: Path = DEFAULT_ANNOT_DIR,
|
| 73 |
+
contact_threshold_g: float = 5.0,
|
| 74 |
+
per_event_max: Optional[int] = None,
|
| 75 |
+
input_stats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]] = None,
|
| 76 |
+
target_stats: Optional[Tuple[np.ndarray, np.ndarray]] = None,
|
| 77 |
+
future_pressure_stats: Optional[Tuple[np.ndarray, np.ndarray]] = None,
|
| 78 |
+
expected_input_dims: Optional[Dict[str, int]] = None,
|
| 79 |
+
expected_target_dim: Optional[int] = None,
|
| 80 |
+
include_future_pressure: bool = False,
|
| 81 |
+
rng_seed: int = 0,
|
| 82 |
+
log: bool = True,
|
| 83 |
+
):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.input_modalities = list(input_modalities)
|
| 86 |
+
self.target_modality = str(target_modality)
|
| 87 |
+
self.t_obs_sec = float(t_obs_sec)
|
| 88 |
+
self.t_fut_sec = float(t_fut_sec)
|
| 89 |
+
self.anchor_stride_sec = float(anchor_stride_sec)
|
| 90 |
+
self.downsample = int(downsample)
|
| 91 |
+
self.sr = SAMPLING_RATE_HZ // self.downsample
|
| 92 |
+
self.dataset_dir = Path(dataset_dir)
|
| 93 |
+
self.annot_dir = Path(annot_dir)
|
| 94 |
+
self.contact_threshold_g = float(contact_threshold_g)
|
| 95 |
+
self.per_event_max = per_event_max
|
| 96 |
+
self.include_future_pressure = bool(include_future_pressure)
|
| 97 |
+
self.T_obs = int(round(self.t_obs_sec * self.sr))
|
| 98 |
+
self.T_fut = int(round(self.t_fut_sec * self.sr))
|
| 99 |
+
|
| 100 |
+
self._items: List[dict] = []
|
| 101 |
+
self._modality_dims: Dict[str, int] = dict(expected_input_dims) if expected_input_dims else {}
|
| 102 |
+
self._target_dim: int = int(expected_target_dim) if expected_target_dim else -1
|
| 103 |
+
rng = np.random.default_rng(rng_seed)
|
| 104 |
+
|
| 105 |
+
# Modalities to load: union of inputs + target + pressure (for filter)
|
| 106 |
+
load_mods = list(dict.fromkeys(
|
| 107 |
+
list(self.input_modalities) + [self.target_modality, "pressure"]
|
| 108 |
+
))
|
| 109 |
+
|
| 110 |
+
# Per-event-type pool of candidate anchor records
|
| 111 |
+
pools: Dict[int, List[dict]] = {0: [], 1: [], 2: [], 3: []}
|
| 112 |
+
|
| 113 |
+
for vol in volunteers:
|
| 114 |
+
vol_dir = self.dataset_dir / vol
|
| 115 |
+
if not vol_dir.is_dir():
|
| 116 |
+
continue
|
| 117 |
+
for scenario_dir in sorted(vol_dir.glob("s*")):
|
| 118 |
+
if not scenario_dir.is_dir():
|
| 119 |
+
continue
|
| 120 |
+
scene = scenario_dir.name
|
| 121 |
+
annot_path = self.annot_dir / vol / f"{scene}.json"
|
| 122 |
+
if not annot_path.exists():
|
| 123 |
+
continue
|
| 124 |
+
try:
|
| 125 |
+
sensors_all = _load_recording_sensors(
|
| 126 |
+
scenario_dir, vol, scene, load_mods
|
| 127 |
+
)
|
| 128 |
+
except Exception:
|
| 129 |
+
continue
|
| 130 |
+
if sensors_all is None or any(a is None for a in sensors_all.values()):
|
| 131 |
+
continue
|
| 132 |
+
|
| 133 |
+
pressure_full = sensors_all["pressure"] # (T, 50)
|
| 134 |
+
target_full = sensors_all[self.target_modality]
|
| 135 |
+
input_arrs = {m: sensors_all[m] for m in self.input_modalities}
|
| 136 |
+
|
| 137 |
+
# Track input modality dims
|
| 138 |
+
for m, arr in input_arrs.items():
|
| 139 |
+
self._enforce_dim(input_arrs, m, arr, self._modality_dims)
|
| 140 |
+
# Track target dim
|
| 141 |
+
if self._target_dim < 0:
|
| 142 |
+
self._target_dim = target_full.shape[1]
|
| 143 |
+
elif target_full.shape[1] != self._target_dim:
|
| 144 |
+
if target_full.shape[1] < self._target_dim:
|
| 145 |
+
pad = np.zeros((target_full.shape[0], self._target_dim - target_full.shape[1]),
|
| 146 |
+
dtype=np.float32)
|
| 147 |
+
target_full = np.concatenate([target_full, pad], axis=1)
|
| 148 |
+
else:
|
| 149 |
+
target_full = target_full[:, :self._target_dim]
|
| 150 |
+
|
| 151 |
+
T_avail = min(a.shape[0] for a in input_arrs.values())
|
| 152 |
+
T_avail = min(T_avail, target_full.shape[0], pressure_full.shape[0])
|
| 153 |
+
if T_avail < (self.T_obs + self.T_fut) * self.downsample:
|
| 154 |
+
continue
|
| 155 |
+
|
| 156 |
+
# Downsample to 20 Hz
|
| 157 |
+
input_ds = {m: arr[:T_avail:self.downsample] for m, arr in input_arrs.items()}
|
| 158 |
+
target_ds = target_full[:T_avail:self.downsample]
|
| 159 |
+
pressure_ds = pressure_full[:T_avail:self.downsample]
|
| 160 |
+
T_ds = target_ds.shape[0]
|
| 161 |
+
pressure_sum = pressure_ds.sum(axis=1) # (T_ds,)
|
| 162 |
+
|
| 163 |
+
stride = max(1, int(round(self.anchor_stride_sec * self.sr)))
|
| 164 |
+
first_anchor = self.T_obs
|
| 165 |
+
last_anchor = T_ds - self.T_fut
|
| 166 |
+
if last_anchor <= first_anchor:
|
| 167 |
+
continue
|
| 168 |
+
|
| 169 |
+
for anchor in range(first_anchor, last_anchor + 1, stride):
|
| 170 |
+
past_p = pressure_sum[anchor - self.T_obs:anchor]
|
| 171 |
+
fut_p = pressure_sum[anchor:anchor + self.T_fut]
|
| 172 |
+
past_high = (past_p > self.contact_threshold_g).mean() > 0.5
|
| 173 |
+
fut_high = (fut_p > self.contact_threshold_g).mean() > 0.5
|
| 174 |
+
if not past_high and not fut_high:
|
| 175 |
+
et = 0
|
| 176 |
+
elif not past_high and fut_high:
|
| 177 |
+
et = 1
|
| 178 |
+
elif past_high and fut_high:
|
| 179 |
+
et = 2
|
| 180 |
+
else:
|
| 181 |
+
et = 3
|
| 182 |
+
|
| 183 |
+
past_slice = {m: arr[anchor - self.T_obs:anchor]
|
| 184 |
+
for m, arr in input_ds.items()}
|
| 185 |
+
past_target_last = target_ds[anchor - 1].copy() # (target_dim,)
|
| 186 |
+
fut_target = target_ds[anchor:anchor + self.T_fut].copy()
|
| 187 |
+
if any(w.shape[0] != self.T_obs for w in past_slice.values()):
|
| 188 |
+
continue
|
| 189 |
+
if fut_target.shape[0] != self.T_fut:
|
| 190 |
+
continue
|
| 191 |
+
|
| 192 |
+
item = {
|
| 193 |
+
"x": past_slice,
|
| 194 |
+
"y": fut_target,
|
| 195 |
+
"y_last": past_target_last, # for persistence
|
| 196 |
+
"event_type": int(et),
|
| 197 |
+
"meta": {"vol": vol, "scene": scene, "anchor_idx": int(anchor)},
|
| 198 |
+
}
|
| 199 |
+
if self.include_future_pressure:
|
| 200 |
+
fut_press = pressure_ds[anchor:anchor + self.T_fut].copy()
|
| 201 |
+
if fut_press.shape[0] != self.T_fut:
|
| 202 |
+
continue
|
| 203 |
+
item["fp"] = fut_press # (T_fut, 50)
|
| 204 |
+
pools[et].append(item)
|
| 205 |
+
|
| 206 |
+
# Cap per-event count if requested (uniform downsample for balance)
|
| 207 |
+
for et, pool in pools.items():
|
| 208 |
+
if self.per_event_max is not None and len(pool) > self.per_event_max:
|
| 209 |
+
idx = rng.choice(len(pool), size=self.per_event_max, replace=False)
|
| 210 |
+
pools[et] = [pool[i] for i in sorted(idx)]
|
| 211 |
+
self._items = [it for et in (0, 1, 2, 3) for it in pools[et]]
|
| 212 |
+
|
| 213 |
+
if not self._items:
|
| 214 |
+
raise RuntimeError("SignalForecastDataset: collected 0 anchors.")
|
| 215 |
+
|
| 216 |
+
# Z-score inputs and target separately
|
| 217 |
+
if input_stats is None:
|
| 218 |
+
input_stats = self._compute_input_stats()
|
| 219 |
+
self._input_stats = input_stats
|
| 220 |
+
self._apply_input_stats(input_stats)
|
| 221 |
+
if target_stats is None:
|
| 222 |
+
target_stats = self._compute_target_stats()
|
| 223 |
+
self._target_stats = target_stats
|
| 224 |
+
self._apply_target_stats(target_stats)
|
| 225 |
+
if self.include_future_pressure:
|
| 226 |
+
if future_pressure_stats is None:
|
| 227 |
+
future_pressure_stats = self._compute_fp_stats()
|
| 228 |
+
self._fp_stats = future_pressure_stats
|
| 229 |
+
self._apply_fp_stats(future_pressure_stats)
|
| 230 |
+
else:
|
| 231 |
+
self._fp_stats = None
|
| 232 |
+
|
| 233 |
+
if log:
|
| 234 |
+
counts = {EVENT_NAMES[k]: sum(1 for it in self._items if it["event_type"] == k)
|
| 235 |
+
for k in (0, 1, 2, 3)}
|
| 236 |
+
print(f"[SignalForecastDataset] vols={len(volunteers)} "
|
| 237 |
+
f"target={self.target_modality} inputs={self.input_modalities} "
|
| 238 |
+
f"anchors={len(self._items)} {counts} "
|
| 239 |
+
f"T_obs={self.T_obs} T_fut={self.T_fut} sr={self.sr}Hz "
|
| 240 |
+
f"input_dims={self._modality_dims} target_dim={self._target_dim}",
|
| 241 |
+
flush=True)
|
| 242 |
+
|
| 243 |
+
@staticmethod
|
| 244 |
+
def _enforce_dim(arrs, m, arr, dim_dict):
|
| 245 |
+
if m in dim_dict:
|
| 246 |
+
target = dim_dict[m]
|
| 247 |
+
if arr.shape[1] != target:
|
| 248 |
+
if arr.shape[1] < target:
|
| 249 |
+
pad = np.zeros((arr.shape[0], target - arr.shape[1]), dtype=np.float32)
|
| 250 |
+
arrs[m] = np.concatenate([arr, pad], axis=1)
|
| 251 |
+
else:
|
| 252 |
+
arrs[m] = arr[:, :target]
|
| 253 |
+
else:
|
| 254 |
+
dim_dict[m] = arr.shape[1]
|
| 255 |
+
|
| 256 |
+
def _compute_input_stats(self):
|
| 257 |
+
accs = {m: [] for m in self._modality_dims}
|
| 258 |
+
for it in self._items:
|
| 259 |
+
for m, w in it["x"].items():
|
| 260 |
+
accs[m].append(w)
|
| 261 |
+
out = {}
|
| 262 |
+
for m, ws in accs.items():
|
| 263 |
+
cat = np.concatenate(ws, axis=0)
|
| 264 |
+
mu = cat.mean(axis=0).astype(np.float32)
|
| 265 |
+
sd = cat.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd)
|
| 266 |
+
out[m] = (mu, sd.astype(np.float32))
|
| 267 |
+
return out
|
| 268 |
+
|
| 269 |
+
def _apply_input_stats(self, stats):
|
| 270 |
+
for it in self._items:
|
| 271 |
+
for m, w in it["x"].items():
|
| 272 |
+
if m in stats:
|
| 273 |
+
mu, sd = stats[m]
|
| 274 |
+
it["x"][m] = ((w - mu) / sd).astype(np.float32)
|
| 275 |
+
|
| 276 |
+
def _compute_target_stats(self):
|
| 277 |
+
ys = np.concatenate([it["y"] for it in self._items], axis=0)
|
| 278 |
+
mu = ys.mean(axis=0).astype(np.float32)
|
| 279 |
+
sd = ys.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd)
|
| 280 |
+
return (mu, sd.astype(np.float32))
|
| 281 |
+
|
| 282 |
+
def _apply_target_stats(self, stats):
|
| 283 |
+
mu, sd = stats
|
| 284 |
+
for it in self._items:
|
| 285 |
+
it["y"] = ((it["y"] - mu) / sd).astype(np.float32)
|
| 286 |
+
it["y_last"] = ((it["y_last"] - mu) / sd).astype(np.float32)
|
| 287 |
+
|
| 288 |
+
def _compute_fp_stats(self):
|
| 289 |
+
fps = np.concatenate([it["fp"] for it in self._items], axis=0)
|
| 290 |
+
mu = fps.mean(axis=0).astype(np.float32)
|
| 291 |
+
sd = fps.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd)
|
| 292 |
+
return (mu, sd.astype(np.float32))
|
| 293 |
+
|
| 294 |
+
def _apply_fp_stats(self, stats):
|
| 295 |
+
mu, sd = stats
|
| 296 |
+
for it in self._items:
|
| 297 |
+
it["fp"] = ((it["fp"] - mu) / sd).astype(np.float32)
|
| 298 |
+
|
| 299 |
+
def __len__(self):
|
| 300 |
+
return len(self._items)
|
| 301 |
+
|
| 302 |
+
def __getitem__(self, idx):
|
| 303 |
+
it = self._items[idx]
|
| 304 |
+
x = {m: torch.from_numpy(np.ascontiguousarray(w)) for m, w in it["x"].items()}
|
| 305 |
+
y = torch.from_numpy(np.ascontiguousarray(it["y"])) # (T_fut, target_dim)
|
| 306 |
+
y_last = torch.from_numpy(np.ascontiguousarray(it["y_last"])) # (target_dim,)
|
| 307 |
+
et = int(it["event_type"])
|
| 308 |
+
if self.include_future_pressure:
|
| 309 |
+
fp = torch.from_numpy(np.ascontiguousarray(it["fp"])) # (T_fut, 50)
|
| 310 |
+
return x, y, y_last, fp, et, it["meta"]
|
| 311 |
+
return x, y, y_last, et, it["meta"]
|
| 312 |
+
|
| 313 |
+
@property
|
| 314 |
+
def modality_dims(self):
|
| 315 |
+
return dict(self._modality_dims)
|
| 316 |
+
|
| 317 |
+
@property
|
| 318 |
+
def target_dim(self):
|
| 319 |
+
return self._target_dim
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def collate_signal_forecast(batch):
|
| 323 |
+
if len(batch[0]) == 6: # has future pressure
|
| 324 |
+
xs, ys, ylasts, fps, ets, metas = zip(*batch)
|
| 325 |
+
mods = list(xs[0].keys())
|
| 326 |
+
x_out = {m: torch.stack([x[m] for x in xs], dim=0) for m in mods}
|
| 327 |
+
y_out = torch.stack(ys, dim=0)
|
| 328 |
+
yl_out = torch.stack(ylasts, dim=0)
|
| 329 |
+
fp_out = torch.stack(fps, dim=0) # (B, T_fut, 50)
|
| 330 |
+
et_out = torch.tensor(ets, dtype=torch.long)
|
| 331 |
+
return x_out, y_out, yl_out, fp_out, et_out, list(metas)
|
| 332 |
+
xs, ys, ylasts, ets, metas = zip(*batch)
|
| 333 |
+
mods = list(xs[0].keys())
|
| 334 |
+
x_out = {m: torch.stack([x[m] for x in xs], dim=0) for m in mods}
|
| 335 |
+
y_out = torch.stack(ys, dim=0)
|
| 336 |
+
yl_out = torch.stack(ylasts, dim=0)
|
| 337 |
+
et_out = torch.tensor(ets, dtype=torch.long)
|
| 338 |
+
return x_out, y_out, yl_out, et_out, list(metas)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def build_signal_train_test(
|
| 342 |
+
input_modalities, target_modality,
|
| 343 |
+
t_obs_sec=1.5, t_fut_sec=0.5, anchor_stride_sec=0.25,
|
| 344 |
+
downsample=5,
|
| 345 |
+
dataset_dir=DEFAULT_DATASET_DIR, annot_dir=DEFAULT_ANNOT_DIR,
|
| 346 |
+
contact_threshold_g=5.0, per_event_max=None,
|
| 347 |
+
include_future_pressure=False,
|
| 348 |
+
rng_seed=0,
|
| 349 |
+
):
|
| 350 |
+
train = SignalForecastDataset(
|
| 351 |
+
TRAIN_VOLS_V3, input_modalities=input_modalities,
|
| 352 |
+
target_modality=target_modality,
|
| 353 |
+
t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec,
|
| 354 |
+
anchor_stride_sec=anchor_stride_sec, downsample=downsample,
|
| 355 |
+
dataset_dir=dataset_dir, annot_dir=annot_dir,
|
| 356 |
+
contact_threshold_g=contact_threshold_g, per_event_max=per_event_max,
|
| 357 |
+
include_future_pressure=include_future_pressure,
|
| 358 |
+
rng_seed=rng_seed, log=True,
|
| 359 |
+
)
|
| 360 |
+
test = SignalForecastDataset(
|
| 361 |
+
TEST_VOLS_V3, input_modalities=input_modalities,
|
| 362 |
+
target_modality=target_modality,
|
| 363 |
+
t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec,
|
| 364 |
+
anchor_stride_sec=anchor_stride_sec, downsample=downsample,
|
| 365 |
+
dataset_dir=dataset_dir, annot_dir=annot_dir,
|
| 366 |
+
contact_threshold_g=contact_threshold_g, per_event_max=per_event_max,
|
| 367 |
+
input_stats=train._input_stats, target_stats=train._target_stats,
|
| 368 |
+
future_pressure_stats=train._fp_stats,
|
| 369 |
+
expected_input_dims=train._modality_dims,
|
| 370 |
+
expected_target_dim=train._target_dim,
|
| 371 |
+
include_future_pressure=include_future_pressure,
|
| 372 |
+
rng_seed=rng_seed + 1, log=True,
|
| 373 |
+
)
|
| 374 |
+
return train, test
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
if __name__ == "__main__":
|
| 378 |
+
import argparse
|
| 379 |
+
ap = argparse.ArgumentParser()
|
| 380 |
+
ap.add_argument("--input_modalities", default="imu")
|
| 381 |
+
ap.add_argument("--target_modality", default="imu")
|
| 382 |
+
ap.add_argument("--t_obs", type=float, default=1.5)
|
| 383 |
+
ap.add_argument("--t_fut", type=float, default=0.5)
|
| 384 |
+
args = ap.parse_args()
|
| 385 |
+
tr, te = build_signal_train_test(
|
| 386 |
+
input_modalities=args.input_modalities.split(","),
|
| 387 |
+
target_modality=args.target_modality,
|
| 388 |
+
t_obs_sec=args.t_obs, t_fut_sec=args.t_fut,
|
| 389 |
+
)
|
| 390 |
+
x, y, y_last, et, meta = tr[0]
|
| 391 |
+
print(f"Sample: x={ {m: tuple(v.shape) for m,v in x.items()} } y={tuple(y.shape)} y_last={tuple(y_last.shape)} event_type={et}")
|
experiments/nets/__init__.py
ADDED
|
File without changes
|
experiments/nets/__pycache__/models_seqpred.cpython-312.pyc
ADDED
|
Binary file (44.4 kB). View file
|
|
|
experiments/nets/baselines_published/__init__.py
ADDED
|
File without changes
|
experiments/nets/baselines_published/baselines.py
ADDED
|
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Published baselines for T1 Scene Recognition, reproduced on DailyAct-5M.
|
| 3 |
+
|
| 4 |
+
Each method accepts a concatenated feature tensor (B, T, F_total) where F_total
|
| 5 |
+
is the sum of the active modality dims; the per-modality slices are recorded in
|
| 6 |
+
the `modality_dims` dict. Each method then uses the subset of modalities its
|
| 7 |
+
original paper intended.
|
| 8 |
+
|
| 9 |
+
All methods output an (B, num_classes) logit tensor.
|
| 10 |
+
"""
|
| 11 |
+
import math
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _slice(x, mod_dims, wanted):
|
| 18 |
+
"""Slice the concatenated feature tensor to keep only `wanted` modalities,
|
| 19 |
+
in the order given. mod_dims is an ordered dict. Returns
|
| 20 |
+
{name: tensor(B,T,d_name)} plus the concat."""
|
| 21 |
+
parts = {}
|
| 22 |
+
offset = 0
|
| 23 |
+
for name, d in mod_dims.items():
|
| 24 |
+
if name in wanted:
|
| 25 |
+
parts[name] = x[..., offset:offset + d]
|
| 26 |
+
offset += d
|
| 27 |
+
assert len(parts) > 0, f"None of {wanted} in {list(mod_dims.keys())}"
|
| 28 |
+
return parts
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
# 1) ST-GCN (Yan et al., AAAI 2018)
|
| 33 |
+
# Spatio-temporal graph CNN for skeleton action recognition.
|
| 34 |
+
# We treat the 56-joint MoCap skeleton as the graph.
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
|
| 37 |
+
class STGCNBlock(nn.Module):
|
| 38 |
+
def __init__(self, in_ch, out_ch, n_joints, stride=1, dropout=0.2):
|
| 39 |
+
super().__init__()
|
| 40 |
+
# Spatial graph conv: learnable adjacency (fully learned, no handcrafted A)
|
| 41 |
+
self.A = nn.Parameter(torch.eye(n_joints) + 0.1 * torch.randn(n_joints, n_joints))
|
| 42 |
+
self.spatial = nn.Conv2d(in_ch, out_ch, kernel_size=(1, 1), bias=False)
|
| 43 |
+
self.spatial_bn = nn.BatchNorm2d(out_ch)
|
| 44 |
+
self.temporal = nn.Conv2d(out_ch, out_ch, kernel_size=(9, 1),
|
| 45 |
+
padding=(4, 0), stride=(stride, 1))
|
| 46 |
+
self.temporal_bn = nn.BatchNorm2d(out_ch)
|
| 47 |
+
self.dropout = nn.Dropout(dropout)
|
| 48 |
+
if in_ch != out_ch or stride != 1:
|
| 49 |
+
self.res = nn.Conv2d(in_ch, out_ch, kernel_size=1,
|
| 50 |
+
stride=(stride, 1))
|
| 51 |
+
else:
|
| 52 |
+
self.res = nn.Identity()
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
# x: (B, C, T, V)
|
| 56 |
+
res = self.res(x)
|
| 57 |
+
# spatial: aggregate along joints via A
|
| 58 |
+
h = self.spatial(x)
|
| 59 |
+
h = torch.einsum('bctv,vw->bctw', h, F.softmax(self.A, dim=-1))
|
| 60 |
+
h = self.spatial_bn(h)
|
| 61 |
+
h = F.relu(h)
|
| 62 |
+
# temporal
|
| 63 |
+
h = self.temporal(h)
|
| 64 |
+
h = self.temporal_bn(h)
|
| 65 |
+
h = self.dropout(h)
|
| 66 |
+
return F.relu(h + res)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class STGCN(nn.Module):
|
| 70 |
+
"""ST-GCN on MoCap skeleton. We assume the MoCap modality is 620-dim
|
| 71 |
+
(hip-relative + velocity) and reshape to ~56 joints."""
|
| 72 |
+
def __init__(self, feat_dim_mocap, num_classes, hidden=64, n_joints=52):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.n_joints = n_joints
|
| 75 |
+
# MoCap feat is (T, 620). 52 joints × 4 (xyz+quat_type), or we take per-joint xyz-only = 156.
|
| 76 |
+
# In this repo, 620 = 52 markers * 4 cols + velocity features. We'll
|
| 77 |
+
# reshape by slicing to 3*52=156 "primary" coords, padded if needed.
|
| 78 |
+
self.coord_dim = 3 # we'll treat each joint as having 3 coords (XYZ)
|
| 79 |
+
self.proj_in = nn.Linear(feat_dim_mocap, n_joints * self.coord_dim)
|
| 80 |
+
|
| 81 |
+
self.blocks = nn.ModuleList([
|
| 82 |
+
STGCNBlock(self.coord_dim, hidden, n_joints),
|
| 83 |
+
STGCNBlock(hidden, hidden, n_joints),
|
| 84 |
+
STGCNBlock(hidden, hidden * 2, n_joints, stride=2),
|
| 85 |
+
STGCNBlock(hidden * 2, hidden * 2, n_joints),
|
| 86 |
+
STGCNBlock(hidden * 2, hidden * 4, n_joints, stride=2),
|
| 87 |
+
STGCNBlock(hidden * 4, hidden * 4, n_joints),
|
| 88 |
+
])
|
| 89 |
+
self.head = nn.Sequential(
|
| 90 |
+
nn.Dropout(0.3),
|
| 91 |
+
nn.Linear(hidden * 4, num_classes),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def forward(self, x_mocap, mask=None):
|
| 95 |
+
# x_mocap: (B, T, feat_dim_mocap)
|
| 96 |
+
B, T, _ = x_mocap.shape
|
| 97 |
+
h = self.proj_in(x_mocap) # (B, T, n_joints * 3)
|
| 98 |
+
h = h.reshape(B, T, self.n_joints, self.coord_dim).permute(0, 3, 1, 2) # (B, C, T, V)
|
| 99 |
+
for blk in self.blocks:
|
| 100 |
+
h = blk(h)
|
| 101 |
+
# Global mean pool over time & joints (with mask if provided)
|
| 102 |
+
if mask is not None:
|
| 103 |
+
# mask: (B, T), h: (B, C, T', V) where T' may be < T due to stride
|
| 104 |
+
T_ = h.shape[2]
|
| 105 |
+
m = mask[:, :T_].float().unsqueeze(1).unsqueeze(-1) # (B, 1, T', 1)
|
| 106 |
+
h = (h * m).sum(dim=(2, 3)) / (m.sum(dim=(2, 3)) * h.shape[3] + 1e-8)
|
| 107 |
+
else:
|
| 108 |
+
h = h.mean(dim=(2, 3))
|
| 109 |
+
return self.head(h)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# ---------------------------------------------------------------------------
|
| 113 |
+
# 2) CTR-GCN (Chen et al., ICCV 2021)
|
| 114 |
+
# Channel-wise Topology Refinement GCN — learns a separate adjacency
|
| 115 |
+
# matrix per channel group, known as SOTA for skeleton action recognition.
|
| 116 |
+
# ---------------------------------------------------------------------------
|
| 117 |
+
|
| 118 |
+
class CTRGC(nn.Module):
|
| 119 |
+
"""Simplified CTR-GC block: learnable per-channel topology refinement."""
|
| 120 |
+
def __init__(self, in_ch, out_ch, n_joints, rel_reduction=4):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.n_joints = n_joints
|
| 123 |
+
self.conv1 = nn.Conv2d(in_ch, out_ch // rel_reduction, 1)
|
| 124 |
+
self.conv2 = nn.Conv2d(in_ch, out_ch // rel_reduction, 1)
|
| 125 |
+
self.conv3 = nn.Conv2d(in_ch, out_ch, 1)
|
| 126 |
+
self.alpha = nn.Parameter(torch.zeros(1))
|
| 127 |
+
self.A = nn.Parameter(torch.eye(n_joints) + 0.1 * torch.randn(n_joints, n_joints))
|
| 128 |
+
|
| 129 |
+
def forward(self, x):
|
| 130 |
+
# x: (B, C, T, V)
|
| 131 |
+
q = self.conv1(x).mean(dim=2) # (B, C', V)
|
| 132 |
+
k = self.conv2(x).mean(dim=2) # (B, C', V)
|
| 133 |
+
v = self.conv3(x) # (B, C_out, T, V)
|
| 134 |
+
# Channel-specific topology refinement
|
| 135 |
+
topology = F.softmax(torch.tanh(q.unsqueeze(-1) - k.unsqueeze(-2)), dim=-1)
|
| 136 |
+
# topology: (B, C', V, V); we average across channels to get a shared (B, V, V)
|
| 137 |
+
topology = topology.mean(dim=1)
|
| 138 |
+
A = self.A.unsqueeze(0) + self.alpha * topology
|
| 139 |
+
# apply A to v
|
| 140 |
+
out = torch.einsum('bctv,bvw->bctw', v, A)
|
| 141 |
+
return out
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class CTRGCNBlock(nn.Module):
|
| 145 |
+
def __init__(self, in_ch, out_ch, n_joints, stride=1):
|
| 146 |
+
super().__init__()
|
| 147 |
+
self.gc = CTRGC(in_ch, out_ch, n_joints)
|
| 148 |
+
self.bn = nn.BatchNorm2d(out_ch)
|
| 149 |
+
self.tcn = nn.Sequential(
|
| 150 |
+
nn.Conv2d(out_ch, out_ch, (9, 1), padding=(4, 0), stride=(stride, 1)),
|
| 151 |
+
nn.BatchNorm2d(out_ch),
|
| 152 |
+
)
|
| 153 |
+
if in_ch != out_ch or stride != 1:
|
| 154 |
+
self.res = nn.Conv2d(in_ch, out_ch, 1, stride=(stride, 1))
|
| 155 |
+
else:
|
| 156 |
+
self.res = nn.Identity()
|
| 157 |
+
|
| 158 |
+
def forward(self, x):
|
| 159 |
+
res = self.res(x)
|
| 160 |
+
h = self.gc(x)
|
| 161 |
+
h = self.bn(h)
|
| 162 |
+
h = F.relu(h)
|
| 163 |
+
h = self.tcn(h)
|
| 164 |
+
return F.relu(h + res)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class CTRGCN(nn.Module):
|
| 168 |
+
def __init__(self, feat_dim_mocap, num_classes, hidden=64, n_joints=52):
|
| 169 |
+
super().__init__()
|
| 170 |
+
self.n_joints = n_joints
|
| 171 |
+
self.coord_dim = 3
|
| 172 |
+
self.proj_in = nn.Linear(feat_dim_mocap, n_joints * self.coord_dim)
|
| 173 |
+
self.blocks = nn.ModuleList([
|
| 174 |
+
CTRGCNBlock(self.coord_dim, hidden, n_joints),
|
| 175 |
+
CTRGCNBlock(hidden, hidden, n_joints),
|
| 176 |
+
CTRGCNBlock(hidden, hidden * 2, n_joints, stride=2),
|
| 177 |
+
CTRGCNBlock(hidden * 2, hidden * 4, n_joints, stride=2),
|
| 178 |
+
])
|
| 179 |
+
self.head = nn.Sequential(
|
| 180 |
+
nn.Dropout(0.3),
|
| 181 |
+
nn.Linear(hidden * 4, num_classes),
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
def forward(self, x_mocap, mask=None):
|
| 185 |
+
B, T, _ = x_mocap.shape
|
| 186 |
+
h = self.proj_in(x_mocap)
|
| 187 |
+
h = h.reshape(B, T, self.n_joints, self.coord_dim).permute(0, 3, 1, 2)
|
| 188 |
+
for blk in self.blocks:
|
| 189 |
+
h = blk(h)
|
| 190 |
+
h = h.mean(dim=(2, 3))
|
| 191 |
+
return self.head(h)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# ---------------------------------------------------------------------------
|
| 195 |
+
# 3) LIMU-BERT (Xu et al., SenSys 2021)
|
| 196 |
+
# IMU self-supervised pretraining via masked reconstruction + fine-tune.
|
| 197 |
+
# We implement a simpler variant: BERT-style encoder with optional
|
| 198 |
+
# pretraining head.
|
| 199 |
+
# ---------------------------------------------------------------------------
|
| 200 |
+
|
| 201 |
+
class LIMUBertEncoder(nn.Module):
|
| 202 |
+
def __init__(self, feat_dim_imu, hidden=128, n_layers=4, n_heads=4, dropout=0.1):
|
| 203 |
+
super().__init__()
|
| 204 |
+
self.in_proj = nn.Linear(feat_dim_imu, hidden)
|
| 205 |
+
self.pos = nn.Parameter(torch.zeros(1, 4096, hidden))
|
| 206 |
+
nn.init.trunc_normal_(self.pos, std=0.02)
|
| 207 |
+
layer = nn.TransformerEncoderLayer(
|
| 208 |
+
d_model=hidden, nhead=n_heads, dim_feedforward=4 * hidden,
|
| 209 |
+
dropout=dropout, batch_first=True, activation='gelu',
|
| 210 |
+
)
|
| 211 |
+
self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)
|
| 212 |
+
|
| 213 |
+
def forward(self, x, mask):
|
| 214 |
+
T = x.size(1)
|
| 215 |
+
h = self.in_proj(x) + self.pos[:, :T, :]
|
| 216 |
+
h = self.encoder(h, src_key_padding_mask=~mask)
|
| 217 |
+
return h
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class LIMUBert(nn.Module):
|
| 221 |
+
"""Supervised-only variant: encoder + classifier head. Paper's
|
| 222 |
+
pretraining is a masked-recon objective; for simplicity we report the
|
| 223 |
+
supervised-only baseline here."""
|
| 224 |
+
def __init__(self, feat_dim_imu, num_classes, hidden=128, n_layers=4,
|
| 225 |
+
n_heads=4, dropout=0.1):
|
| 226 |
+
super().__init__()
|
| 227 |
+
self.encoder = LIMUBertEncoder(feat_dim_imu, hidden, n_layers, n_heads, dropout)
|
| 228 |
+
self.head = nn.Sequential(
|
| 229 |
+
nn.LayerNorm(hidden),
|
| 230 |
+
nn.Dropout(dropout),
|
| 231 |
+
nn.Linear(hidden, num_classes),
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
def forward(self, x_imu, mask):
|
| 235 |
+
h = self.encoder(x_imu, mask)
|
| 236 |
+
m = mask.unsqueeze(-1).float()
|
| 237 |
+
pooled = (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)
|
| 238 |
+
return self.head(pooled)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# ---------------------------------------------------------------------------
|
| 242 |
+
# 4) EMG-CNN (standard 1D CNN baseline from sEMG classification literature)
|
| 243 |
+
# E.g. Atzori et al. — multi-layer CNN with moving-window input.
|
| 244 |
+
# ---------------------------------------------------------------------------
|
| 245 |
+
|
| 246 |
+
class EMGCNN(nn.Module):
|
| 247 |
+
def __init__(self, feat_dim_emg, num_classes, hidden=64):
|
| 248 |
+
super().__init__()
|
| 249 |
+
self.cnn = nn.Sequential(
|
| 250 |
+
nn.Conv1d(feat_dim_emg, hidden, 7, padding=3),
|
| 251 |
+
nn.BatchNorm1d(hidden), nn.ReLU(), nn.Dropout(0.3),
|
| 252 |
+
nn.Conv1d(hidden, hidden * 2, 5, padding=2),
|
| 253 |
+
nn.BatchNorm1d(hidden * 2), nn.ReLU(), nn.Dropout(0.3),
|
| 254 |
+
nn.Conv1d(hidden * 2, hidden * 4, 3, padding=1),
|
| 255 |
+
nn.BatchNorm1d(hidden * 4), nn.ReLU(),
|
| 256 |
+
)
|
| 257 |
+
self.head = nn.Linear(hidden * 4, num_classes)
|
| 258 |
+
|
| 259 |
+
def forward(self, x_emg, mask):
|
| 260 |
+
# (B, T, 8) -> (B, 8, T) for conv1d
|
| 261 |
+
h = self.cnn(x_emg.transpose(1, 2))
|
| 262 |
+
# Masked pool
|
| 263 |
+
m = mask.unsqueeze(1).float()
|
| 264 |
+
T_ = h.size(2)
|
| 265 |
+
if m.size(2) != T_:
|
| 266 |
+
m = F.adaptive_avg_pool1d(m, T_)
|
| 267 |
+
m = (m > 0.5).float()
|
| 268 |
+
pooled = (h * m).sum(dim=2) / m.sum(dim=2).clamp(min=1.0)
|
| 269 |
+
return self.head(pooled)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# ---------------------------------------------------------------------------
|
| 273 |
+
# 5) ActionSense baseline (DelPreto et al., NeurIPS '22)
|
| 274 |
+
# Simple 3-layer MLP per modality + shared LSTM + classifier.
|
| 275 |
+
# ---------------------------------------------------------------------------
|
| 276 |
+
|
| 277 |
+
class ActionSenseLSTM(nn.Module):
|
| 278 |
+
def __init__(self, modality_dims: dict, num_classes, hidden=128):
|
| 279 |
+
super().__init__()
|
| 280 |
+
self.mod_names = list(modality_dims.keys())
|
| 281 |
+
self.mod_dims = modality_dims
|
| 282 |
+
self.per_mod = nn.ModuleDict({
|
| 283 |
+
name: nn.Sequential(
|
| 284 |
+
nn.Linear(d, hidden), nn.ReLU(), nn.Dropout(0.2),
|
| 285 |
+
nn.Linear(hidden, hidden), nn.ReLU(),
|
| 286 |
+
) for name, d in modality_dims.items()
|
| 287 |
+
})
|
| 288 |
+
concat_dim = hidden * len(modality_dims)
|
| 289 |
+
self.lstm = nn.LSTM(concat_dim, hidden, num_layers=2,
|
| 290 |
+
batch_first=True, bidirectional=True, dropout=0.2)
|
| 291 |
+
self.head = nn.Linear(hidden * 2, num_classes)
|
| 292 |
+
|
| 293 |
+
def forward(self, x, mask):
|
| 294 |
+
# x: (B, T, F_total), slice by modality
|
| 295 |
+
offset = 0
|
| 296 |
+
feats = []
|
| 297 |
+
for name in self.mod_names:
|
| 298 |
+
d = self.mod_dims[name]
|
| 299 |
+
x_m = x[..., offset:offset + d]
|
| 300 |
+
offset += d
|
| 301 |
+
feats.append(self.per_mod[name](x_m))
|
| 302 |
+
h = torch.cat(feats, dim=-1) # (B, T, hidden * M)
|
| 303 |
+
h, _ = self.lstm(h)
|
| 304 |
+
m = mask.unsqueeze(-1).float()
|
| 305 |
+
pooled = (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)
|
| 306 |
+
return self.head(pooled)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
# ---------------------------------------------------------------------------
|
| 310 |
+
# 6) MulT (Multimodal Transformer, Tsai et al., ACL 2019)
|
| 311 |
+
# Core idea: cross-modal attention between every pair of modalities.
|
| 312 |
+
# For a 3-modality input (A, B, C), produce
|
| 313 |
+
# {A->B, A->C, B->A, B->C, C->A, C->B} via directed cross-attention.
|
| 314 |
+
# ---------------------------------------------------------------------------
|
| 315 |
+
|
| 316 |
+
class CrossModalTransformer(nn.Module):
|
| 317 |
+
def __init__(self, d_model, n_heads=4, n_layers=2, dropout=0.1):
|
| 318 |
+
super().__init__()
|
| 319 |
+
self.layers = nn.ModuleList([
|
| 320 |
+
nn.TransformerDecoderLayer(
|
| 321 |
+
d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
|
| 322 |
+
dropout=dropout, batch_first=True, activation='gelu',
|
| 323 |
+
) for _ in range(n_layers)
|
| 324 |
+
])
|
| 325 |
+
|
| 326 |
+
def forward(self, q, kv, q_mask, kv_mask):
|
| 327 |
+
# q: (B, T_q, D), kv: (B, T_kv, D)
|
| 328 |
+
h = q
|
| 329 |
+
for layer in self.layers:
|
| 330 |
+
h = layer(h, kv,
|
| 331 |
+
tgt_key_padding_mask=~q_mask,
|
| 332 |
+
memory_key_padding_mask=~kv_mask)
|
| 333 |
+
return h
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class MulT(nn.Module):
|
| 337 |
+
"""Multimodal Transformer. Uses MoCap + EMG + IMU as 3 modalities
|
| 338 |
+
(EyeTrack/Pressure omitted to match original 3-mod paper design)."""
|
| 339 |
+
def __init__(self, modality_dims: dict, num_classes, d_model=128,
|
| 340 |
+
n_layers=2, n_heads=4, dropout=0.1):
|
| 341 |
+
super().__init__()
|
| 342 |
+
self.mod_names = [m for m in ['mocap', 'emg', 'imu'] if m in modality_dims]
|
| 343 |
+
if len(self.mod_names) < 2:
|
| 344 |
+
self.mod_names = list(modality_dims.keys())[:3]
|
| 345 |
+
self.mod_dims = {m: modality_dims[m] for m in self.mod_names}
|
| 346 |
+
self.in_proj = nn.ModuleDict({
|
| 347 |
+
m: nn.Linear(d, d_model) for m, d in self.mod_dims.items()
|
| 348 |
+
})
|
| 349 |
+
# Pairwise cross-attention
|
| 350 |
+
self.cross = nn.ModuleDict({
|
| 351 |
+
f"{a}_to_{b}": CrossModalTransformer(d_model, n_heads, n_layers, dropout)
|
| 352 |
+
for a in self.mod_names for b in self.mod_names if a != b
|
| 353 |
+
})
|
| 354 |
+
# Self-attention after cross
|
| 355 |
+
self.self_tx = nn.ModuleDict({
|
| 356 |
+
m: nn.TransformerEncoder(
|
| 357 |
+
nn.TransformerEncoderLayer(
|
| 358 |
+
d_model=d_model, nhead=n_heads,
|
| 359 |
+
dim_feedforward=4 * d_model, dropout=dropout,
|
| 360 |
+
batch_first=True, activation='gelu',
|
| 361 |
+
), num_layers=1,
|
| 362 |
+
) for m in self.mod_names
|
| 363 |
+
})
|
| 364 |
+
total_dim = d_model * len(self.mod_names) * len(self.mod_names)
|
| 365 |
+
self.head = nn.Sequential(
|
| 366 |
+
nn.LayerNorm(total_dim),
|
| 367 |
+
nn.Dropout(dropout),
|
| 368 |
+
nn.Linear(total_dim, num_classes),
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
def forward(self, x, mask):
|
| 372 |
+
# Slice modalities from x
|
| 373 |
+
offset = 0
|
| 374 |
+
projs = {}
|
| 375 |
+
# Walk through all known mod_dims to find offsets
|
| 376 |
+
# We need the FULL modality_dims order, which we don't have here;
|
| 377 |
+
# expect caller to already supply x with exactly mod_names in order.
|
| 378 |
+
# Workaround: assume caller passes mod_names order matching projection.
|
| 379 |
+
for m in self.mod_names:
|
| 380 |
+
d = self.mod_dims[m]
|
| 381 |
+
projs[m] = self.in_proj[m](x[..., offset:offset + d])
|
| 382 |
+
offset += d
|
| 383 |
+
|
| 384 |
+
# Cross-attention: each modality attends to each other
|
| 385 |
+
fused = {m: [] for m in self.mod_names}
|
| 386 |
+
for a in self.mod_names:
|
| 387 |
+
for b in self.mod_names:
|
| 388 |
+
if a == b:
|
| 389 |
+
fused[a].append(projs[a])
|
| 390 |
+
else:
|
| 391 |
+
out = self.cross[f"{a}_to_{b}"](projs[a], projs[b], mask, mask)
|
| 392 |
+
fused[a].append(out)
|
| 393 |
+
|
| 394 |
+
# Self-attention + pool per modality
|
| 395 |
+
pooled = []
|
| 396 |
+
for a in self.mod_names:
|
| 397 |
+
# Concat all attended-to representations along feature dim
|
| 398 |
+
cat = torch.cat(fused[a], dim=-1) # (B, T, D * M)
|
| 399 |
+
# Actually re-project back to D per stream, then self-attn on stacked
|
| 400 |
+
# Simplified: self-attention over concatenated, pool, flatten
|
| 401 |
+
# Here we just pool each separately
|
| 402 |
+
for i, rep in enumerate(fused[a]):
|
| 403 |
+
rep = self.self_tx[a](rep)
|
| 404 |
+
m = mask.unsqueeze(-1).float()
|
| 405 |
+
p = (rep * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)
|
| 406 |
+
pooled.append(p)
|
| 407 |
+
|
| 408 |
+
h = torch.cat(pooled, dim=-1)
|
| 409 |
+
return self.head(h)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
# ---------------------------------------------------------------------------
|
| 413 |
+
# 7) Perceiver IO (Jaegle et al., ICML 2021)
|
| 414 |
+
# Cross-attention from a fixed-size latent query set to all input tokens,
|
| 415 |
+
# repeated for a few iterations.
|
| 416 |
+
# ---------------------------------------------------------------------------
|
| 417 |
+
|
| 418 |
+
class PerceiverBlock(nn.Module):
|
| 419 |
+
def __init__(self, latent_dim, n_heads, dropout):
|
| 420 |
+
super().__init__()
|
| 421 |
+
self.ca = nn.MultiheadAttention(
|
| 422 |
+
latent_dim, n_heads, dropout=dropout, batch_first=True,
|
| 423 |
+
)
|
| 424 |
+
self.norm1 = nn.LayerNorm(latent_dim)
|
| 425 |
+
self.sa = nn.TransformerEncoderLayer(
|
| 426 |
+
d_model=latent_dim, nhead=n_heads,
|
| 427 |
+
dim_feedforward=4 * latent_dim, dropout=dropout,
|
| 428 |
+
batch_first=True, activation='gelu',
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
def forward(self, latents, inputs, input_kpm):
|
| 432 |
+
# Cross-attn: latents attend to inputs
|
| 433 |
+
h, _ = self.ca(latents, inputs, inputs, key_padding_mask=input_kpm)
|
| 434 |
+
latents = self.norm1(latents + h)
|
| 435 |
+
# Self-attn on latents
|
| 436 |
+
latents = self.sa(latents)
|
| 437 |
+
return latents
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
class PerceiverIO(nn.Module):
|
| 441 |
+
"""Perceiver with N learnable latent queries; supports any modality mix."""
|
| 442 |
+
def __init__(self, modality_dims: dict, num_classes,
|
| 443 |
+
latent_dim=128, n_latents=32, n_layers=3, n_heads=4, dropout=0.1):
|
| 444 |
+
super().__init__()
|
| 445 |
+
self.mod_names = list(modality_dims.keys())
|
| 446 |
+
self.mod_dims = modality_dims
|
| 447 |
+
# Per-modality input projection to latent_dim, with modality-id embedding
|
| 448 |
+
self.in_proj = nn.ModuleDict({
|
| 449 |
+
m: nn.Linear(d, latent_dim) for m, d in modality_dims.items()
|
| 450 |
+
})
|
| 451 |
+
self.mod_emb = nn.Parameter(torch.randn(len(self.mod_names), latent_dim) * 0.02)
|
| 452 |
+
# Positional encoding (shared)
|
| 453 |
+
self.pos = nn.Parameter(torch.zeros(1, 4096, latent_dim))
|
| 454 |
+
nn.init.trunc_normal_(self.pos, std=0.02)
|
| 455 |
+
# Learnable latents
|
| 456 |
+
self.latents = nn.Parameter(torch.randn(n_latents, latent_dim) * 0.02)
|
| 457 |
+
self.blocks = nn.ModuleList([
|
| 458 |
+
PerceiverBlock(latent_dim, n_heads, dropout) for _ in range(n_layers)
|
| 459 |
+
])
|
| 460 |
+
self.head = nn.Sequential(
|
| 461 |
+
nn.LayerNorm(latent_dim),
|
| 462 |
+
nn.Linear(latent_dim, num_classes),
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
def forward(self, x, mask):
|
| 466 |
+
B, T, _ = x.shape
|
| 467 |
+
# Project each modality + add modality embedding
|
| 468 |
+
offset = 0
|
| 469 |
+
tokens = []
|
| 470 |
+
for i, m in enumerate(self.mod_names):
|
| 471 |
+
d = self.mod_dims[m]
|
| 472 |
+
tok = self.in_proj[m](x[..., offset:offset + d]) # (B, T, D)
|
| 473 |
+
tok = tok + self.mod_emb[i]
|
| 474 |
+
offset += d
|
| 475 |
+
tokens.append(tok)
|
| 476 |
+
# Concatenate along TIME dim, add shared pos enc per-modality
|
| 477 |
+
# Each modality gets its own time sequence concatenated
|
| 478 |
+
# Simpler: sum across modalities (like early fusion in latent space) + pos
|
| 479 |
+
h = torch.stack(tokens, dim=2).mean(dim=2) # (B, T, D)
|
| 480 |
+
h = h + self.pos[:, :T, :]
|
| 481 |
+
input_kpm = ~mask # (B, T), True = ignore
|
| 482 |
+
# Iterative cross-attention
|
| 483 |
+
latents = self.latents.unsqueeze(0).expand(B, -1, -1) # (B, N, D)
|
| 484 |
+
for blk in self.blocks:
|
| 485 |
+
latents = blk(latents, h, input_kpm)
|
| 486 |
+
# Mean-pool latents
|
| 487 |
+
pooled = latents.mean(dim=1)
|
| 488 |
+
return self.head(pooled)
|
experiments/nets/baselines_published/syncfuse.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SyncFuse — our proposed method for T1 scene recognition.
|
| 3 |
+
|
| 4 |
+
Four components (all toggleable via args for ablation):
|
| 5 |
+
|
| 6 |
+
(1) Modality dropout: per-sample independent Bernoulli(p=0.3) drop on each
|
| 7 |
+
modality during training; at test time all modalities
|
| 8 |
+
are active. Keeps at least 1 modality.
|
| 9 |
+
(2) Pretrained transfer: each per-modality backbone is optionally loaded from
|
| 10 |
+
an independently pretrained single-modality
|
| 11 |
+
checkpoint and frozen during fine-tuning.
|
| 12 |
+
(3) Cross-modal temporal-shift attention:
|
| 13 |
+
a late cross-attention block where EMG queries
|
| 14 |
+
attend to MoCap keys/values at a LEARNED temporal
|
| 15 |
+
offset Δ (Gumbel-softmax over {-10,...,+10} bins at
|
| 16 |
+
20 Hz = ±500 ms). Motivated by the paper's case-study
|
| 17 |
+
finding (EMG leads motion by ~20 ms sub-frame).
|
| 18 |
+
(4) Learnable late fusion:
|
| 19 |
+
per-modality classifier logits are combined with a
|
| 20 |
+
learnable softmax-weighted average (temperature is
|
| 21 |
+
also learned). Equivalent to `late_agg='learned'`
|
| 22 |
+
in the repo's existing LateFusionModel.
|
| 23 |
+
"""
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
import random
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def masked_mean(x, mask):
|
| 31 |
+
m = mask.unsqueeze(-1).float()
|
| 32 |
+
return (x * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
# Per-modality Transformer branch (same as repo's TransformerBackbone)
|
| 37 |
+
# ---------------------------------------------------------------------------
|
| 38 |
+
|
| 39 |
+
class ModTransformer(nn.Module):
|
| 40 |
+
def __init__(self, feat_dim, hidden=128, n_layers=2, n_heads=4, dropout=0.1):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.in_proj = nn.Linear(feat_dim, hidden)
|
| 43 |
+
self.pos = nn.Parameter(torch.zeros(1, 4096, hidden))
|
| 44 |
+
nn.init.trunc_normal_(self.pos, std=0.02)
|
| 45 |
+
layer = nn.TransformerEncoderLayer(
|
| 46 |
+
d_model=hidden, nhead=n_heads, dim_feedforward=4 * hidden,
|
| 47 |
+
dropout=dropout, batch_first=True, activation='gelu',
|
| 48 |
+
)
|
| 49 |
+
self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)
|
| 50 |
+
self.output_dim = hidden
|
| 51 |
+
|
| 52 |
+
def forward(self, x, mask):
|
| 53 |
+
# x: (B, T, feat_dim)
|
| 54 |
+
T = x.size(1)
|
| 55 |
+
h = self.in_proj(x) + self.pos[:, :T, :]
|
| 56 |
+
h = self.encoder(h, src_key_padding_mask=~mask)
|
| 57 |
+
return h # (B, T, hidden) — token-level, NOT pooled
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# ---------------------------------------------------------------------------
|
| 61 |
+
# (3) Cross-modal temporal-shift attention
|
| 62 |
+
# ---------------------------------------------------------------------------
|
| 63 |
+
|
| 64 |
+
class TemporalShiftAttention(nn.Module):
|
| 65 |
+
"""Multi-head attention where queries are temporally shifted by a learned
|
| 66 |
+
offset Δ from the keys. Δ is drawn from a discrete set {-3,...,+3} via
|
| 67 |
+
straight-through Gumbel-softmax: we sample ONE shift per forward pass,
|
| 68 |
+
but the softmax weights flow gradient back through shift_logits.
|
| 69 |
+
|
| 70 |
+
At 20 Hz bins, ±3 ≈ ±150 ms, which brackets the paper's ~20 ms EMG-motion
|
| 71 |
+
lead. Memory cost is ~1 attention pass (not 7)."""
|
| 72 |
+
def __init__(self, d_model, n_heads=4, dropout=0.1, max_shift=3,
|
| 73 |
+
gumbel_tau=1.0):
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.max_shift = max_shift
|
| 76 |
+
self.shifts = list(range(-max_shift, max_shift + 1))
|
| 77 |
+
self.shift_logits = nn.Parameter(torch.zeros(len(self.shifts)))
|
| 78 |
+
self.tau = gumbel_tau
|
| 79 |
+
self.attn = nn.MultiheadAttention(
|
| 80 |
+
d_model, n_heads, dropout=dropout, batch_first=True,
|
| 81 |
+
)
|
| 82 |
+
self.norm = nn.LayerNorm(d_model)
|
| 83 |
+
|
| 84 |
+
def _shift_tensor(self, x, shift, mask):
|
| 85 |
+
if shift == 0:
|
| 86 |
+
return x, mask
|
| 87 |
+
B, T, D = x.shape
|
| 88 |
+
if shift > 0:
|
| 89 |
+
pad = torch.zeros(B, shift, D, device=x.device, dtype=x.dtype)
|
| 90 |
+
x_s = torch.cat([x[:, shift:, :], pad], dim=1)
|
| 91 |
+
m_s = torch.cat([mask[:, shift:],
|
| 92 |
+
torch.zeros(B, shift, device=mask.device, dtype=torch.bool)],
|
| 93 |
+
dim=1)
|
| 94 |
+
else:
|
| 95 |
+
s = -shift
|
| 96 |
+
pad = torch.zeros(B, s, D, device=x.device, dtype=x.dtype)
|
| 97 |
+
x_s = torch.cat([pad, x[:, :-s, :]], dim=1)
|
| 98 |
+
m_s = torch.cat([torch.zeros(B, s, device=mask.device, dtype=torch.bool),
|
| 99 |
+
mask[:, :-s]], dim=1)
|
| 100 |
+
return x_s, m_s
|
| 101 |
+
|
| 102 |
+
def forward(self, q_tokens, kv_tokens, q_mask, kv_mask, hard=False):
|
| 103 |
+
if hard or not self.training:
|
| 104 |
+
# Eval: take the argmax shift
|
| 105 |
+
with torch.no_grad():
|
| 106 |
+
idx = self.shift_logits.argmax().item()
|
| 107 |
+
shift = self.shifts[idx]
|
| 108 |
+
shifted_kv, shifted_mask = self._shift_tensor(kv_tokens, shift, kv_mask)
|
| 109 |
+
out, _ = self.attn(q_tokens, shifted_kv, shifted_kv,
|
| 110 |
+
key_padding_mask=~shifted_mask)
|
| 111 |
+
return self.norm(q_tokens + out)
|
| 112 |
+
|
| 113 |
+
# Training: straight-through Gumbel-softmax to sample 1 shift,
|
| 114 |
+
# with gradient flowing via softmax weights.
|
| 115 |
+
one_hot = F.gumbel_softmax(self.shift_logits, tau=self.tau, hard=True)
|
| 116 |
+
# pick the sampled shift (argmax of the hard one-hot)
|
| 117 |
+
idx = int(one_hot.argmax().item())
|
| 118 |
+
shift = self.shifts[idx]
|
| 119 |
+
shifted_kv, shifted_mask = self._shift_tensor(kv_tokens, shift, kv_mask)
|
| 120 |
+
out, _ = self.attn(q_tokens, shifted_kv, shifted_kv,
|
| 121 |
+
key_padding_mask=~shifted_mask)
|
| 122 |
+
# scale out by the corresponding soft weight to let gradient flow
|
| 123 |
+
out = out * one_hot[idx]
|
| 124 |
+
return self.norm(q_tokens + out)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ---------------------------------------------------------------------------
|
| 128 |
+
# SyncFuse main model
|
| 129 |
+
# ---------------------------------------------------------------------------
|
| 130 |
+
|
| 131 |
+
class SyncFuse(nn.Module):
|
| 132 |
+
def __init__(self, modality_dims: dict, num_classes, hidden=128, n_heads=4,
|
| 133 |
+
n_layers=2, dropout=0.1,
|
| 134 |
+
use_xmod_shift=True, use_learned_late=True):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.mod_names = list(modality_dims.keys())
|
| 137 |
+
self.mod_dims = modality_dims
|
| 138 |
+
self.use_xmod_shift = use_xmod_shift
|
| 139 |
+
self.use_learned_late = use_learned_late
|
| 140 |
+
|
| 141 |
+
self.branches = nn.ModuleDict({
|
| 142 |
+
m: ModTransformer(d, hidden, n_layers, n_heads, dropout)
|
| 143 |
+
for m, d in modality_dims.items()
|
| 144 |
+
})
|
| 145 |
+
self.classifiers = nn.ModuleDict({
|
| 146 |
+
m: nn.Sequential(nn.LayerNorm(hidden), nn.Dropout(dropout),
|
| 147 |
+
nn.Linear(hidden, num_classes))
|
| 148 |
+
for m in self.mod_names
|
| 149 |
+
})
|
| 150 |
+
|
| 151 |
+
# Cross-modal temporal-shift: apply to EMG branch attending to MoCap
|
| 152 |
+
# (and symmetrically MoCap->EMG), only when both modalities are present.
|
| 153 |
+
if use_xmod_shift and 'emg' in self.mod_names and 'mocap' in self.mod_names:
|
| 154 |
+
self.xmod_emg2mocap = TemporalShiftAttention(hidden, n_heads, dropout)
|
| 155 |
+
self.xmod_mocap2emg = TemporalShiftAttention(hidden, n_heads, dropout)
|
| 156 |
+
else:
|
| 157 |
+
self.xmod_emg2mocap = None
|
| 158 |
+
self.xmod_mocap2emg = None
|
| 159 |
+
|
| 160 |
+
if use_learned_late:
|
| 161 |
+
self.late_logits = nn.Parameter(torch.zeros(len(self.mod_names)))
|
| 162 |
+
self.late_temperature = nn.Parameter(torch.ones(1))
|
| 163 |
+
|
| 164 |
+
def load_pretrained(self, pretrain_paths: dict, freeze=True):
|
| 165 |
+
"""Load pretrained single-modality checkpoints into branches.
|
| 166 |
+
pretrain_paths: {modality_name: path_to_checkpoint_state_dict}."""
|
| 167 |
+
import torch as _torch
|
| 168 |
+
for m, path in pretrain_paths.items():
|
| 169 |
+
if m not in self.branches:
|
| 170 |
+
continue
|
| 171 |
+
try:
|
| 172 |
+
sd = _torch.load(path, weights_only=True, map_location='cpu')
|
| 173 |
+
except TypeError:
|
| 174 |
+
sd = _torch.load(path, map_location='cpu')
|
| 175 |
+
# Map SingleModel keys ("backbone.X.*") -> branch keys
|
| 176 |
+
mapped = {}
|
| 177 |
+
for k, v in sd.items():
|
| 178 |
+
if k.startswith('backbone.'):
|
| 179 |
+
new_k = k.replace('backbone.', '')
|
| 180 |
+
if new_k in self.branches[m].state_dict():
|
| 181 |
+
mapped[new_k] = v
|
| 182 |
+
if mapped:
|
| 183 |
+
self.branches[m].load_state_dict(mapped, strict=False)
|
| 184 |
+
if freeze:
|
| 185 |
+
for p in self.branches[m].parameters():
|
| 186 |
+
p.requires_grad = False
|
| 187 |
+
print(f" [SyncFuse] loaded {len(mapped)} tensors into branch '{m}' (frozen={freeze})")
|
| 188 |
+
|
| 189 |
+
def forward(self, x, mask, mod_dropout_p=0.0, training_time=True):
|
| 190 |
+
"""
|
| 191 |
+
x: (B, T, F_total) concatenated features
|
| 192 |
+
mask: (B, T)
|
| 193 |
+
mod_dropout_p: probability of dropping each modality (training only)
|
| 194 |
+
"""
|
| 195 |
+
B, T, _ = x.shape
|
| 196 |
+
|
| 197 |
+
# Slice modality features
|
| 198 |
+
offset = 0
|
| 199 |
+
feats = {}
|
| 200 |
+
for m in self.mod_names:
|
| 201 |
+
d = self.mod_dims[m]
|
| 202 |
+
feats[m] = x[..., offset:offset + d]
|
| 203 |
+
offset += d
|
| 204 |
+
|
| 205 |
+
# (1) Modality dropout — per sample, independent per modality
|
| 206 |
+
active = {m: torch.ones(B, dtype=torch.bool, device=x.device) for m in self.mod_names}
|
| 207 |
+
if training_time and self.training and mod_dropout_p > 0:
|
| 208 |
+
drop_map = {m: (torch.rand(B, device=x.device) < mod_dropout_p)
|
| 209 |
+
for m in self.mod_names}
|
| 210 |
+
all_dropped = torch.stack([drop_map[m] for m in self.mod_names], dim=0).all(dim=0) # (B,)
|
| 211 |
+
if all_dropped.any():
|
| 212 |
+
# for all-dropped samples, un-drop one random modality
|
| 213 |
+
rescue_idx = torch.randint(0, len(self.mod_names),
|
| 214 |
+
(all_dropped.sum().item(),),
|
| 215 |
+
device=x.device)
|
| 216 |
+
mod_name_tensor = self.mod_names # python list
|
| 217 |
+
j = 0
|
| 218 |
+
for b in range(B):
|
| 219 |
+
if all_dropped[b]:
|
| 220 |
+
r = mod_name_tensor[rescue_idx[j].item()]
|
| 221 |
+
drop_map[r][b] = False
|
| 222 |
+
j += 1
|
| 223 |
+
for m in self.mod_names:
|
| 224 |
+
active[m] = ~drop_map[m]
|
| 225 |
+
# zero out dropped features for that branch
|
| 226 |
+
feats[m] = feats[m] * active[m].view(B, 1, 1).float()
|
| 227 |
+
|
| 228 |
+
# Per-modality encoding
|
| 229 |
+
tokens = {}
|
| 230 |
+
for m in self.mod_names:
|
| 231 |
+
tokens[m] = self.branches[m](feats[m], mask) # (B, T, hidden)
|
| 232 |
+
|
| 233 |
+
# (3) Cross-modal temporal-shift (bidirectional EMG <-> MoCap)
|
| 234 |
+
if self.xmod_emg2mocap is not None:
|
| 235 |
+
tokens['emg'] = self.xmod_emg2mocap(
|
| 236 |
+
tokens['emg'], tokens['mocap'], mask, mask,
|
| 237 |
+
hard=not self.training,
|
| 238 |
+
)
|
| 239 |
+
tokens['mocap'] = self.xmod_mocap2emg(
|
| 240 |
+
tokens['mocap'], tokens['emg'], mask, mask,
|
| 241 |
+
hard=not self.training,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# Pool and classify per modality
|
| 245 |
+
logits_per = []
|
| 246 |
+
for m in self.mod_names:
|
| 247 |
+
pooled = masked_mean(tokens[m], mask)
|
| 248 |
+
logits_per.append(self.classifiers[m](pooled))
|
| 249 |
+
stacked = torch.stack(logits_per, dim=0) # (M, B, C)
|
| 250 |
+
|
| 251 |
+
# Mask out logits from dropped modalities (so they don't dominate)
|
| 252 |
+
if training_time and self.training and mod_dropout_p > 0:
|
| 253 |
+
act_mask = torch.stack([active[m].float() for m in self.mod_names], dim=0) # (M, B)
|
| 254 |
+
# Re-normalize weights across active modalities
|
| 255 |
+
if self.use_learned_late:
|
| 256 |
+
w = F.softmax(self.late_logits / self.late_temperature.clamp(min=0.1), dim=0)
|
| 257 |
+
w = w.view(-1, 1) * act_mask # (M, B)
|
| 258 |
+
w = w / w.sum(dim=0, keepdim=True).clamp(min=1e-6)
|
| 259 |
+
out = (stacked * w.unsqueeze(-1)).sum(dim=0)
|
| 260 |
+
else:
|
| 261 |
+
w = act_mask / act_mask.sum(dim=0, keepdim=True).clamp(min=1e-6)
|
| 262 |
+
out = (stacked * w.unsqueeze(-1)).sum(dim=0)
|
| 263 |
+
else:
|
| 264 |
+
# (4) Learnable late fusion (or simple mean)
|
| 265 |
+
if self.use_learned_late:
|
| 266 |
+
w = F.softmax(self.late_logits / self.late_temperature.clamp(min=0.1), dim=0)
|
| 267 |
+
out = (stacked * w.view(-1, 1, 1)).sum(dim=0)
|
| 268 |
+
else:
|
| 269 |
+
out = stacked.mean(dim=0)
|
| 270 |
+
return out
|
experiments/nets/models.py
ADDED
|
@@ -0,0 +1,648 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model definitions for Experiment 1: Scene Recognition.
|
| 3 |
+
Backbones: CNN1D, BiLSTM, Transformer
|
| 4 |
+
Fusion: Early (default), Late, Attention, WeightedLate, GatedLate, Stacking, Product, MoE
|
| 5 |
+
|
| 6 |
+
Supports optional per-modality projection via proj_dim parameter:
|
| 7 |
+
proj_dim > 0: project each modality to proj_dim before backbone
|
| 8 |
+
proj_dim = 0: no projection, use raw features (original behavior)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ============================================================
|
| 18 |
+
# Per-modality projection
|
| 19 |
+
# ============================================================
|
| 20 |
+
|
| 21 |
+
class ModalityProjector(nn.Module):
|
| 22 |
+
"""Project each modality from its raw dimension to proj_dim."""
|
| 23 |
+
|
| 24 |
+
def __init__(self, modality_dims, proj_dim):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.mod_names = list(modality_dims.keys())
|
| 27 |
+
self.mod_dims = list(modality_dims.values())
|
| 28 |
+
self.proj_dim = proj_dim
|
| 29 |
+
self.projectors = nn.ModuleList()
|
| 30 |
+
for dim in self.mod_dims:
|
| 31 |
+
self.projectors.append(nn.Sequential(
|
| 32 |
+
nn.Linear(dim, proj_dim),
|
| 33 |
+
nn.LayerNorm(proj_dim),
|
| 34 |
+
nn.ReLU(),
|
| 35 |
+
))
|
| 36 |
+
|
| 37 |
+
@property
|
| 38 |
+
def output_dim(self):
|
| 39 |
+
return self.proj_dim * len(self.mod_dims)
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
"""x: (B, T, total_raw_dim) -> (B, T, proj_dim * M)"""
|
| 43 |
+
parts = []
|
| 44 |
+
offset = 0
|
| 45 |
+
for i, dim in enumerate(self.mod_dims):
|
| 46 |
+
x_mod = x[:, :, offset:offset + dim]
|
| 47 |
+
offset += dim
|
| 48 |
+
parts.append(self.projectors[i](x_mod))
|
| 49 |
+
return torch.cat(parts, dim=-1)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ============================================================
|
| 53 |
+
# Per-modality hidden dim scaling (used when proj_dim=0)
|
| 54 |
+
# ============================================================
|
| 55 |
+
|
| 56 |
+
def _compute_per_modality_hidden(mod_dim, base_hidden_dim):
|
| 57 |
+
if mod_dim >= 128:
|
| 58 |
+
return max(base_hidden_dim, 48)
|
| 59 |
+
elif mod_dim >= 32:
|
| 60 |
+
return base_hidden_dim
|
| 61 |
+
else:
|
| 62 |
+
return max(16, base_hidden_dim // 2)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# ============================================================
|
| 66 |
+
# Backbones
|
| 67 |
+
# ============================================================
|
| 68 |
+
|
| 69 |
+
class CNN1DBackbone(nn.Module):
|
| 70 |
+
def __init__(self, input_dim, hidden_dim=128):
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.conv1 = nn.Sequential(
|
| 73 |
+
nn.Conv1d(input_dim, 64, kernel_size=7, padding=3),
|
| 74 |
+
nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.1),
|
| 75 |
+
)
|
| 76 |
+
self.conv2 = nn.Sequential(
|
| 77 |
+
nn.Conv1d(64, 128, kernel_size=5, padding=2),
|
| 78 |
+
nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.1),
|
| 79 |
+
)
|
| 80 |
+
self.conv3 = nn.Sequential(
|
| 81 |
+
nn.Conv1d(128, hidden_dim, kernel_size=3, padding=1),
|
| 82 |
+
nn.BatchNorm1d(hidden_dim), nn.ReLU(),
|
| 83 |
+
)
|
| 84 |
+
self.output_dim = hidden_dim
|
| 85 |
+
|
| 86 |
+
def forward(self, x, mask=None):
|
| 87 |
+
x = x.permute(0, 2, 1)
|
| 88 |
+
x = self.conv1(x)
|
| 89 |
+
x = self.conv2(x)
|
| 90 |
+
x = self.conv3(x)
|
| 91 |
+
if mask is not None:
|
| 92 |
+
x = (x * mask.unsqueeze(1).float()).sum(2) / mask.sum(1, keepdim=True).float().clamp(min=1)
|
| 93 |
+
else:
|
| 94 |
+
x = x.mean(2)
|
| 95 |
+
return x
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class LSTMBackbone(nn.Module):
|
| 99 |
+
def __init__(self, input_dim, hidden_dim=128, num_layers=2, dropout=0.2):
|
| 100 |
+
super().__init__()
|
| 101 |
+
self.lstm = nn.LSTM(
|
| 102 |
+
input_dim, hidden_dim, num_layers=num_layers,
|
| 103 |
+
batch_first=True, bidirectional=True,
|
| 104 |
+
dropout=dropout if num_layers > 1 else 0,
|
| 105 |
+
)
|
| 106 |
+
self.attn = nn.Linear(hidden_dim * 2, 1)
|
| 107 |
+
self.output_dim = hidden_dim * 2
|
| 108 |
+
|
| 109 |
+
def forward(self, x, mask=None):
|
| 110 |
+
out, _ = self.lstm(x)
|
| 111 |
+
scores = self.attn(out).squeeze(-1)
|
| 112 |
+
if mask is not None:
|
| 113 |
+
scores = scores.masked_fill(~mask, float('-inf'))
|
| 114 |
+
weights = torch.softmax(scores, dim=1)
|
| 115 |
+
out = (out * weights.unsqueeze(-1)).sum(dim=1)
|
| 116 |
+
return out
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class TinyHARBackbone(nn.Module):
|
| 120 |
+
"""TinyHAR backbone (Zhou et al., ISWC 2022 Best Paper).
|
| 121 |
+
|
| 122 |
+
Lightweight model for human activity recognition from wearable sensors.
|
| 123 |
+
Uses multi-scale temporal convolutions + cross-channel interaction + temporal pooling.
|
| 124 |
+
|
| 125 |
+
Input: (B, T, C) with optional mask
|
| 126 |
+
Output: (B, hidden_dim)
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(self, input_dim, hidden_dim=128, num_scales=4):
|
| 130 |
+
super().__init__()
|
| 131 |
+
scale_dim = max(4, hidden_dim // num_scales)
|
| 132 |
+
actual_hidden = scale_dim * num_scales
|
| 133 |
+
|
| 134 |
+
# Multi-scale temporal convolution feature extraction
|
| 135 |
+
self.convs = nn.ModuleList()
|
| 136 |
+
for i in range(num_scales):
|
| 137 |
+
ks = 2 * (i + 1) + 1 # kernel sizes: 3, 5, 7, 9
|
| 138 |
+
self.convs.append(nn.Sequential(
|
| 139 |
+
nn.Conv1d(input_dim, scale_dim, kernel_size=ks, padding=ks // 2),
|
| 140 |
+
nn.BatchNorm1d(scale_dim),
|
| 141 |
+
nn.ReLU(),
|
| 142 |
+
))
|
| 143 |
+
|
| 144 |
+
# Cross-channel interaction via multi-head self-attention
|
| 145 |
+
nhead = max(1, min(4, actual_hidden // 8))
|
| 146 |
+
# Ensure actual_hidden is divisible by nhead
|
| 147 |
+
while actual_hidden % nhead != 0 and nhead > 1:
|
| 148 |
+
nhead -= 1
|
| 149 |
+
self.channel_attn = nn.MultiheadAttention(
|
| 150 |
+
actual_hidden, num_heads=nhead, batch_first=True, dropout=0.1,
|
| 151 |
+
)
|
| 152 |
+
self.channel_norm = nn.LayerNorm(actual_hidden)
|
| 153 |
+
self.channel_ff = nn.Sequential(
|
| 154 |
+
nn.Linear(actual_hidden, actual_hidden),
|
| 155 |
+
nn.ReLU(),
|
| 156 |
+
nn.Dropout(0.1),
|
| 157 |
+
nn.Linear(actual_hidden, actual_hidden),
|
| 158 |
+
)
|
| 159 |
+
self.ff_norm = nn.LayerNorm(actual_hidden)
|
| 160 |
+
|
| 161 |
+
# Temporal attention pooling
|
| 162 |
+
self.temporal_query = nn.Parameter(torch.randn(1, 1, actual_hidden) * 0.02)
|
| 163 |
+
self.temporal_attn = nn.MultiheadAttention(
|
| 164 |
+
actual_hidden, num_heads=1, batch_first=True, dropout=0.1,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
self.output_dim = actual_hidden
|
| 168 |
+
|
| 169 |
+
def forward(self, x, mask=None):
|
| 170 |
+
# x: (B, T, C)
|
| 171 |
+
B, T, C = x.shape
|
| 172 |
+
x_t = x.permute(0, 2, 1) # (B, C, T)
|
| 173 |
+
|
| 174 |
+
# Multi-scale feature extraction
|
| 175 |
+
scale_features = [conv(x_t) for conv in self.convs]
|
| 176 |
+
x = torch.cat(scale_features, dim=1) # (B, actual_hidden, T)
|
| 177 |
+
x = x.permute(0, 2, 1) # (B, T, actual_hidden)
|
| 178 |
+
|
| 179 |
+
# Cross-channel interaction
|
| 180 |
+
key_padding_mask = ~mask if mask is not None else None
|
| 181 |
+
attn_out, _ = self.channel_attn(x, x, x, key_padding_mask=key_padding_mask)
|
| 182 |
+
x = self.channel_norm(x + attn_out)
|
| 183 |
+
x = self.ff_norm(x + self.channel_ff(x))
|
| 184 |
+
|
| 185 |
+
# Temporal attention pooling
|
| 186 |
+
query = self.temporal_query.expand(B, -1, -1) # (B, 1, actual_hidden)
|
| 187 |
+
pooled, _ = self.temporal_attn(query, x, x, key_padding_mask=key_padding_mask)
|
| 188 |
+
return pooled.squeeze(1) # (B, actual_hidden)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class PositionalEncoding(nn.Module):
|
| 192 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
| 193 |
+
super().__init__()
|
| 194 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 195 |
+
pe = torch.zeros(max_len, d_model)
|
| 196 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 197 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 198 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 199 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 200 |
+
pe = pe.unsqueeze(0)
|
| 201 |
+
self.register_buffer('pe', pe)
|
| 202 |
+
|
| 203 |
+
def forward(self, x):
|
| 204 |
+
x = x + self.pe[:, :x.size(1)]
|
| 205 |
+
return self.dropout(x)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class TransformerBackbone(nn.Module):
|
| 209 |
+
def __init__(self, input_dim, d_model=128, nhead=4, num_layers=2, dropout=0.1):
|
| 210 |
+
super().__init__()
|
| 211 |
+
self.input_proj = nn.Linear(input_dim, d_model)
|
| 212 |
+
self.pos_enc = PositionalEncoding(d_model, dropout=dropout)
|
| 213 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 214 |
+
d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4,
|
| 215 |
+
dropout=dropout, batch_first=True,
|
| 216 |
+
)
|
| 217 |
+
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
| 218 |
+
self.output_dim = d_model
|
| 219 |
+
|
| 220 |
+
def forward(self, x, mask=None):
|
| 221 |
+
x = self.input_proj(x)
|
| 222 |
+
x = self.pos_enc(x)
|
| 223 |
+
src_key_padding_mask = ~mask if mask is not None else None
|
| 224 |
+
x = self.encoder(x, src_key_padding_mask=src_key_padding_mask)
|
| 225 |
+
if mask is not None:
|
| 226 |
+
x = (x * mask.unsqueeze(-1).float()).sum(1) / mask.sum(1, keepdim=True).float().clamp(min=1)
|
| 227 |
+
else:
|
| 228 |
+
x = x.mean(1)
|
| 229 |
+
return x
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
# ============================================================
|
| 233 |
+
# Full models
|
| 234 |
+
# ============================================================
|
| 235 |
+
|
| 236 |
+
def get_backbone(name, input_dim, hidden_dim=128):
|
| 237 |
+
if name == 'cnn':
|
| 238 |
+
return CNN1DBackbone(input_dim, hidden_dim)
|
| 239 |
+
elif name == 'lstm':
|
| 240 |
+
return LSTMBackbone(input_dim, hidden_dim)
|
| 241 |
+
elif name == 'transformer':
|
| 242 |
+
return TransformerBackbone(input_dim, hidden_dim)
|
| 243 |
+
elif name == 'tinyhar':
|
| 244 |
+
return TinyHARBackbone(input_dim, hidden_dim)
|
| 245 |
+
elif name == 'deepconvlstm':
|
| 246 |
+
from experiments.published_models import DeepConvLSTMBackbone
|
| 247 |
+
return DeepConvLSTMBackbone(input_dim, hidden_dim)
|
| 248 |
+
elif name == 'inceptiontime':
|
| 249 |
+
from experiments.published_models import InceptionTimeBackbone
|
| 250 |
+
return InceptionTimeBackbone(input_dim, hidden_dim)
|
| 251 |
+
else:
|
| 252 |
+
raise ValueError(f"Unknown backbone: {name}")
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def _make_branch(backbone_name, raw_dim, hidden_dim, proj_dim):
|
| 256 |
+
"""Create optional projector + backbone for one modality branch."""
|
| 257 |
+
if proj_dim > 0:
|
| 258 |
+
proj = nn.Sequential(
|
| 259 |
+
nn.Linear(raw_dim, proj_dim),
|
| 260 |
+
nn.LayerNorm(proj_dim),
|
| 261 |
+
nn.ReLU(),
|
| 262 |
+
)
|
| 263 |
+
bb_input = proj_dim
|
| 264 |
+
bb_hidden = hidden_dim
|
| 265 |
+
else:
|
| 266 |
+
proj = None
|
| 267 |
+
bb_input = raw_dim
|
| 268 |
+
bb_hidden = _compute_per_modality_hidden(raw_dim, hidden_dim)
|
| 269 |
+
bb = get_backbone(backbone_name, bb_input, bb_hidden)
|
| 270 |
+
return proj, bb
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class SingleModel(nn.Module):
|
| 274 |
+
"""Single backbone + classifier (early fusion or single-modality)."""
|
| 275 |
+
|
| 276 |
+
def __init__(self, backbone_name, input_dim, num_classes, hidden_dim=128,
|
| 277 |
+
modality_dims=None, proj_dim=0):
|
| 278 |
+
super().__init__()
|
| 279 |
+
self.projector = None
|
| 280 |
+
if proj_dim > 0 and modality_dims:
|
| 281 |
+
self.projector = ModalityProjector(modality_dims, proj_dim)
|
| 282 |
+
actual_input_dim = self.projector.output_dim
|
| 283 |
+
else:
|
| 284 |
+
actual_input_dim = input_dim
|
| 285 |
+
self.backbone = get_backbone(backbone_name, actual_input_dim, hidden_dim)
|
| 286 |
+
self.classifier = nn.Sequential(
|
| 287 |
+
nn.Dropout(0.5),
|
| 288 |
+
nn.Linear(self.backbone.output_dim, num_classes),
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
def forward(self, x, mask=None):
|
| 292 |
+
if self.projector is not None:
|
| 293 |
+
x = self.projector(x)
|
| 294 |
+
feat = self.backbone(x, mask)
|
| 295 |
+
return self.classifier(feat)
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class LateFusionModel(nn.Module):
|
| 299 |
+
"""Late fusion: separate backbone per modality, configurable logit aggregation.
|
| 300 |
+
|
| 301 |
+
late_agg='mean': simple average (original)
|
| 302 |
+
late_agg='confidence': entropy-based confidence weighting (0 extra params)
|
| 303 |
+
late_agg='learned': temperature-scaled learned weights (M+1 extra params)
|
| 304 |
+
"""
|
| 305 |
+
|
| 306 |
+
def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64,
|
| 307 |
+
proj_dim=0, late_agg='mean'):
|
| 308 |
+
super().__init__()
|
| 309 |
+
self.mod_names = list(modality_dims.keys())
|
| 310 |
+
self.mod_dims = list(modality_dims.values())
|
| 311 |
+
self.late_agg = late_agg
|
| 312 |
+
self.projectors = nn.ModuleList()
|
| 313 |
+
self.backbones = nn.ModuleList()
|
| 314 |
+
self.classifiers = nn.ModuleList()
|
| 315 |
+
for dim in self.mod_dims:
|
| 316 |
+
proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim)
|
| 317 |
+
self.projectors.append(proj if proj else nn.Identity())
|
| 318 |
+
self.backbones.append(bb)
|
| 319 |
+
self.classifiers.append(nn.Sequential(
|
| 320 |
+
nn.Dropout(0.5), nn.Linear(bb.output_dim, num_classes),
|
| 321 |
+
))
|
| 322 |
+
self._has_proj = proj_dim > 0
|
| 323 |
+
|
| 324 |
+
M = len(self.mod_dims)
|
| 325 |
+
if late_agg == 'learned':
|
| 326 |
+
self.modality_logits = nn.Parameter(torch.zeros(M))
|
| 327 |
+
self.temperature = nn.Parameter(torch.ones(1))
|
| 328 |
+
|
| 329 |
+
def forward(self, x, mask=None):
|
| 330 |
+
offset = 0
|
| 331 |
+
all_logits = []
|
| 332 |
+
for i, dim in enumerate(self.mod_dims):
|
| 333 |
+
x_mod = x[:, :, offset:offset + dim]
|
| 334 |
+
offset += dim
|
| 335 |
+
if self._has_proj:
|
| 336 |
+
x_mod = self.projectors[i](x_mod)
|
| 337 |
+
feat = self.backbones[i](x_mod, mask)
|
| 338 |
+
all_logits.append(self.classifiers[i](feat))
|
| 339 |
+
|
| 340 |
+
stacked = torch.stack(all_logits, dim=0) # (M, B, C)
|
| 341 |
+
|
| 342 |
+
if self.late_agg == 'confidence':
|
| 343 |
+
# Weight by confidence: low entropy → high weight
|
| 344 |
+
probs = F.softmax(stacked, dim=-1) # (M, B, C)
|
| 345 |
+
entropy = -(probs * (probs + 1e-8).log()).sum(dim=-1) # (M, B)
|
| 346 |
+
weights = F.softmax(-entropy, dim=0).unsqueeze(-1) # (M, B, 1)
|
| 347 |
+
return (stacked * weights).sum(dim=0)
|
| 348 |
+
elif self.late_agg == 'learned':
|
| 349 |
+
weights = F.softmax(self.modality_logits / self.temperature, dim=0)
|
| 350 |
+
return (stacked * weights.view(-1, 1, 1)).sum(dim=0)
|
| 351 |
+
else: # 'mean'
|
| 352 |
+
return stacked.mean(dim=0)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class AttentionFusionModel(nn.Module):
|
| 356 |
+
"""Attention fusion: separate encoder per modality -> cross-modal attention -> classifier."""
|
| 357 |
+
|
| 358 |
+
def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0):
|
| 359 |
+
super().__init__()
|
| 360 |
+
self.mod_names = list(modality_dims.keys())
|
| 361 |
+
self.mod_dims = list(modality_dims.values())
|
| 362 |
+
unified_dim = hidden_dim
|
| 363 |
+
self.projectors = nn.ModuleList()
|
| 364 |
+
self.backbones = nn.ModuleList()
|
| 365 |
+
self.feat_projections = nn.ModuleList()
|
| 366 |
+
for dim in self.mod_dims:
|
| 367 |
+
proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim)
|
| 368 |
+
self.projectors.append(proj if proj else nn.Identity())
|
| 369 |
+
self.backbones.append(bb)
|
| 370 |
+
if bb.output_dim != unified_dim:
|
| 371 |
+
self.feat_projections.append(nn.Linear(bb.output_dim, unified_dim))
|
| 372 |
+
else:
|
| 373 |
+
self.feat_projections.append(nn.Identity())
|
| 374 |
+
self._has_proj = proj_dim > 0
|
| 375 |
+
nhead = 4 if unified_dim % 4 == 0 else (2 if unified_dim % 2 == 0 else 1)
|
| 376 |
+
self.cross_attn = nn.TransformerEncoderLayer(
|
| 377 |
+
d_model=unified_dim, nhead=nhead, dim_feedforward=unified_dim * 2,
|
| 378 |
+
dropout=0.1, batch_first=True,
|
| 379 |
+
)
|
| 380 |
+
self.classifier = nn.Sequential(
|
| 381 |
+
nn.Dropout(0.5), nn.Linear(unified_dim, num_classes),
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
def forward(self, x, mask=None):
|
| 385 |
+
offset = 0
|
| 386 |
+
mod_features = []
|
| 387 |
+
for i, dim in enumerate(self.mod_dims):
|
| 388 |
+
x_mod = x[:, :, offset:offset + dim]
|
| 389 |
+
offset += dim
|
| 390 |
+
if self._has_proj:
|
| 391 |
+
x_mod = self.projectors[i](x_mod)
|
| 392 |
+
feat = self.backbones[i](x_mod, mask)
|
| 393 |
+
feat = self.feat_projections[i](feat)
|
| 394 |
+
mod_features.append(feat)
|
| 395 |
+
tokens = torch.stack(mod_features, dim=1)
|
| 396 |
+
tokens = self.cross_attn(tokens)
|
| 397 |
+
pooled = tokens.mean(dim=1)
|
| 398 |
+
return self.classifier(pooled)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
class WeightedLateFusionModel(nn.Module):
|
| 402 |
+
def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0):
|
| 403 |
+
super().__init__()
|
| 404 |
+
self.mod_names = list(modality_dims.keys())
|
| 405 |
+
self.mod_dims = list(modality_dims.values())
|
| 406 |
+
self.projectors = nn.ModuleList()
|
| 407 |
+
self.backbones = nn.ModuleList()
|
| 408 |
+
self.classifiers = nn.ModuleList()
|
| 409 |
+
for dim in self.mod_dims:
|
| 410 |
+
proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim)
|
| 411 |
+
self.projectors.append(proj if proj else nn.Identity())
|
| 412 |
+
self.backbones.append(bb)
|
| 413 |
+
self.classifiers.append(nn.Sequential(
|
| 414 |
+
nn.Dropout(0.5), nn.Linear(bb.output_dim, num_classes),
|
| 415 |
+
))
|
| 416 |
+
self._has_proj = proj_dim > 0
|
| 417 |
+
self.modality_weights = nn.Parameter(torch.ones(len(self.mod_dims)))
|
| 418 |
+
|
| 419 |
+
def forward(self, x, mask=None):
|
| 420 |
+
offset = 0
|
| 421 |
+
all_logits = []
|
| 422 |
+
for i, dim in enumerate(self.mod_dims):
|
| 423 |
+
x_mod = x[:, :, offset:offset + dim]
|
| 424 |
+
offset += dim
|
| 425 |
+
if self._has_proj:
|
| 426 |
+
x_mod = self.projectors[i](x_mod)
|
| 427 |
+
feat = self.backbones[i](x_mod, mask)
|
| 428 |
+
all_logits.append(self.classifiers[i](feat))
|
| 429 |
+
weights = F.softmax(self.modality_weights, dim=0)
|
| 430 |
+
stacked = torch.stack(all_logits, dim=0)
|
| 431 |
+
return (stacked * weights.view(-1, 1, 1)).sum(dim=0)
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
class GatedLateFusionModel(nn.Module):
|
| 435 |
+
def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0):
|
| 436 |
+
super().__init__()
|
| 437 |
+
self.mod_names = list(modality_dims.keys())
|
| 438 |
+
self.mod_dims = list(modality_dims.values())
|
| 439 |
+
M = len(self.mod_dims)
|
| 440 |
+
self.projectors = nn.ModuleList()
|
| 441 |
+
self.backbones = nn.ModuleList()
|
| 442 |
+
self.classifiers = nn.ModuleList()
|
| 443 |
+
total_feat_dim = 0
|
| 444 |
+
for dim in self.mod_dims:
|
| 445 |
+
proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim)
|
| 446 |
+
self.projectors.append(proj if proj else nn.Identity())
|
| 447 |
+
self.backbones.append(bb)
|
| 448 |
+
total_feat_dim += bb.output_dim
|
| 449 |
+
self.classifiers.append(nn.Sequential(
|
| 450 |
+
nn.Dropout(0.5), nn.Linear(bb.output_dim, num_classes),
|
| 451 |
+
))
|
| 452 |
+
self._has_proj = proj_dim > 0
|
| 453 |
+
self.gate = nn.Sequential(
|
| 454 |
+
nn.Linear(total_feat_dim, 32), nn.ReLU(), nn.Linear(32, M),
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
def forward(self, x, mask=None):
|
| 458 |
+
offset = 0
|
| 459 |
+
all_feats, all_logits = [], []
|
| 460 |
+
for i, dim in enumerate(self.mod_dims):
|
| 461 |
+
x_mod = x[:, :, offset:offset + dim]
|
| 462 |
+
offset += dim
|
| 463 |
+
if self._has_proj:
|
| 464 |
+
x_mod = self.projectors[i](x_mod)
|
| 465 |
+
feat = self.backbones[i](x_mod, mask)
|
| 466 |
+
all_feats.append(feat)
|
| 467 |
+
all_logits.append(self.classifiers[i](feat))
|
| 468 |
+
cat_feats = torch.cat(all_feats, dim=1)
|
| 469 |
+
gate_weights = F.softmax(self.gate(cat_feats), dim=1)
|
| 470 |
+
stacked = torch.stack(all_logits, dim=1)
|
| 471 |
+
return (stacked * gate_weights.unsqueeze(-1)).sum(dim=1)
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
class StackingFusionModel(nn.Module):
|
| 475 |
+
def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0):
|
| 476 |
+
super().__init__()
|
| 477 |
+
self.mod_names = list(modality_dims.keys())
|
| 478 |
+
self.mod_dims = list(modality_dims.values())
|
| 479 |
+
M = len(self.mod_dims)
|
| 480 |
+
self.projectors = nn.ModuleList()
|
| 481 |
+
self.backbones = nn.ModuleList()
|
| 482 |
+
self.classifiers = nn.ModuleList()
|
| 483 |
+
for dim in self.mod_dims:
|
| 484 |
+
proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim)
|
| 485 |
+
self.projectors.append(proj if proj else nn.Identity())
|
| 486 |
+
self.backbones.append(bb)
|
| 487 |
+
self.classifiers.append(nn.Sequential(
|
| 488 |
+
nn.Dropout(0.5), nn.Linear(bb.output_dim, num_classes),
|
| 489 |
+
))
|
| 490 |
+
self._has_proj = proj_dim > 0
|
| 491 |
+
self.meta_learner = nn.Sequential(
|
| 492 |
+
nn.Linear(M * num_classes, 32), nn.ReLU(),
|
| 493 |
+
nn.Dropout(0.5), nn.Linear(32, num_classes),
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
def forward(self, x, mask=None):
|
| 497 |
+
offset = 0
|
| 498 |
+
all_logits = []
|
| 499 |
+
for i, dim in enumerate(self.mod_dims):
|
| 500 |
+
x_mod = x[:, :, offset:offset + dim]
|
| 501 |
+
offset += dim
|
| 502 |
+
if self._has_proj:
|
| 503 |
+
x_mod = self.projectors[i](x_mod)
|
| 504 |
+
feat = self.backbones[i](x_mod, mask)
|
| 505 |
+
all_logits.append(self.classifiers[i](feat))
|
| 506 |
+
cat_logits = torch.cat(all_logits, dim=1)
|
| 507 |
+
return self.meta_learner(cat_logits)
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
class ProductOfExpertsModel(nn.Module):
|
| 511 |
+
def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0):
|
| 512 |
+
super().__init__()
|
| 513 |
+
self.mod_names = list(modality_dims.keys())
|
| 514 |
+
self.mod_dims = list(modality_dims.values())
|
| 515 |
+
self.projectors = nn.ModuleList()
|
| 516 |
+
self.backbones = nn.ModuleList()
|
| 517 |
+
self.classifiers = nn.ModuleList()
|
| 518 |
+
for dim in self.mod_dims:
|
| 519 |
+
proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim)
|
| 520 |
+
self.projectors.append(proj if proj else nn.Identity())
|
| 521 |
+
self.backbones.append(bb)
|
| 522 |
+
self.classifiers.append(nn.Sequential(
|
| 523 |
+
nn.Dropout(0.5), nn.Linear(bb.output_dim, num_classes),
|
| 524 |
+
))
|
| 525 |
+
self._has_proj = proj_dim > 0
|
| 526 |
+
|
| 527 |
+
def forward(self, x, mask=None):
|
| 528 |
+
offset = 0
|
| 529 |
+
log_probs_sum = None
|
| 530 |
+
for i, dim in enumerate(self.mod_dims):
|
| 531 |
+
x_mod = x[:, :, offset:offset + dim]
|
| 532 |
+
offset += dim
|
| 533 |
+
if self._has_proj:
|
| 534 |
+
x_mod = self.projectors[i](x_mod)
|
| 535 |
+
feat = self.backbones[i](x_mod, mask)
|
| 536 |
+
logits = self.classifiers[i](feat)
|
| 537 |
+
log_p = F.log_softmax(logits, dim=1)
|
| 538 |
+
log_probs_sum = log_p if log_probs_sum is None else log_probs_sum + log_p
|
| 539 |
+
return log_probs_sum
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
class MoEFusionModel(nn.Module):
|
| 543 |
+
def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0):
|
| 544 |
+
super().__init__()
|
| 545 |
+
self.mod_names = list(modality_dims.keys())
|
| 546 |
+
self.mod_dims = list(modality_dims.values())
|
| 547 |
+
M = len(self.mod_dims)
|
| 548 |
+
self.top_k = min(2, M)
|
| 549 |
+
self.projectors = nn.ModuleList()
|
| 550 |
+
self.backbones = nn.ModuleList()
|
| 551 |
+
self.classifiers = nn.ModuleList()
|
| 552 |
+
total_feat_dim = 0
|
| 553 |
+
for dim in self.mod_dims:
|
| 554 |
+
proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim)
|
| 555 |
+
self.projectors.append(proj if proj else nn.Identity())
|
| 556 |
+
self.backbones.append(bb)
|
| 557 |
+
total_feat_dim += bb.output_dim
|
| 558 |
+
self.classifiers.append(nn.Sequential(
|
| 559 |
+
nn.Dropout(0.5), nn.Linear(bb.output_dim, num_classes),
|
| 560 |
+
))
|
| 561 |
+
self._has_proj = proj_dim > 0
|
| 562 |
+
self.router = nn.Linear(total_feat_dim, M)
|
| 563 |
+
|
| 564 |
+
def forward(self, x, mask=None):
|
| 565 |
+
offset = 0
|
| 566 |
+
all_feats, all_logits = [], []
|
| 567 |
+
for i, dim in enumerate(self.mod_dims):
|
| 568 |
+
x_mod = x[:, :, offset:offset + dim]
|
| 569 |
+
offset += dim
|
| 570 |
+
if self._has_proj:
|
| 571 |
+
x_mod = self.projectors[i](x_mod)
|
| 572 |
+
feat = self.backbones[i](x_mod, mask)
|
| 573 |
+
all_feats.append(feat)
|
| 574 |
+
all_logits.append(self.classifiers[i](feat))
|
| 575 |
+
cat_feats = torch.cat(all_feats, dim=1)
|
| 576 |
+
router_logits = self.router(cat_feats)
|
| 577 |
+
top_vals, top_idx = router_logits.topk(self.top_k, dim=1)
|
| 578 |
+
top_weights = F.softmax(top_vals, dim=1)
|
| 579 |
+
stacked = torch.stack(all_logits, dim=1)
|
| 580 |
+
top_idx_exp = top_idx.unsqueeze(-1).expand(-1, -1, stacked.size(-1))
|
| 581 |
+
selected = stacked.gather(1, top_idx_exp)
|
| 582 |
+
return (selected * top_weights.unsqueeze(-1)).sum(dim=1)
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
class FeatureConcatFusionModel(nn.Module):
|
| 586 |
+
"""Feature-level late fusion: separate backbones, concatenate features, joint classifier."""
|
| 587 |
+
|
| 588 |
+
def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0):
|
| 589 |
+
super().__init__()
|
| 590 |
+
self.mod_names = list(modality_dims.keys())
|
| 591 |
+
self.mod_dims = list(modality_dims.values())
|
| 592 |
+
self.projectors = nn.ModuleList()
|
| 593 |
+
self.backbones = nn.ModuleList()
|
| 594 |
+
total_feat_dim = 0
|
| 595 |
+
for dim in self.mod_dims:
|
| 596 |
+
proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim)
|
| 597 |
+
self.projectors.append(proj if proj else nn.Identity())
|
| 598 |
+
self.backbones.append(bb)
|
| 599 |
+
total_feat_dim += bb.output_dim
|
| 600 |
+
self._has_proj = proj_dim > 0
|
| 601 |
+
self.classifier = nn.Sequential(
|
| 602 |
+
nn.LayerNorm(total_feat_dim),
|
| 603 |
+
nn.Dropout(0.5),
|
| 604 |
+
nn.Linear(total_feat_dim, hidden_dim),
|
| 605 |
+
nn.ReLU(),
|
| 606 |
+
nn.Dropout(0.3),
|
| 607 |
+
nn.Linear(hidden_dim, num_classes),
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
def forward(self, x, mask=None):
|
| 611 |
+
offset = 0
|
| 612 |
+
all_feats = []
|
| 613 |
+
for i, dim in enumerate(self.mod_dims):
|
| 614 |
+
x_mod = x[:, :, offset:offset + dim]
|
| 615 |
+
offset += dim
|
| 616 |
+
if self._has_proj:
|
| 617 |
+
x_mod = self.projectors[i](x_mod)
|
| 618 |
+
feat = self.backbones[i](x_mod, mask)
|
| 619 |
+
all_feats.append(feat)
|
| 620 |
+
cat_feats = torch.cat(all_feats, dim=1)
|
| 621 |
+
return self.classifier(cat_feats)
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
def build_model(backbone_name, fusion, input_dim, modality_dims, num_classes,
|
| 625 |
+
hidden_dim=128, proj_dim=0, late_agg='mean'):
|
| 626 |
+
"""Factory function. proj_dim=0 means no projection (raw features)."""
|
| 627 |
+
if fusion == 'early':
|
| 628 |
+
return SingleModel(backbone_name, input_dim, num_classes, hidden_dim,
|
| 629 |
+
modality_dims=modality_dims, proj_dim=proj_dim)
|
| 630 |
+
elif fusion == 'late':
|
| 631 |
+
return LateFusionModel(backbone_name, modality_dims, num_classes, hidden_dim,
|
| 632 |
+
proj_dim, late_agg=late_agg)
|
| 633 |
+
elif fusion == 'attention':
|
| 634 |
+
return AttentionFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim)
|
| 635 |
+
elif fusion == 'weighted_late':
|
| 636 |
+
return WeightedLateFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim)
|
| 637 |
+
elif fusion == 'gated_late':
|
| 638 |
+
return GatedLateFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim)
|
| 639 |
+
elif fusion == 'stacking':
|
| 640 |
+
return StackingFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim)
|
| 641 |
+
elif fusion == 'product':
|
| 642 |
+
return ProductOfExpertsModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim)
|
| 643 |
+
elif fusion == 'moe':
|
| 644 |
+
return MoEFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim)
|
| 645 |
+
elif fusion == 'feat_concat':
|
| 646 |
+
return FeatureConcatFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim)
|
| 647 |
+
else:
|
| 648 |
+
raise ValueError(f"Unknown fusion: {fusion}")
|
experiments/nets/models_forecast.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Frame-level future forecasting models.
|
| 2 |
+
|
| 3 |
+
Three baselines (all sharing the same forecast head signature):
|
| 4 |
+
- TransformerForecast (our DAF-style)
|
| 5 |
+
- FUTRForecast (Transformer encoder + parallel query decoder)
|
| 6 |
+
- DeepConvLSTMForecast (Ordoñez & Roggen 2016 wearable HAR backbone)
|
| 7 |
+
|
| 8 |
+
All take a dict {mod: (B, T_obs, F_mod)} and output (B, T_fut, num_classes).
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
from typing import Dict, List
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ---------------------------------------------------------------------------
|
| 19 |
+
# Shared per-modality projection: each modality -> hidden dim d_model
|
| 20 |
+
# ---------------------------------------------------------------------------
|
| 21 |
+
|
| 22 |
+
class _PerModalityProj(nn.Module):
|
| 23 |
+
def __init__(self, modality_dims: Dict[str, int], d_model: int):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.proj = nn.ModuleDict({
|
| 26 |
+
m: nn.Linear(d, d_model) for m, d in modality_dims.items()
|
| 27 |
+
})
|
| 28 |
+
self.mod_emb = nn.Parameter(torch.zeros(len(modality_dims), d_model))
|
| 29 |
+
nn.init.trunc_normal_(self.mod_emb, std=0.02)
|
| 30 |
+
self.mods = list(modality_dims.keys())
|
| 31 |
+
|
| 32 |
+
def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 33 |
+
# Concatenate per-modality projections along time? Or sum?
|
| 34 |
+
# We sum modality-projected features per time step (with modality
|
| 35 |
+
# embedding broadcast). Equivalent to early-fusion at the d_model
|
| 36 |
+
# space and is what a "modality-aware Transformer" typically uses.
|
| 37 |
+
out = None
|
| 38 |
+
for i, m in enumerate(self.mods):
|
| 39 |
+
h = self.proj[m](x[m]) + self.mod_emb[i]
|
| 40 |
+
out = h if out is None else out + h
|
| 41 |
+
return out / len(self.mods) # (B, T_obs, d_model)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ---------------------------------------------------------------------------
|
| 45 |
+
# 1. Transformer (DAF-style) forecast model
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
class TransformerForecast(nn.Module):
|
| 49 |
+
def __init__(self, modality_dims: Dict[str, int], num_classes: int,
|
| 50 |
+
t_obs: int, t_fut: int, d_model: int = 128,
|
| 51 |
+
n_heads: int = 4, n_layers: int = 2, dropout: float = 0.1):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.t_obs = t_obs
|
| 54 |
+
self.t_fut = t_fut
|
| 55 |
+
self.num_classes = num_classes
|
| 56 |
+
self.embed = _PerModalityProj(modality_dims, d_model)
|
| 57 |
+
self.pos = nn.Parameter(torch.zeros(1, t_obs, d_model))
|
| 58 |
+
nn.init.trunc_normal_(self.pos, std=0.02)
|
| 59 |
+
layer = nn.TransformerEncoderLayer(
|
| 60 |
+
d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
|
| 61 |
+
dropout=dropout, batch_first=True, activation="gelu",
|
| 62 |
+
)
|
| 63 |
+
self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)
|
| 64 |
+
self.queries = nn.Parameter(torch.zeros(1, t_fut, d_model))
|
| 65 |
+
nn.init.trunc_normal_(self.queries, std=0.02)
|
| 66 |
+
self.cross_attn = nn.MultiheadAttention(
|
| 67 |
+
d_model, n_heads, dropout=dropout, batch_first=True
|
| 68 |
+
)
|
| 69 |
+
self.norm = nn.LayerNorm(d_model)
|
| 70 |
+
self.head = nn.Linear(d_model, num_classes)
|
| 71 |
+
|
| 72 |
+
def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 73 |
+
h = self.embed(x) + self.pos
|
| 74 |
+
h = self.encoder(h) # (B, T_obs, D)
|
| 75 |
+
q = self.queries.expand(h.size(0), -1, -1) # (B, T_fut, D)
|
| 76 |
+
out, _ = self.cross_attn(q, h, h, need_weights=False)
|
| 77 |
+
out = self.norm(out)
|
| 78 |
+
return self.head(out) # (B, T_fut, C)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# ---------------------------------------------------------------------------
|
| 82 |
+
# 2. FUTR-style forecast (Future Transformer, Gong et al. CVPR 2022)
|
| 83 |
+
# Same encoder + parallel query decoder. We add a small Transformer
|
| 84 |
+
# decoder so it's not literally identical to TransformerForecast.
|
| 85 |
+
# ---------------------------------------------------------------------------
|
| 86 |
+
|
| 87 |
+
class FUTRForecast(nn.Module):
|
| 88 |
+
def __init__(self, modality_dims: Dict[str, int], num_classes: int,
|
| 89 |
+
t_obs: int, t_fut: int, d_model: int = 128,
|
| 90 |
+
n_heads: int = 4, n_enc: int = 2, n_dec: int = 1,
|
| 91 |
+
dropout: float = 0.1):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.t_obs = t_obs
|
| 94 |
+
self.t_fut = t_fut
|
| 95 |
+
self.num_classes = num_classes
|
| 96 |
+
self.embed = _PerModalityProj(modality_dims, d_model)
|
| 97 |
+
self.pos = nn.Parameter(torch.zeros(1, t_obs, d_model))
|
| 98 |
+
nn.init.trunc_normal_(self.pos, std=0.02)
|
| 99 |
+
enc_layer = nn.TransformerEncoderLayer(
|
| 100 |
+
d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
|
| 101 |
+
dropout=dropout, batch_first=True, activation="gelu",
|
| 102 |
+
)
|
| 103 |
+
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_enc)
|
| 104 |
+
dec_layer = nn.TransformerDecoderLayer(
|
| 105 |
+
d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
|
| 106 |
+
dropout=dropout, batch_first=True, activation="gelu",
|
| 107 |
+
)
|
| 108 |
+
self.decoder = nn.TransformerDecoder(dec_layer, num_layers=n_dec)
|
| 109 |
+
self.queries = nn.Parameter(torch.zeros(1, t_fut, d_model))
|
| 110 |
+
nn.init.trunc_normal_(self.queries, std=0.02)
|
| 111 |
+
self.head = nn.Linear(d_model, num_classes)
|
| 112 |
+
|
| 113 |
+
def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 114 |
+
memory = self.encoder(self.embed(x) + self.pos) # (B, T_obs, D)
|
| 115 |
+
q = self.queries.expand(memory.size(0), -1, -1) # (B, T_fut, D)
|
| 116 |
+
out = self.decoder(q, memory)
|
| 117 |
+
return self.head(out) # (B, T_fut, C)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# ---------------------------------------------------------------------------
|
| 121 |
+
# 3. DeepConvLSTM-style forecast
|
| 122 |
+
# ---------------------------------------------------------------------------
|
| 123 |
+
|
| 124 |
+
class DeepConvLSTMForecast(nn.Module):
|
| 125 |
+
def __init__(self, modality_dims: Dict[str, int], num_classes: int,
|
| 126 |
+
t_obs: int, t_fut: int, conv_filters: int = 64,
|
| 127 |
+
lstm_hidden: int = 128, n_lstm_layers: int = 2,
|
| 128 |
+
dropout: float = 0.1):
|
| 129 |
+
super().__init__()
|
| 130 |
+
self.t_obs = t_obs
|
| 131 |
+
self.t_fut = t_fut
|
| 132 |
+
self.num_classes = num_classes
|
| 133 |
+
self.mods = list(modality_dims.keys())
|
| 134 |
+
in_ch = sum(modality_dims.values())
|
| 135 |
+
# Same 4-layer conv stack as the original DeepConvLSTM
|
| 136 |
+
layers = []
|
| 137 |
+
ch = in_ch
|
| 138 |
+
for i in range(4):
|
| 139 |
+
layers.append(nn.Sequential(
|
| 140 |
+
nn.Conv1d(ch, conv_filters, kernel_size=5, padding=2),
|
| 141 |
+
nn.BatchNorm1d(conv_filters),
|
| 142 |
+
nn.ReLU(),
|
| 143 |
+
nn.Dropout(dropout if i < 3 else 0.2),
|
| 144 |
+
))
|
| 145 |
+
ch = conv_filters
|
| 146 |
+
self.convs = nn.ModuleList(layers)
|
| 147 |
+
self.lstm = nn.LSTM(
|
| 148 |
+
conv_filters, lstm_hidden, num_layers=n_lstm_layers,
|
| 149 |
+
batch_first=True, dropout=dropout if n_lstm_layers > 1 else 0,
|
| 150 |
+
)
|
| 151 |
+
self.head = nn.Linear(lstm_hidden, t_fut * num_classes)
|
| 152 |
+
|
| 153 |
+
def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 154 |
+
h = torch.cat([x[m] for m in self.mods], dim=-1) # (B, T_obs, F_total)
|
| 155 |
+
h = h.permute(0, 2, 1) # (B, F, T_obs)
|
| 156 |
+
for c in self.convs:
|
| 157 |
+
h = c(h)
|
| 158 |
+
h = h.permute(0, 2, 1) # (B, T_obs, conv_filters)
|
| 159 |
+
out, (h_n, _) = self.lstm(h)
|
| 160 |
+
feat = h_n[-1] # (B, lstm_hidden)
|
| 161 |
+
logits = self.head(feat).view(-1, self.t_fut, self.num_classes)
|
| 162 |
+
return logits
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# ---------------------------------------------------------------------------
|
| 166 |
+
# 4. RU-LSTM (Furnari et al. RAL 2019, "Rolling-Unrolling LSTM for action
|
| 167 |
+
# anticipation"). Two-phase LSTM: a "rolling" phase encodes past, an
|
| 168 |
+
# "unrolling" phase autoregressively decodes future tokens.
|
| 169 |
+
# ---------------------------------------------------------------------------
|
| 170 |
+
|
| 171 |
+
class RULSTMForecast(nn.Module):
|
| 172 |
+
def __init__(self, modality_dims: Dict[str, int], num_classes: int,
|
| 173 |
+
t_obs: int, t_fut: int, d_model: int = 128,
|
| 174 |
+
n_lstm_layers: int = 2, dropout: float = 0.1):
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.t_obs = t_obs
|
| 177 |
+
self.t_fut = t_fut
|
| 178 |
+
self.num_classes = num_classes
|
| 179 |
+
self.embed = _PerModalityProj(modality_dims, d_model)
|
| 180 |
+
self.rolling = nn.LSTM(
|
| 181 |
+
d_model, d_model, num_layers=n_lstm_layers,
|
| 182 |
+
batch_first=True, dropout=dropout if n_lstm_layers > 1 else 0,
|
| 183 |
+
)
|
| 184 |
+
self.unrolling = nn.LSTM(
|
| 185 |
+
d_model, d_model, num_layers=n_lstm_layers,
|
| 186 |
+
batch_first=True, dropout=dropout if n_lstm_layers > 1 else 0,
|
| 187 |
+
)
|
| 188 |
+
self.fut_init = nn.Parameter(torch.zeros(1, 1, d_model))
|
| 189 |
+
nn.init.trunc_normal_(self.fut_init, std=0.02)
|
| 190 |
+
self.head = nn.Linear(d_model, num_classes)
|
| 191 |
+
|
| 192 |
+
def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 193 |
+
h_past = self.embed(x) # (B, T_obs, D)
|
| 194 |
+
_, (h_n, c_n) = self.rolling(h_past)
|
| 195 |
+
B = h_past.size(0)
|
| 196 |
+
# Use a learned initial future token, repeated T_fut times
|
| 197 |
+
fut_input = self.fut_init.expand(B, self.t_fut, -1)
|
| 198 |
+
out, _ = self.unrolling(fut_input, (h_n, c_n))
|
| 199 |
+
return self.head(out) # (B, T_fut, C)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# ---------------------------------------------------------------------------
|
| 203 |
+
# 5. AVT (Girdhar & Grauman ICCV 2021, "Anticipative Video Transformer").
|
| 204 |
+
# Causal Transformer over the concatenation of past + future tokens.
|
| 205 |
+
# ---------------------------------------------------------------------------
|
| 206 |
+
|
| 207 |
+
class AVTForecast(nn.Module):
|
| 208 |
+
def __init__(self, modality_dims: Dict[str, int], num_classes: int,
|
| 209 |
+
t_obs: int, t_fut: int, d_model: int = 128,
|
| 210 |
+
n_heads: int = 4, n_layers: int = 2, dropout: float = 0.1):
|
| 211 |
+
super().__init__()
|
| 212 |
+
self.t_obs = t_obs
|
| 213 |
+
self.t_fut = t_fut
|
| 214 |
+
self.num_classes = num_classes
|
| 215 |
+
self.embed = _PerModalityProj(modality_dims, d_model)
|
| 216 |
+
seq_len = t_obs + t_fut
|
| 217 |
+
self.pos = nn.Parameter(torch.zeros(1, seq_len, d_model))
|
| 218 |
+
nn.init.trunc_normal_(self.pos, std=0.02)
|
| 219 |
+
layer = nn.TransformerEncoderLayer(
|
| 220 |
+
d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
|
| 221 |
+
dropout=dropout, batch_first=True, activation="gelu",
|
| 222 |
+
)
|
| 223 |
+
self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)
|
| 224 |
+
self.fut_tokens = nn.Parameter(torch.zeros(1, t_fut, d_model))
|
| 225 |
+
nn.init.trunc_normal_(self.fut_tokens, std=0.02)
|
| 226 |
+
self.head = nn.Linear(d_model, num_classes)
|
| 227 |
+
# Causal mask over concatenated [past | future] sequence
|
| 228 |
+
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
|
| 229 |
+
self.register_buffer("causal_mask", mask)
|
| 230 |
+
|
| 231 |
+
def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 232 |
+
h_past = self.embed(x) # (B, T_obs, D)
|
| 233 |
+
B = h_past.size(0)
|
| 234 |
+
h_fut = self.fut_tokens.expand(B, -1, -1) # (B, T_fut, D)
|
| 235 |
+
seq = torch.cat([h_past, h_fut], dim=1) + self.pos
|
| 236 |
+
out = self.encoder(seq, mask=self.causal_mask)
|
| 237 |
+
out_fut = out[:, self.t_obs:, :]
|
| 238 |
+
return self.head(out_fut) # (B, T_fut, C)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# ---------------------------------------------------------------------------
|
| 242 |
+
# Builder
|
| 243 |
+
# ---------------------------------------------------------------------------
|
| 244 |
+
|
| 245 |
+
def build_forecast_model(name: str, modality_dims: Dict[str, int],
|
| 246 |
+
num_classes: int, t_obs: int, t_fut: int,
|
| 247 |
+
d_model: int = 128, dropout: float = 0.1) -> nn.Module:
|
| 248 |
+
name = name.lower()
|
| 249 |
+
if name in ("daf", "transformer"):
|
| 250 |
+
return TransformerForecast(modality_dims, num_classes,
|
| 251 |
+
t_obs=t_obs, t_fut=t_fut,
|
| 252 |
+
d_model=d_model, dropout=dropout)
|
| 253 |
+
if name == "futr":
|
| 254 |
+
return FUTRForecast(modality_dims, num_classes,
|
| 255 |
+
t_obs=t_obs, t_fut=t_fut,
|
| 256 |
+
d_model=d_model, dropout=dropout)
|
| 257 |
+
if name == "deepconvlstm":
|
| 258 |
+
return DeepConvLSTMForecast(modality_dims, num_classes,
|
| 259 |
+
t_obs=t_obs, t_fut=t_fut,
|
| 260 |
+
dropout=dropout)
|
| 261 |
+
if name in ("rulstm", "ru-lstm", "ru_lstm"):
|
| 262 |
+
return RULSTMForecast(modality_dims, num_classes,
|
| 263 |
+
t_obs=t_obs, t_fut=t_fut,
|
| 264 |
+
d_model=d_model, dropout=dropout)
|
| 265 |
+
if name == "avt":
|
| 266 |
+
return AVTForecast(modality_dims, num_classes,
|
| 267 |
+
t_obs=t_obs, t_fut=t_fut,
|
| 268 |
+
d_model=d_model, dropout=dropout)
|
| 269 |
+
raise ValueError(f"Unknown forecast model: {name!r}")
|
experiments/nets/models_forecast_priv.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Models for T8 v3 — privileged future-pressure conditioning.
|
| 2 |
+
|
| 3 |
+
Wraps the existing TransformerForecast (DAF) to accept future pressure as
|
| 4 |
+
side-channel context. The future pressure trajectory is encoded into T_fut
|
| 5 |
+
tokens that get appended to the past memory; future queries cross-attend
|
| 6 |
+
over the union (past sensors + future pressure). This is privileged
|
| 7 |
+
information (oracle) — at test time we'd not have future pressure — so
|
| 8 |
+
this is a hypothesis-test setup, not a deployable forecaster.
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
from typing import Dict
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class _PerModalityProj(nn.Module):
|
| 18 |
+
def __init__(self, modality_dims, d_model):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.proj = nn.ModuleDict({
|
| 21 |
+
m: nn.Linear(d, d_model) for m, d in modality_dims.items()
|
| 22 |
+
})
|
| 23 |
+
self.mod_emb = nn.Parameter(torch.zeros(len(modality_dims), d_model))
|
| 24 |
+
nn.init.trunc_normal_(self.mod_emb, std=0.02)
|
| 25 |
+
self.mods = list(modality_dims.keys())
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
out = None
|
| 29 |
+
for i, m in enumerate(self.mods):
|
| 30 |
+
h = self.proj[m](x[m]) + self.mod_emb[i]
|
| 31 |
+
out = h if out is None else out + h
|
| 32 |
+
return out / len(self.mods)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class DAFFuturePressure(nn.Module):
|
| 36 |
+
"""DAF backbone + future-pressure conditioning."""
|
| 37 |
+
|
| 38 |
+
def __init__(self, modality_dims: Dict[str, int], target_dim: int,
|
| 39 |
+
t_obs: int, t_fut: int, future_pressure_dim: int = 50,
|
| 40 |
+
d_model: int = 128, n_heads: int = 4, n_layers: int = 2,
|
| 41 |
+
dropout: float = 0.1):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.t_obs = t_obs
|
| 44 |
+
self.t_fut = t_fut
|
| 45 |
+
self.embed = _PerModalityProj(modality_dims, d_model)
|
| 46 |
+
self.pos = nn.Parameter(torch.zeros(1, t_obs, d_model))
|
| 47 |
+
nn.init.trunc_normal_(self.pos, std=0.02)
|
| 48 |
+
layer = nn.TransformerEncoderLayer(
|
| 49 |
+
d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
|
| 50 |
+
dropout=dropout, batch_first=True, activation="gelu",
|
| 51 |
+
)
|
| 52 |
+
self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)
|
| 53 |
+
# future-pressure encoder
|
| 54 |
+
self.fp_proj = nn.Linear(future_pressure_dim, d_model)
|
| 55 |
+
self.fp_pos = nn.Parameter(torch.zeros(1, t_fut, d_model))
|
| 56 |
+
nn.init.trunc_normal_(self.fp_pos, std=0.02)
|
| 57 |
+
self.fp_seg = nn.Parameter(torch.zeros(1, 1, d_model)) # segment id
|
| 58 |
+
nn.init.trunc_normal_(self.fp_seg, std=0.02)
|
| 59 |
+
# decoder side
|
| 60 |
+
self.queries = nn.Parameter(torch.zeros(1, t_fut, d_model))
|
| 61 |
+
nn.init.trunc_normal_(self.queries, std=0.02)
|
| 62 |
+
self.cross_attn = nn.MultiheadAttention(
|
| 63 |
+
d_model, n_heads, dropout=dropout, batch_first=True
|
| 64 |
+
)
|
| 65 |
+
self.norm = nn.LayerNorm(d_model)
|
| 66 |
+
self.head = nn.Linear(d_model, target_dim)
|
| 67 |
+
|
| 68 |
+
def forward(self, x: Dict[str, torch.Tensor],
|
| 69 |
+
future_pressure: torch.Tensor) -> torch.Tensor:
|
| 70 |
+
h_past = self.encoder(self.embed(x) + self.pos) # (B, T_obs, D)
|
| 71 |
+
h_fp = self.fp_proj(future_pressure) + self.fp_pos + self.fp_seg
|
| 72 |
+
memory = torch.cat([h_past, h_fp], dim=1) # (B, T_obs+T_fut, D)
|
| 73 |
+
q = self.queries.expand(memory.size(0), -1, -1) # (B, T_fut, D)
|
| 74 |
+
out, _ = self.cross_attn(q, memory, memory, need_weights=False)
|
| 75 |
+
out = self.norm(out)
|
| 76 |
+
return self.head(out) # (B, T_fut, target_dim)
|
experiments/nets/models_seqpred.py
ADDED
|
@@ -0,0 +1,806 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Models for T10 Triplet Next-Action Prediction.
|
| 3 |
+
|
| 4 |
+
Two classes live here:
|
| 5 |
+
|
| 6 |
+
* TripletHead — shared head module producing (verb_fine, verb_composite,
|
| 7 |
+
noun, hand) logits from a pooled feature vector.
|
| 8 |
+
* DeepConvLSTMTriplet — single-flow CNN+LSTM baseline (concatenates all
|
| 9 |
+
available modalities along the feature axis).
|
| 10 |
+
* DailyActFormer — our full-modality cross-modal Transformer that keeps
|
| 11 |
+
each modality in its own stem, fuses via a modality
|
| 12 |
+
token, and runs a causal temporal Transformer. Supports
|
| 13 |
+
the anticipatory auxiliary loss mentioned in the paper
|
| 14 |
+
plan (currently as a stub; enabled later in training).
|
| 15 |
+
|
| 16 |
+
All models take:
|
| 17 |
+
x: dict[mod_name -> (B, T, F_mod)]
|
| 18 |
+
mask: BoolTensor (B, T)
|
| 19 |
+
and return a dict:
|
| 20 |
+
{'verb_fine': (B, NUM_VERB_FINE),
|
| 21 |
+
'verb_composite': (B, NUM_VERB_COMPOSITE),
|
| 22 |
+
'noun': (B, NUM_NOUN),
|
| 23 |
+
'hand': (B, NUM_HAND)}
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import math
|
| 29 |
+
import sys
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
from typing import Dict, List, Optional, Sequence
|
| 32 |
+
|
| 33 |
+
import torch
|
| 34 |
+
import torch.nn as nn
|
| 35 |
+
import torch.nn.functional as F
|
| 36 |
+
|
| 37 |
+
# Importable from either (a) neurips26 root, or (b) frozen row/code/ folder.
|
| 38 |
+
_THIS = Path(__file__).resolve()
|
| 39 |
+
sys.path.insert(0, str(_THIS.parent))
|
| 40 |
+
sys.path.insert(0, str(_THIS.parent.parent))
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
from experiments.taxonomy import (
|
| 44 |
+
NUM_VERB_FINE, NUM_VERB_COMPOSITE, NUM_NOUN, NUM_HAND,
|
| 45 |
+
)
|
| 46 |
+
except ModuleNotFoundError:
|
| 47 |
+
from taxonomy import (
|
| 48 |
+
NUM_VERB_FINE, NUM_VERB_COMPOSITE, NUM_NOUN, NUM_HAND,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# ---------------------------------------------------------------------------
|
| 52 |
+
# Shared triplet head
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
|
| 55 |
+
class _PrevActionConcat(nn.Module):
|
| 56 |
+
"""Embeds the previous-segment (verb_composite, noun) ground-truth labels
|
| 57 |
+
and concatenates them to a pooled feature vector. Used by every model
|
| 58 |
+
when `use_prev_action=True`. The +1 vocab slot is the BOS / no-prev
|
| 59 |
+
sentinel emitted by the dataset for the first kept segment of each
|
| 60 |
+
recording. Output dim added to pooled = 2 * prev_emb_dim."""
|
| 61 |
+
|
| 62 |
+
def __init__(self, prev_emb_dim: int = 32):
|
| 63 |
+
super().__init__()
|
| 64 |
+
from taxonomy import NUM_VERB_COMPOSITE as _NVC, NUM_NOUN as _NN # noqa
|
| 65 |
+
self.vc_emb = nn.Embedding(_NVC + 1, prev_emb_dim)
|
| 66 |
+
self.n_emb = nn.Embedding(_NN + 1, prev_emb_dim)
|
| 67 |
+
self.out_dim = 2 * prev_emb_dim
|
| 68 |
+
|
| 69 |
+
def forward(self, pooled: torch.Tensor,
|
| 70 |
+
prev_v_comp: Optional[torch.Tensor] = None,
|
| 71 |
+
prev_noun: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 72 |
+
if prev_v_comp is None or prev_noun is None:
|
| 73 |
+
B = pooled.size(0)
|
| 74 |
+
prev_v_comp = torch.full((B,), self.vc_emb.num_embeddings - 1,
|
| 75 |
+
dtype=torch.long, device=pooled.device)
|
| 76 |
+
prev_noun = torch.full((B,), self.n_emb.num_embeddings - 1,
|
| 77 |
+
dtype=torch.long, device=pooled.device)
|
| 78 |
+
pe = torch.cat([self.vc_emb(prev_v_comp), self.n_emb(prev_noun)], dim=-1)
|
| 79 |
+
return torch.cat([pooled, pe], dim=-1)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class TripletHead(nn.Module):
|
| 83 |
+
def __init__(self, feat_dim: int, hidden: int = 256, dropout: float = 0.2):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.norm = nn.LayerNorm(feat_dim)
|
| 86 |
+
self.trunk = nn.Sequential(
|
| 87 |
+
nn.Linear(feat_dim, hidden),
|
| 88 |
+
nn.GELU(),
|
| 89 |
+
nn.Dropout(dropout),
|
| 90 |
+
)
|
| 91 |
+
self.verb_fine = nn.Linear(hidden, NUM_VERB_FINE)
|
| 92 |
+
self.verb_composite = nn.Linear(hidden, NUM_VERB_COMPOSITE)
|
| 93 |
+
self.noun = nn.Linear(hidden, NUM_NOUN)
|
| 94 |
+
self.hand = nn.Linear(hidden, NUM_HAND)
|
| 95 |
+
|
| 96 |
+
def forward(self, feat: torch.Tensor) -> Dict[str, torch.Tensor]:
|
| 97 |
+
h = self.trunk(self.norm(feat))
|
| 98 |
+
return {
|
| 99 |
+
"verb_fine": self.verb_fine(h),
|
| 100 |
+
"verb_composite": self.verb_composite(h),
|
| 101 |
+
"noun": self.noun(h),
|
| 102 |
+
"hand": self.hand(h),
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _masked_mean_pool(h: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
| 107 |
+
"""Mean over the time axis of `h` (B, T, D) using a boolean mask (B, T)."""
|
| 108 |
+
m = mask.to(h.dtype).unsqueeze(-1)
|
| 109 |
+
return (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# ---------------------------------------------------------------------------
|
| 113 |
+
# Baseline: DeepConvLSTM (Ordonez & Roggen 2016) adapted for triplet prediction
|
| 114 |
+
# ---------------------------------------------------------------------------
|
| 115 |
+
|
| 116 |
+
class DeepConvLSTMTriplet(nn.Module):
|
| 117 |
+
"""Single-flow CNN+LSTM. Concatenates per-modality features on F axis."""
|
| 118 |
+
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
modality_dims: Dict[str, int],
|
| 122 |
+
conv_filters: int = 64,
|
| 123 |
+
conv_kernel: int = 5,
|
| 124 |
+
num_conv_layers: int = 4,
|
| 125 |
+
lstm_hidden: int = 128,
|
| 126 |
+
num_lstm_layers: int = 2,
|
| 127 |
+
dropout: float = 0.2,
|
| 128 |
+
head_hidden: int = 256,
|
| 129 |
+
use_prev_action: bool = False,
|
| 130 |
+
prev_emb_dim: int = 32,
|
| 131 |
+
):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.modality_dims = dict(modality_dims)
|
| 134 |
+
self.use_prev_action = use_prev_action
|
| 135 |
+
in_ch = sum(modality_dims.values())
|
| 136 |
+
|
| 137 |
+
convs: List[nn.Module] = []
|
| 138 |
+
c = in_ch
|
| 139 |
+
for i in range(num_conv_layers):
|
| 140 |
+
convs.append(nn.Sequential(
|
| 141 |
+
nn.Conv1d(c, conv_filters, conv_kernel, padding=conv_kernel // 2),
|
| 142 |
+
nn.BatchNorm1d(conv_filters),
|
| 143 |
+
nn.ReLU(),
|
| 144 |
+
nn.Dropout(dropout if i < num_conv_layers - 1 else dropout + 0.1),
|
| 145 |
+
))
|
| 146 |
+
c = conv_filters
|
| 147 |
+
self.convs = nn.Sequential(*convs)
|
| 148 |
+
|
| 149 |
+
self.lstm = nn.LSTM(
|
| 150 |
+
conv_filters, lstm_hidden, num_layers=num_lstm_layers,
|
| 151 |
+
batch_first=True, bidirectional=False,
|
| 152 |
+
dropout=dropout if num_lstm_layers > 1 else 0.0,
|
| 153 |
+
)
|
| 154 |
+
head_in = lstm_hidden
|
| 155 |
+
if use_prev_action:
|
| 156 |
+
self.prev_concat = _PrevActionConcat(prev_emb_dim)
|
| 157 |
+
head_in += self.prev_concat.out_dim
|
| 158 |
+
else:
|
| 159 |
+
self.prev_concat = None
|
| 160 |
+
self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout)
|
| 161 |
+
|
| 162 |
+
def forward(
|
| 163 |
+
self, x: Dict[str, torch.Tensor], mask: torch.Tensor,
|
| 164 |
+
prev_v_comp: Optional[torch.Tensor] = None,
|
| 165 |
+
prev_noun: Optional[torch.Tensor] = None,
|
| 166 |
+
) -> Dict[str, torch.Tensor]:
|
| 167 |
+
feats = torch.cat([x[m] for m in x], dim=-1).transpose(1, 2)
|
| 168 |
+
feats = self.convs(feats).transpose(1, 2)
|
| 169 |
+
out, (h_n, _) = self.lstm(feats)
|
| 170 |
+
pooled = h_n[-1]
|
| 171 |
+
if self.use_prev_action:
|
| 172 |
+
pooled = self.prev_concat(pooled, prev_v_comp, prev_noun)
|
| 173 |
+
return self.head(pooled)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# ---------------------------------------------------------------------------
|
| 177 |
+
# Our model: DailyActFormer
|
| 178 |
+
# ---------------------------------------------------------------------------
|
| 179 |
+
|
| 180 |
+
class _ModalityStem(nn.Module):
|
| 181 |
+
"""Multi-scale 1-D conv stem (kernels 3, 5, 9) per modality.
|
| 182 |
+
|
| 183 |
+
Borrowed from HandFormer (the top-1 baseline on T10 recognition): three
|
| 184 |
+
parallel convolutions capture fast (k=3, ~0.15s @ 20Hz), medium (k=5),
|
| 185 |
+
and slow (k=9, ~0.45s) temporal patterns. Output is a 1×1 fusion of
|
| 186 |
+
the three branches, projected back to d_model.
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
def __init__(self, in_dim: int, d_model: int, kernels=(3, 5, 9),
|
| 190 |
+
dropout: float = 0.1):
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.kernels = kernels
|
| 193 |
+
self.branches = nn.ModuleList([
|
| 194 |
+
nn.Conv1d(in_dim, d_model, k, padding=k // 2) for k in kernels
|
| 195 |
+
])
|
| 196 |
+
self.merge = nn.Sequential(
|
| 197 |
+
nn.GELU(),
|
| 198 |
+
nn.Conv1d(d_model * len(kernels), d_model, 1),
|
| 199 |
+
)
|
| 200 |
+
self.norm = nn.LayerNorm(d_model)
|
| 201 |
+
self.drop = nn.Dropout(dropout)
|
| 202 |
+
|
| 203 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 204 |
+
# x: (B, T, F_in) -> (B, F_in, T) for conv1d
|
| 205 |
+
z = x.transpose(1, 2)
|
| 206 |
+
multi = [c(z) for c in self.branches] # each (B, D, T)
|
| 207 |
+
h = self.merge(torch.cat(multi, dim=1)).transpose(1, 2) # (B, T, D)
|
| 208 |
+
return self.drop(self.norm(h))
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class _QueryPool(nn.Module):
|
| 212 |
+
"""Learnable-query cross-attention pooling (replaces mean pool).
|
| 213 |
+
|
| 214 |
+
Inspired by FUTR (the top-5 baseline winner): a single learnable query
|
| 215 |
+
cross-attends to the entire encoder output, producing one summary vector.
|
| 216 |
+
Compared to a plain mean pool this lets the model weight informative
|
| 217 |
+
frames more heavily.
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
def __init__(self, d_model: int, n_heads: int = 4, dropout: float = 0.1):
|
| 221 |
+
super().__init__()
|
| 222 |
+
self.q = nn.Parameter(torch.zeros(1, 1, d_model))
|
| 223 |
+
nn.init.trunc_normal_(self.q, std=0.02)
|
| 224 |
+
self.attn = nn.MultiheadAttention(
|
| 225 |
+
d_model, n_heads, dropout=dropout, batch_first=True,
|
| 226 |
+
)
|
| 227 |
+
self.norm = nn.LayerNorm(d_model)
|
| 228 |
+
|
| 229 |
+
def forward(self, h: torch.Tensor, key_padding_mask: Optional[torch.Tensor]):
|
| 230 |
+
# h: (B, T, D); key_padding_mask: (B, T) where True = pad-to-mask-out
|
| 231 |
+
B = h.size(0)
|
| 232 |
+
q = self.q.expand(B, -1, -1)
|
| 233 |
+
out, _ = self.attn(q, h, h, key_padding_mask=key_padding_mask,
|
| 234 |
+
need_weights=False)
|
| 235 |
+
return self.norm(out.squeeze(1))
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class _CrossModalTemporalShift(nn.Module):
|
| 239 |
+
"""Cross-modal temporal-shift attention between two modalities.
|
| 240 |
+
|
| 241 |
+
Motivation (paper case study, §sec:grasp-phase-main): EMG activation leads
|
| 242 |
+
motion onset by a sub-frame ~20ms in our 100Hz recordings. After the 5x
|
| 243 |
+
downsample to 20Hz, that lag is ~0.4 frames, but per-subject variability
|
| 244 |
+
plus slack in our segment annotations introduces a few frames of drift
|
| 245 |
+
that a fixed alignment cannot capture.
|
| 246 |
+
|
| 247 |
+
We learn a discrete temporal shift Δ ∈ {-max_shift, …, +max_shift} frames
|
| 248 |
+
applied to one of the two modalities (EMG by default), so the shifted
|
| 249 |
+
tokens align with the other branch (MoCap) before cross-modal fusion. The
|
| 250 |
+
shift is sampled via straight-through Gumbel-softmax during training; at
|
| 251 |
+
inference we take the argmax (deterministic).
|
| 252 |
+
|
| 253 |
+
Inputs are per-modality token sequences (B, T, D). Outputs the same shape.
|
| 254 |
+
Only the `shift_modality` branch is shifted; other modalities pass through.
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
def __init__(self, max_shift: int = 3, tau: float = 1.0):
|
| 258 |
+
super().__init__()
|
| 259 |
+
self.max_shift = max_shift
|
| 260 |
+
self.tau = tau
|
| 261 |
+
# Logits over 2*max_shift+1 categorical shift candidates.
|
| 262 |
+
self.shift_logits = nn.Parameter(torch.zeros(2 * max_shift + 1))
|
| 263 |
+
|
| 264 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 265 |
+
# x: (B, T, D); produce a shifted version that's a soft-blend over
|
| 266 |
+
# the shift dimension. Hard at inference, gumbel-softmax at training.
|
| 267 |
+
if self.training:
|
| 268 |
+
w = F.gumbel_softmax(self.shift_logits, tau=self.tau, hard=True, dim=-1)
|
| 269 |
+
else:
|
| 270 |
+
w = F.one_hot(self.shift_logits.argmax(),
|
| 271 |
+
num_classes=2 * self.max_shift + 1).float()
|
| 272 |
+
shifted = []
|
| 273 |
+
for i, s in enumerate(range(-self.max_shift, self.max_shift + 1)):
|
| 274 |
+
shifted.append(w[i] * torch.roll(x, shifts=s, dims=1))
|
| 275 |
+
return torch.stack(shifted, dim=0).sum(dim=0)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class _CausalTransformerBlock(nn.Module):
|
| 279 |
+
"""Standard Transformer encoder block with a strictly causal attention mask."""
|
| 280 |
+
|
| 281 |
+
def __init__(self, d_model: int, n_heads: int, mlp_ratio: float = 4.0,
|
| 282 |
+
dropout: float = 0.1):
|
| 283 |
+
super().__init__()
|
| 284 |
+
self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout,
|
| 285 |
+
batch_first=True)
|
| 286 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 287 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 288 |
+
mlp_dim = int(d_model * mlp_ratio)
|
| 289 |
+
self.mlp = nn.Sequential(
|
| 290 |
+
nn.Linear(d_model, mlp_dim), nn.GELU(), nn.Dropout(dropout),
|
| 291 |
+
nn.Linear(mlp_dim, d_model), nn.Dropout(dropout),
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor,
|
| 295 |
+
key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor:
|
| 296 |
+
h = self.norm1(x)
|
| 297 |
+
h, _ = self.attn(h, h, h, attn_mask=attn_mask,
|
| 298 |
+
key_padding_mask=key_padding_mask, need_weights=False)
|
| 299 |
+
x = x + h
|
| 300 |
+
x = x + self.mlp(self.norm2(x))
|
| 301 |
+
return x
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class DailyActFormer(nn.Module):
|
| 305 |
+
"""Cross-modal Transformer that uses every available modality.
|
| 306 |
+
|
| 307 |
+
Architecture outline:
|
| 308 |
+
per-modality stem → learnable modality embedding →
|
| 309 |
+
concat across time (each frame -> M modality tokens) →
|
| 310 |
+
1 fusion-layer cross-modal attention (compress M→1 per frame) →
|
| 311 |
+
temporal Transformer (bidirectional by default; causal when
|
| 312 |
+
`causal=True` for anticipation-style next-action prediction)
|
| 313 |
+
→ pooled → TripletHead
|
| 314 |
+
|
| 315 |
+
For simplicity the fusion step is an attention pooling with learnable
|
| 316 |
+
queries, rather than a full cross-modal block. This keeps the parameter
|
| 317 |
+
count modest (2–4 M range with d_model=128).
|
| 318 |
+
"""
|
| 319 |
+
|
| 320 |
+
def __init__(
|
| 321 |
+
self,
|
| 322 |
+
modality_dims: Dict[str, int],
|
| 323 |
+
d_model: int = 128,
|
| 324 |
+
n_layers: int = 4,
|
| 325 |
+
n_heads: int = 4,
|
| 326 |
+
dropout: float = 0.1,
|
| 327 |
+
head_hidden: int = 256,
|
| 328 |
+
max_T: int = 256,
|
| 329 |
+
causal: bool = False,
|
| 330 |
+
xshift_modality: Optional[str] = "emg",
|
| 331 |
+
xshift_max: int = 3,
|
| 332 |
+
use_prev_action: bool = False,
|
| 333 |
+
prev_emb_dim: int = 32,
|
| 334 |
+
):
|
| 335 |
+
super().__init__()
|
| 336 |
+
self.modalities = list(modality_dims.keys())
|
| 337 |
+
self.causal = causal
|
| 338 |
+
self.use_prev_action = use_prev_action
|
| 339 |
+
|
| 340 |
+
# Prev-action concat (shared helper)
|
| 341 |
+
if use_prev_action:
|
| 342 |
+
self.prev_concat = _PrevActionConcat(prev_emb_dim)
|
| 343 |
+
self._prev_extra_dim = self.prev_concat.out_dim
|
| 344 |
+
else:
|
| 345 |
+
self.prev_concat = None
|
| 346 |
+
self._prev_extra_dim = 0
|
| 347 |
+
|
| 348 |
+
# 0) Cross-modal temporal-shift block on one branch (EMG by default).
|
| 349 |
+
# Disabled if `xshift_modality` is None or not present.
|
| 350 |
+
if xshift_modality is not None and xshift_modality in modality_dims:
|
| 351 |
+
self.xshift_modality = xshift_modality
|
| 352 |
+
self.xshift = _CrossModalTemporalShift(max_shift=xshift_max)
|
| 353 |
+
else:
|
| 354 |
+
self.xshift_modality = None
|
| 355 |
+
self.xshift = None
|
| 356 |
+
|
| 357 |
+
# 1) per-modality 1-D conv stems (each produces d_model features/frame)
|
| 358 |
+
self.stems = nn.ModuleDict({
|
| 359 |
+
m: _ModalityStem(F, d_model, dropout=dropout)
|
| 360 |
+
for m, F in modality_dims.items()
|
| 361 |
+
})
|
| 362 |
+
|
| 363 |
+
# 2) modality embedding (broadcast-add to per-modality tokens)
|
| 364 |
+
self.modality_embed = nn.Parameter(
|
| 365 |
+
torch.zeros(len(self.modalities), d_model)
|
| 366 |
+
)
|
| 367 |
+
nn.init.trunc_normal_(self.modality_embed, std=0.02)
|
| 368 |
+
|
| 369 |
+
# 3) per-frame cross-modal fusion: use a single learnable query token
|
| 370 |
+
self.fusion_q = nn.Parameter(torch.zeros(1, 1, d_model))
|
| 371 |
+
self.fusion_kv = nn.LayerNorm(d_model)
|
| 372 |
+
self.fusion_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
|
| 373 |
+
|
| 374 |
+
# 4) positional embedding along time (post-fusion)
|
| 375 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, max_T, d_model))
|
| 376 |
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
| 377 |
+
self.max_T = max_T
|
| 378 |
+
|
| 379 |
+
# 5) causal temporal Transformer
|
| 380 |
+
self.temporal_norm = nn.LayerNorm(d_model)
|
| 381 |
+
self.temporal = nn.ModuleList([
|
| 382 |
+
_CausalTransformerBlock(d_model, n_heads, dropout=dropout)
|
| 383 |
+
for _ in range(n_layers)
|
| 384 |
+
])
|
| 385 |
+
|
| 386 |
+
# 6) Pool: learnable-query cross-attention (replaces mean pool, FUTR-style)
|
| 387 |
+
self.pool = _QueryPool(d_model, n_heads=n_heads, dropout=dropout)
|
| 388 |
+
|
| 389 |
+
# 7) triplet head: input dim = d_model + (optional prev-action embed)
|
| 390 |
+
head_in = d_model + self._prev_extra_dim
|
| 391 |
+
self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout)
|
| 392 |
+
|
| 393 |
+
nn.init.trunc_normal_(self.fusion_q, std=0.02)
|
| 394 |
+
|
| 395 |
+
# ---- helpers ----
|
| 396 |
+
def _causal_mask(self, T: int, device) -> torch.Tensor:
|
| 397 |
+
# MultiheadAttention wants additive mask with -inf above diag.
|
| 398 |
+
m = torch.full((T, T), float("-inf"), device=device)
|
| 399 |
+
m.triu_(diagonal=1)
|
| 400 |
+
return m
|
| 401 |
+
|
| 402 |
+
# ---- forward ----
|
| 403 |
+
def forward(
|
| 404 |
+
self, x: Dict[str, torch.Tensor], mask: torch.Tensor,
|
| 405 |
+
prev_v_comp: Optional[torch.Tensor] = None,
|
| 406 |
+
prev_noun: Optional[torch.Tensor] = None,
|
| 407 |
+
return_features: bool = False,
|
| 408 |
+
) -> Dict[str, torch.Tensor]:
|
| 409 |
+
# Stems: per-modality token streams
|
| 410 |
+
stem_tokens: List[torch.Tensor] = []
|
| 411 |
+
mods_in = [m for m in self.modalities if m in x]
|
| 412 |
+
if not mods_in:
|
| 413 |
+
raise ValueError("No modality from the model signature was provided.")
|
| 414 |
+
for i, m in enumerate(mods_in):
|
| 415 |
+
h = self.stems[m](x[m]) # (B, T, D)
|
| 416 |
+
# Cross-modal temporal shift: apply to one branch (e.g. EMG) so it
|
| 417 |
+
# aligns with the others before fusion. Implements paper SyncFuse's
|
| 418 |
+
# main novelty (sub-frame anticipatory coupling between EMG/MoCap).
|
| 419 |
+
if self.xshift is not None and m == self.xshift_modality:
|
| 420 |
+
h = self.xshift(h)
|
| 421 |
+
h = h + self.modality_embed[self.modalities.index(m)]
|
| 422 |
+
stem_tokens.append(h)
|
| 423 |
+
|
| 424 |
+
# Cross-modal fusion: per-frame, attend learnable query over the M stacked
|
| 425 |
+
# modality tokens. Output is (B, T, D).
|
| 426 |
+
B, T, D = stem_tokens[0].shape
|
| 427 |
+
# stack -> (B, T, M, D) -> reshape as (B*T, M, D)
|
| 428 |
+
stacked = torch.stack(stem_tokens, dim=2) # (B, T, M, D)
|
| 429 |
+
M = stacked.size(2)
|
| 430 |
+
stacked = stacked.reshape(B * T, M, D)
|
| 431 |
+
kv = self.fusion_kv(stacked)
|
| 432 |
+
q = self.fusion_q.expand(B * T, -1, -1)
|
| 433 |
+
fused, _ = self.fusion_attn(q, kv, kv, need_weights=False)
|
| 434 |
+
fused = fused.reshape(B, T, D) # (B, T, D)
|
| 435 |
+
|
| 436 |
+
# Positional embedding + causal temporal Transformer
|
| 437 |
+
if T > self.max_T:
|
| 438 |
+
raise ValueError(f"T={T} exceeds max_T={self.max_T}")
|
| 439 |
+
h = fused + self.pos_embed[:, :T, :]
|
| 440 |
+
h = self.temporal_norm(h)
|
| 441 |
+
|
| 442 |
+
attn_mask = self._causal_mask(T, h.device) if self.causal else None
|
| 443 |
+
key_padding = ~mask if mask is not None else None
|
| 444 |
+
for block in self.temporal:
|
| 445 |
+
h = block(h, attn_mask=attn_mask, key_padding_mask=key_padding)
|
| 446 |
+
|
| 447 |
+
# Pool: learnable-query cross-attention (FUTR-style) over valid frames
|
| 448 |
+
pooled = self.pool(h, key_padding_mask=key_padding)
|
| 449 |
+
|
| 450 |
+
# Optional: condition on previous segment's labels
|
| 451 |
+
if self.use_prev_action:
|
| 452 |
+
pooled = self.prev_concat(pooled, prev_v_comp, prev_noun)
|
| 453 |
+
|
| 454 |
+
logits = self.head(pooled)
|
| 455 |
+
if return_features:
|
| 456 |
+
logits["_pooled"] = pooled
|
| 457 |
+
return logits
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
# ===========================================================================
|
| 461 |
+
# Published baselines, sensor-adapted. Each keeps the original paper's key
|
| 462 |
+
# idea (rolling+unrolling LSTM for RULSTM, causal encoder–decoder for FUTR,
|
| 463 |
+
# early modality-token fusion for AFFT, etc.) but swaps the RGB/feature input
|
| 464 |
+
# for our multimodal sensor streams, and the classification head for our
|
| 465 |
+
# shared TripletHead.
|
| 466 |
+
# ===========================================================================
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
# ---------------------------------------------------------------------------
|
| 470 |
+
# RULSTM (Furnari & Farinella, TPAMI 2020) — sensor-adapted
|
| 471 |
+
# Per-modality rolling LSTM summarises the past, a second unrolling LSTM
|
| 472 |
+
# takes R-LSTM state and walks `future_steps` steps forward to mimic
|
| 473 |
+
# anticipation without needing future sensor data. Fusion is late: each
|
| 474 |
+
# modality produces logits, we average them.
|
| 475 |
+
# ---------------------------------------------------------------------------
|
| 476 |
+
|
| 477 |
+
class _RULSTMBranch(nn.Module):
|
| 478 |
+
def __init__(self, in_dim: int, hidden: int, future_steps: int,
|
| 479 |
+
dropout: float = 0.2):
|
| 480 |
+
super().__init__()
|
| 481 |
+
self.future_steps = future_steps
|
| 482 |
+
self.rolling = nn.LSTM(in_dim, hidden, batch_first=True)
|
| 483 |
+
self.unrolling = nn.LSTMCell(hidden, hidden)
|
| 484 |
+
self.drop = nn.Dropout(dropout)
|
| 485 |
+
self.out_dim = hidden
|
| 486 |
+
|
| 487 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
| 488 |
+
# x: (B, T, F_in), mask: (B, T)
|
| 489 |
+
# Pack-free: LSTM on padded sequences is fine since we pool from h_n.
|
| 490 |
+
_, (h_n, c_n) = self.rolling(x) # (1, B, H)
|
| 491 |
+
h = h_n.squeeze(0); c = c_n.squeeze(0)
|
| 492 |
+
inp = h
|
| 493 |
+
for _ in range(self.future_steps):
|
| 494 |
+
h, c = self.unrolling(inp, (h, c))
|
| 495 |
+
inp = h
|
| 496 |
+
return self.drop(h)
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
class RULSTMTriplet(nn.Module):
|
| 500 |
+
def __init__(self, modality_dims: Dict[str, int], hidden: int = 128,
|
| 501 |
+
future_steps: int = 8, dropout: float = 0.2,
|
| 502 |
+
head_hidden: int = 256,
|
| 503 |
+
use_prev_action: bool = False, prev_emb_dim: int = 32):
|
| 504 |
+
super().__init__()
|
| 505 |
+
self.use_prev_action = use_prev_action
|
| 506 |
+
self.branches = nn.ModuleDict({
|
| 507 |
+
m: _RULSTMBranch(F, hidden, future_steps, dropout)
|
| 508 |
+
for m, F in modality_dims.items()
|
| 509 |
+
})
|
| 510 |
+
head_in = hidden
|
| 511 |
+
if use_prev_action:
|
| 512 |
+
self.prev_concat = _PrevActionConcat(prev_emb_dim)
|
| 513 |
+
head_in += self.prev_concat.out_dim
|
| 514 |
+
else:
|
| 515 |
+
self.prev_concat = None
|
| 516 |
+
self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout)
|
| 517 |
+
|
| 518 |
+
def forward(self, x, mask, prev_v_comp=None, prev_noun=None):
|
| 519 |
+
feats = []
|
| 520 |
+
for m in x:
|
| 521 |
+
feats.append(self.branches[m](x[m], mask))
|
| 522 |
+
fused = torch.stack(feats, dim=0).mean(dim=0)
|
| 523 |
+
if self.use_prev_action:
|
| 524 |
+
fused = self.prev_concat(fused, prev_v_comp, prev_noun)
|
| 525 |
+
return self.head(fused)
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
# ---------------------------------------------------------------------------
|
| 529 |
+
# FUTR (Gong et al., CVPR 2022) — sensor-adapted
|
| 530 |
+
# Transformer encoder over observation frames (with per-frame feature from
|
| 531 |
+
# concat(modalities)). A decoder query attends over the encoder memory to
|
| 532 |
+
# produce a single future-action embedding which is fed into the triplet
|
| 533 |
+
# head. No autoregressive decoding — we only predict 1 target segment.
|
| 534 |
+
# ---------------------------------------------------------------------------
|
| 535 |
+
|
| 536 |
+
class FUTRTriplet(nn.Module):
|
| 537 |
+
def __init__(self, modality_dims: Dict[str, int], d_model: int = 128,
|
| 538 |
+
n_heads: int = 4, n_layers: int = 3, dropout: float = 0.1,
|
| 539 |
+
head_hidden: int = 256, max_T: int = 256,
|
| 540 |
+
use_prev_action: bool = False, prev_emb_dim: int = 32):
|
| 541 |
+
super().__init__()
|
| 542 |
+
self.use_prev_action = use_prev_action
|
| 543 |
+
in_dim = sum(modality_dims.values())
|
| 544 |
+
self.in_proj = nn.Linear(in_dim, d_model)
|
| 545 |
+
self.pos = nn.Parameter(torch.zeros(1, max_T, d_model))
|
| 546 |
+
nn.init.trunc_normal_(self.pos, std=0.02)
|
| 547 |
+
self.max_T = max_T
|
| 548 |
+
|
| 549 |
+
enc_layer = nn.TransformerEncoderLayer(
|
| 550 |
+
d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
|
| 551 |
+
dropout=dropout, batch_first=True, activation="gelu",
|
| 552 |
+
)
|
| 553 |
+
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
|
| 554 |
+
|
| 555 |
+
self.future_q = nn.Parameter(torch.zeros(1, 1, d_model))
|
| 556 |
+
nn.init.trunc_normal_(self.future_q, std=0.02)
|
| 557 |
+
self.cross_attn = nn.MultiheadAttention(
|
| 558 |
+
d_model, n_heads, dropout=dropout, batch_first=True,
|
| 559 |
+
)
|
| 560 |
+
head_in = d_model
|
| 561 |
+
if use_prev_action:
|
| 562 |
+
self.prev_concat = _PrevActionConcat(prev_emb_dim)
|
| 563 |
+
head_in += self.prev_concat.out_dim
|
| 564 |
+
else:
|
| 565 |
+
self.prev_concat = None
|
| 566 |
+
self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout)
|
| 567 |
+
|
| 568 |
+
def forward(self, x, mask, prev_v_comp=None, prev_noun=None):
|
| 569 |
+
feats = torch.cat([x[m] for m in x], dim=-1)
|
| 570 |
+
B, T, _ = feats.shape
|
| 571 |
+
if T > self.max_T:
|
| 572 |
+
raise ValueError(f"T={T} exceeds FUTR max_T={self.max_T}")
|
| 573 |
+
h = self.in_proj(feats) + self.pos[:, :T, :]
|
| 574 |
+
h = self.encoder(h, src_key_padding_mask=~mask)
|
| 575 |
+
q = self.future_q.expand(B, -1, -1)
|
| 576 |
+
out, _ = self.cross_attn(q, h, h, key_padding_mask=~mask,
|
| 577 |
+
need_weights=False)
|
| 578 |
+
pooled = out.squeeze(1)
|
| 579 |
+
if self.use_prev_action:
|
| 580 |
+
pooled = self.prev_concat(pooled, prev_v_comp, prev_noun)
|
| 581 |
+
return self.head(pooled)
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
# ---------------------------------------------------------------------------
|
| 585 |
+
# AFFT (Zhong et al., WACV 2023) — sensor-adapted
|
| 586 |
+
# Per-modality tokens (one per frame per modality) are concatenated into a
|
| 587 |
+
# long token sequence of length T*M and passed through an encoder with
|
| 588 |
+
# causal temporal attention so the model must anticipate strictly from the
|
| 589 |
+
# past. Fusion happens "anticipatively" inside the attention.
|
| 590 |
+
# ---------------------------------------------------------------------------
|
| 591 |
+
|
| 592 |
+
class AFFTTriplet(nn.Module):
|
| 593 |
+
def __init__(self, modality_dims: Dict[str, int], d_model: int = 96,
|
| 594 |
+
n_heads: int = 4, n_layers: int = 3, dropout: float = 0.1,
|
| 595 |
+
head_hidden: int = 256, max_T: int = 256,
|
| 596 |
+
use_prev_action: bool = False, prev_emb_dim: int = 32):
|
| 597 |
+
super().__init__()
|
| 598 |
+
self.use_prev_action = use_prev_action
|
| 599 |
+
self.modalities = list(modality_dims.keys())
|
| 600 |
+
self.stems = nn.ModuleDict({
|
| 601 |
+
m: nn.Linear(F, d_model) for m, F in modality_dims.items()
|
| 602 |
+
})
|
| 603 |
+
self.mod_embed = nn.Parameter(
|
| 604 |
+
torch.zeros(len(self.modalities), d_model)
|
| 605 |
+
)
|
| 606 |
+
nn.init.trunc_normal_(self.mod_embed, std=0.02)
|
| 607 |
+
self.pos = nn.Parameter(torch.zeros(1, max_T, d_model))
|
| 608 |
+
nn.init.trunc_normal_(self.pos, std=0.02)
|
| 609 |
+
self.max_T = max_T
|
| 610 |
+
self.d_model = d_model
|
| 611 |
+
|
| 612 |
+
self.blocks = nn.ModuleList([
|
| 613 |
+
_CausalTransformerBlock(d_model, n_heads, dropout=dropout)
|
| 614 |
+
for _ in range(n_layers)
|
| 615 |
+
])
|
| 616 |
+
head_in = d_model
|
| 617 |
+
if use_prev_action:
|
| 618 |
+
self.prev_concat = _PrevActionConcat(prev_emb_dim)
|
| 619 |
+
head_in += self.prev_concat.out_dim
|
| 620 |
+
else:
|
| 621 |
+
self.prev_concat = None
|
| 622 |
+
self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout)
|
| 623 |
+
|
| 624 |
+
def _expand_causal_mask(self, T: int, M: int, device) -> torch.Tensor:
|
| 625 |
+
# Token layout: [m0_t0, m1_t0, ..., mM_t0, m0_t1, ..., mM_t(T-1)]
|
| 626 |
+
# Token at (m, t) can attend to all (m', t') with t' <= t.
|
| 627 |
+
ts = torch.arange(T, device=device).unsqueeze(1).expand(-1, M).reshape(-1)
|
| 628 |
+
return ts[:, None] < ts[None, :] # True where future (mask out)
|
| 629 |
+
|
| 630 |
+
def forward(self, x, mask, prev_v_comp=None, prev_noun=None):
|
| 631 |
+
# Build per-frame token streams.
|
| 632 |
+
mods = [m for m in self.modalities if m in x]
|
| 633 |
+
per_mod_tokens = []
|
| 634 |
+
B, T, _ = x[mods[0]].shape
|
| 635 |
+
for i, m in enumerate(mods):
|
| 636 |
+
h = self.stems[m](x[m]) + self.mod_embed[self.modalities.index(m)]
|
| 637 |
+
per_mod_tokens.append(h)
|
| 638 |
+
stacked = torch.stack(per_mod_tokens, dim=2)
|
| 639 |
+
M = stacked.size(2)
|
| 640 |
+
tokens = stacked.reshape(B, T * M, self.d_model)
|
| 641 |
+
if T > self.max_T:
|
| 642 |
+
raise ValueError(f"T={T} exceeds AFFT max_T={self.max_T}")
|
| 643 |
+
pos_per_frame = self.pos[:, :T, :].unsqueeze(2).expand(-1, -1, M, -1)
|
| 644 |
+
tokens = tokens + pos_per_frame.reshape(1, T * M, self.d_model)
|
| 645 |
+
attn_mask = self._expand_causal_mask(T, M, tokens.device)
|
| 646 |
+
attn_mask = torch.where(attn_mask, torch.tensor(float("-inf"),
|
| 647 |
+
device=tokens.device),
|
| 648 |
+
torch.tensor(0.0, device=tokens.device))
|
| 649 |
+
kp = (~mask).unsqueeze(2).expand(-1, -1, M).reshape(B, T * M)
|
| 650 |
+
for blk in self.blocks:
|
| 651 |
+
tokens = blk(tokens, attn_mask=attn_mask, key_padding_mask=kp)
|
| 652 |
+
last_slice = tokens[:, -M:, :]
|
| 653 |
+
pooled = last_slice.mean(dim=1)
|
| 654 |
+
if self.use_prev_action:
|
| 655 |
+
pooled = self.prev_concat(pooled, prev_v_comp, prev_noun)
|
| 656 |
+
return self.head(pooled)
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
# ---------------------------------------------------------------------------
|
| 660 |
+
# HandFormer (Shamil et al., ECCV 2024) — sensor-adapted
|
| 661 |
+
# Originally on 3D hand poses. We feed it only the MoCap modality (which
|
| 662 |
+
# contains 10 fingertip joints). Multi-scale 1-D conv over time, followed
|
| 663 |
+
# by a Transformer. If MoCap is not in `modalities`, falls back to whatever
|
| 664 |
+
# is provided (but then it's no longer the paper's "pose-only" setup).
|
| 665 |
+
# ---------------------------------------------------------------------------
|
| 666 |
+
|
| 667 |
+
class HandFormerTriplet(nn.Module):
|
| 668 |
+
def __init__(self, modality_dims: Dict[str, int], d_model: int = 128,
|
| 669 |
+
n_heads: int = 4, n_layers: int = 3, kernels=(3, 5, 9),
|
| 670 |
+
dropout: float = 0.1, head_hidden: int = 256, max_T: int = 256,
|
| 671 |
+
use_prev_action: bool = False, prev_emb_dim: int = 32):
|
| 672 |
+
super().__init__()
|
| 673 |
+
self.use_prev_action = use_prev_action
|
| 674 |
+
in_dim = sum(modality_dims.values())
|
| 675 |
+
self.multi_conv = nn.ModuleList([
|
| 676 |
+
nn.Conv1d(in_dim, d_model, k, padding=k // 2) for k in kernels
|
| 677 |
+
])
|
| 678 |
+
self.conv_merge = nn.Conv1d(d_model * len(kernels), d_model, 1)
|
| 679 |
+
|
| 680 |
+
self.pos = nn.Parameter(torch.zeros(1, max_T, d_model))
|
| 681 |
+
nn.init.trunc_normal_(self.pos, std=0.02)
|
| 682 |
+
self.max_T = max_T
|
| 683 |
+
|
| 684 |
+
enc_layer = nn.TransformerEncoderLayer(
|
| 685 |
+
d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
|
| 686 |
+
dropout=dropout, batch_first=True, activation="gelu",
|
| 687 |
+
)
|
| 688 |
+
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
|
| 689 |
+
head_in = d_model
|
| 690 |
+
if use_prev_action:
|
| 691 |
+
self.prev_concat = _PrevActionConcat(prev_emb_dim)
|
| 692 |
+
head_in += self.prev_concat.out_dim
|
| 693 |
+
else:
|
| 694 |
+
self.prev_concat = None
|
| 695 |
+
self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout)
|
| 696 |
+
|
| 697 |
+
def forward(self, x, mask, prev_v_comp=None, prev_noun=None):
|
| 698 |
+
feats = torch.cat([x[m] for m in x], dim=-1).transpose(1, 2)
|
| 699 |
+
multi = [c(feats) for c in self.multi_conv]
|
| 700 |
+
h = self.conv_merge(torch.cat(multi, dim=1))
|
| 701 |
+
h = h.transpose(1, 2)
|
| 702 |
+
T = h.size(1)
|
| 703 |
+
if T > self.max_T:
|
| 704 |
+
raise ValueError(f"T={T} exceeds HandFormer max_T={self.max_T}")
|
| 705 |
+
h = h + self.pos[:, :T, :]
|
| 706 |
+
h = self.encoder(h, src_key_padding_mask=~mask)
|
| 707 |
+
pooled = _masked_mean_pool(h, mask)
|
| 708 |
+
if self.use_prev_action:
|
| 709 |
+
pooled = self.prev_concat(pooled, prev_v_comp, prev_noun)
|
| 710 |
+
return self.head(pooled)
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
# ---------------------------------------------------------------------------
|
| 714 |
+
# Placeholder ActionLLM — a conv-stem sensor encoder + a 2-layer Transformer
|
| 715 |
+
# trained from scratch as a surrogate. The *full* LoRA+Qwen version lives in
|
| 716 |
+
# `train_pred.py` and can be wired in later if the surrogate is too weak.
|
| 717 |
+
# ---------------------------------------------------------------------------
|
| 718 |
+
|
| 719 |
+
class ActionLLMSurrogate(nn.Module):
|
| 720 |
+
def __init__(self, modality_dims: Dict[str, int], d_model: int = 192,
|
| 721 |
+
n_heads: int = 6, n_layers: int = 2, dropout: float = 0.1,
|
| 722 |
+
head_hidden: int = 256, max_T: int = 256,
|
| 723 |
+
use_prev_action: bool = False, prev_emb_dim: int = 32):
|
| 724 |
+
super().__init__()
|
| 725 |
+
self.use_prev_action = use_prev_action
|
| 726 |
+
in_dim = sum(modality_dims.values())
|
| 727 |
+
self.stem = nn.Sequential(
|
| 728 |
+
nn.Conv1d(in_dim, d_model, 5, padding=2),
|
| 729 |
+
nn.GELU(),
|
| 730 |
+
nn.Conv1d(d_model, d_model, 5, padding=2),
|
| 731 |
+
)
|
| 732 |
+
self.pos = nn.Parameter(torch.zeros(1, max_T, d_model))
|
| 733 |
+
nn.init.trunc_normal_(self.pos, std=0.02)
|
| 734 |
+
self.max_T = max_T
|
| 735 |
+
enc_layer = nn.TransformerEncoderLayer(
|
| 736 |
+
d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
|
| 737 |
+
dropout=dropout, batch_first=True, activation="gelu",
|
| 738 |
+
)
|
| 739 |
+
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
|
| 740 |
+
head_in = d_model
|
| 741 |
+
if use_prev_action:
|
| 742 |
+
self.prev_concat = _PrevActionConcat(prev_emb_dim)
|
| 743 |
+
head_in += self.prev_concat.out_dim
|
| 744 |
+
else:
|
| 745 |
+
self.prev_concat = None
|
| 746 |
+
self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout)
|
| 747 |
+
|
| 748 |
+
def forward(self, x, mask, prev_v_comp=None, prev_noun=None):
|
| 749 |
+
feats = torch.cat([x[m] for m in x], dim=-1).transpose(1, 2)
|
| 750 |
+
h = self.stem(feats).transpose(1, 2)
|
| 751 |
+
T = h.size(1)
|
| 752 |
+
if T > self.max_T:
|
| 753 |
+
raise ValueError(f"T={T} exceeds ActionLLM max_T={self.max_T}")
|
| 754 |
+
h = h + self.pos[:, :T, :]
|
| 755 |
+
h = self.encoder(h, src_key_padding_mask=~mask)
|
| 756 |
+
pooled = _masked_mean_pool(h, mask)
|
| 757 |
+
if self.use_prev_action:
|
| 758 |
+
pooled = self.prev_concat(pooled, prev_v_comp, prev_noun)
|
| 759 |
+
return self.head(pooled)
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
# ---------------------------------------------------------------------------
|
| 763 |
+
# Factory
|
| 764 |
+
# ---------------------------------------------------------------------------
|
| 765 |
+
|
| 766 |
+
def build_model(
|
| 767 |
+
name: str, modality_dims: Dict[str, int], **kwargs,
|
| 768 |
+
) -> nn.Module:
|
| 769 |
+
name = name.lower()
|
| 770 |
+
if name in ("deepconvlstm", "dcl"):
|
| 771 |
+
return DeepConvLSTMTriplet(modality_dims, **kwargs)
|
| 772 |
+
if name in ("dailyactformer", "ours", "daf"):
|
| 773 |
+
return DailyActFormer(modality_dims, **kwargs)
|
| 774 |
+
if name in ("rulstm",):
|
| 775 |
+
return RULSTMTriplet(modality_dims, **kwargs)
|
| 776 |
+
if name in ("futr",):
|
| 777 |
+
return FUTRTriplet(modality_dims, **kwargs)
|
| 778 |
+
if name in ("afft",):
|
| 779 |
+
return AFFTTriplet(modality_dims, **kwargs)
|
| 780 |
+
if name in ("handformer",):
|
| 781 |
+
return HandFormerTriplet(modality_dims, **kwargs)
|
| 782 |
+
if name in ("actionllm",):
|
| 783 |
+
return ActionLLMSurrogate(modality_dims, **kwargs)
|
| 784 |
+
raise ValueError(f"Unknown model: {name}")
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
# ---------------------------------------------------------------------------
|
| 788 |
+
# Smoke-test: build each model, run a random batch, check output shapes.
|
| 789 |
+
# ---------------------------------------------------------------------------
|
| 790 |
+
|
| 791 |
+
if __name__ == "__main__":
|
| 792 |
+
B, T = 2, 160
|
| 793 |
+
dims = {"imu": 180, "emg": 8, "eyetrack": 24}
|
| 794 |
+
x = {m: torch.randn(B, T, d) for m, d in dims.items()}
|
| 795 |
+
mask = torch.ones(B, T, dtype=torch.bool)
|
| 796 |
+
|
| 797 |
+
for name in ("deepconvlstm", "dailyactformer", "rulstm", "futr", "afft",
|
| 798 |
+
"handformer", "actionllm"):
|
| 799 |
+
model = build_model(name, dims)
|
| 800 |
+
n_params = sum(p.numel() for p in model.parameters())
|
| 801 |
+
out = model(x, mask)
|
| 802 |
+
print(f"{name:16s} params={n_params:>10,} shapes="
|
| 803 |
+
f"vf={tuple(out['verb_fine'].shape)} "
|
| 804 |
+
f"vc={tuple(out['verb_composite'].shape)} "
|
| 805 |
+
f"n={tuple(out['noun'].shape)} "
|
| 806 |
+
f"h={tuple(out['hand'].shape)}")
|
experiments/nets/published_models.py
ADDED
|
@@ -0,0 +1,699 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Published baseline models for NeurIPS 2026 benchmark experiments.
|
| 3 |
+
|
| 4 |
+
Contains faithful implementations of 6 published models:
|
| 5 |
+
1. DeepConvLSTM (Ordonez & Roggen, Sensors 2016) - Exp1/Exp3
|
| 6 |
+
2. InceptionTime (Fawaz et al., DMKD 2020) - Exp1/Exp3
|
| 7 |
+
3. MS-TCN++ (Li et al., TPAMI 2020) - Exp2
|
| 8 |
+
4. DiffAct (Liu et al., ICCV 2023) - Exp2
|
| 9 |
+
5. UnderPressure (Mourot et al., SCA/CGF 2022) - Exp3/Exp4a
|
| 10 |
+
6. emg2pose (Meta, NeurIPS 2024 D&B) - Exp4b
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import math
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# ============================================================
|
| 21 |
+
# 1. DeepConvLSTM (Ordonez & Roggen, Sensors 2016)
|
| 22 |
+
# "Deep Convolutional and LSTM Recurrent Neural Networks
|
| 23 |
+
# for Multimodal Wearable Activity Recognition"
|
| 24 |
+
# 4 Conv layers -> 2 LSTM layers -> pooling/per-frame output
|
| 25 |
+
# ============================================================
|
| 26 |
+
|
| 27 |
+
class DeepConvLSTMBackbone(nn.Module):
|
| 28 |
+
"""DeepConvLSTM backbone for sequence-level classification (Exp1).
|
| 29 |
+
|
| 30 |
+
Input: (B, T, C), optional mask
|
| 31 |
+
Output: (B, output_dim)
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, input_dim, hidden_dim=128, num_conv_layers=4,
|
| 35 |
+
conv_filters=64, conv_kernel=5, num_lstm_layers=2):
|
| 36 |
+
super().__init__()
|
| 37 |
+
conv_layers = []
|
| 38 |
+
in_ch = input_dim
|
| 39 |
+
for i in range(num_conv_layers):
|
| 40 |
+
out_ch = conv_filters
|
| 41 |
+
conv_layers.append(nn.Sequential(
|
| 42 |
+
nn.Conv1d(in_ch, out_ch, conv_kernel, padding=conv_kernel // 2),
|
| 43 |
+
nn.BatchNorm1d(out_ch),
|
| 44 |
+
nn.ReLU(),
|
| 45 |
+
nn.Dropout(0.1 if i < num_conv_layers - 1 else 0.2),
|
| 46 |
+
))
|
| 47 |
+
in_ch = out_ch
|
| 48 |
+
self.convs = nn.ModuleList(conv_layers)
|
| 49 |
+
|
| 50 |
+
self.lstm = nn.LSTM(
|
| 51 |
+
conv_filters, hidden_dim, num_layers=num_lstm_layers,
|
| 52 |
+
batch_first=True, bidirectional=False,
|
| 53 |
+
dropout=0.2 if num_lstm_layers > 1 else 0,
|
| 54 |
+
)
|
| 55 |
+
self.output_dim = hidden_dim
|
| 56 |
+
|
| 57 |
+
def forward(self, x, mask=None):
|
| 58 |
+
# x: (B, T, C) -> Conv expects (B, C, T)
|
| 59 |
+
x = x.permute(0, 2, 1)
|
| 60 |
+
for conv in self.convs:
|
| 61 |
+
x = conv(x)
|
| 62 |
+
x = x.permute(0, 2, 1) # (B, T, conv_filters)
|
| 63 |
+
|
| 64 |
+
out, (h_n, _) = self.lstm(x)
|
| 65 |
+
# Use last hidden state
|
| 66 |
+
feat = h_n[-1] # (B, hidden_dim)
|
| 67 |
+
return feat
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class DeepConvLSTMContact(nn.Module):
|
| 71 |
+
"""DeepConvLSTM for frame-level contact detection (Exp3).
|
| 72 |
+
|
| 73 |
+
Input: (B, T, C)
|
| 74 |
+
Output: (B, T, 2)
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(self, input_dim, hidden_dim=64, num_conv_layers=4,
|
| 78 |
+
conv_filters=64, conv_kernel=5):
|
| 79 |
+
super().__init__()
|
| 80 |
+
conv_layers = []
|
| 81 |
+
in_ch = input_dim
|
| 82 |
+
for i in range(num_conv_layers):
|
| 83 |
+
conv_layers.append(nn.Sequential(
|
| 84 |
+
nn.Conv1d(in_ch, conv_filters, conv_kernel, padding=conv_kernel // 2),
|
| 85 |
+
nn.BatchNorm1d(conv_filters),
|
| 86 |
+
nn.ReLU(),
|
| 87 |
+
nn.Dropout(0.1),
|
| 88 |
+
))
|
| 89 |
+
in_ch = conv_filters
|
| 90 |
+
self.convs = nn.ModuleList(conv_layers)
|
| 91 |
+
self.lstm = nn.LSTM(conv_filters, hidden_dim, num_layers=2,
|
| 92 |
+
batch_first=True, bidirectional=True, dropout=0.2)
|
| 93 |
+
self.head = nn.Linear(hidden_dim * 2, 2)
|
| 94 |
+
|
| 95 |
+
def forward(self, x):
|
| 96 |
+
x = x.permute(0, 2, 1)
|
| 97 |
+
for conv in self.convs:
|
| 98 |
+
x = conv(x)
|
| 99 |
+
x = x.permute(0, 2, 1)
|
| 100 |
+
out, _ = self.lstm(x)
|
| 101 |
+
return self.head(out)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# ============================================================
|
| 105 |
+
# 2. InceptionTime (Fawaz et al., DMKD 2020)
|
| 106 |
+
# "InceptionTime: Finding AlexNet for Time Series Classification"
|
| 107 |
+
# Inception modules with multi-scale convolutions + residual
|
| 108 |
+
# ============================================================
|
| 109 |
+
|
| 110 |
+
class InceptionModule(nn.Module):
|
| 111 |
+
"""Single Inception module for time series."""
|
| 112 |
+
|
| 113 |
+
def __init__(self, in_channels, n_filters=32, kernel_sizes=(9, 19, 39),
|
| 114 |
+
bottleneck_channels=32):
|
| 115 |
+
super().__init__()
|
| 116 |
+
# Bottleneck
|
| 117 |
+
self.bottleneck = nn.Conv1d(in_channels, bottleneck_channels, 1, bias=False)
|
| 118 |
+
|
| 119 |
+
# Parallel convolutions with different kernel sizes (odd kernels for symmetric padding)
|
| 120 |
+
self.convs = nn.ModuleList()
|
| 121 |
+
for ks in kernel_sizes:
|
| 122 |
+
self.convs.append(
|
| 123 |
+
nn.Conv1d(bottleneck_channels, n_filters, ks,
|
| 124 |
+
padding=(ks - 1) // 2, bias=False)
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# MaxPool branch
|
| 128 |
+
self.maxpool_conv = nn.Sequential(
|
| 129 |
+
nn.MaxPool1d(3, stride=1, padding=1),
|
| 130 |
+
nn.Conv1d(in_channels, n_filters, 1, bias=False),
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
self.bn = nn.BatchNorm1d(n_filters * (len(kernel_sizes) + 1))
|
| 134 |
+
self.relu = nn.ReLU()
|
| 135 |
+
|
| 136 |
+
def forward(self, x):
|
| 137 |
+
# x: (B, C, T)
|
| 138 |
+
x_bottleneck = self.bottleneck(x)
|
| 139 |
+
conv_outputs = [conv(x_bottleneck) for conv in self.convs]
|
| 140 |
+
conv_outputs.append(self.maxpool_conv(x))
|
| 141 |
+
out = torch.cat(conv_outputs, dim=1)
|
| 142 |
+
return self.relu(self.bn(out))
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class InceptionBlock(nn.Module):
|
| 146 |
+
"""Stack of Inception modules with a residual connection."""
|
| 147 |
+
|
| 148 |
+
def __init__(self, in_channels, n_filters=32, depth=3):
|
| 149 |
+
super().__init__()
|
| 150 |
+
n_out = n_filters * 4 # 3 conv branches + 1 maxpool branch
|
| 151 |
+
modules = []
|
| 152 |
+
for i in range(depth):
|
| 153 |
+
inc = in_channels if i == 0 else n_out
|
| 154 |
+
modules.append(InceptionModule(inc, n_filters))
|
| 155 |
+
self.modules_list = nn.ModuleList(modules)
|
| 156 |
+
|
| 157 |
+
# Residual connection
|
| 158 |
+
self.use_residual = (in_channels != n_out)
|
| 159 |
+
if self.use_residual:
|
| 160 |
+
self.residual = nn.Sequential(
|
| 161 |
+
nn.Conv1d(in_channels, n_out, 1, bias=False),
|
| 162 |
+
nn.BatchNorm1d(n_out),
|
| 163 |
+
)
|
| 164 |
+
self.relu = nn.ReLU()
|
| 165 |
+
|
| 166 |
+
def forward(self, x):
|
| 167 |
+
residual = x
|
| 168 |
+
for mod in self.modules_list:
|
| 169 |
+
x = mod(x)
|
| 170 |
+
if self.use_residual:
|
| 171 |
+
residual = self.residual(residual)
|
| 172 |
+
return self.relu(x + residual)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class InceptionTimeBackbone(nn.Module):
|
| 176 |
+
"""InceptionTime backbone for sequence-level classification (Exp1).
|
| 177 |
+
|
| 178 |
+
Input: (B, T, C), optional mask
|
| 179 |
+
Output: (B, output_dim)
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
def __init__(self, input_dim, hidden_dim=128, n_filters=32, num_blocks=2, depth=3):
|
| 183 |
+
super().__init__()
|
| 184 |
+
blocks = []
|
| 185 |
+
in_ch = input_dim
|
| 186 |
+
for i in range(num_blocks):
|
| 187 |
+
blocks.append(InceptionBlock(in_ch, n_filters, depth))
|
| 188 |
+
in_ch = n_filters * 4
|
| 189 |
+
self.blocks = nn.ModuleList(blocks)
|
| 190 |
+
self.output_dim = n_filters * 4
|
| 191 |
+
|
| 192 |
+
def forward(self, x, mask=None):
|
| 193 |
+
# x: (B, T, C) -> (B, C, T)
|
| 194 |
+
x = x.permute(0, 2, 1)
|
| 195 |
+
for block in self.blocks:
|
| 196 |
+
x = block(x)
|
| 197 |
+
# Global average pooling with mask
|
| 198 |
+
if mask is not None:
|
| 199 |
+
x = (x * mask.unsqueeze(1).float()).sum(2) / mask.sum(1, keepdim=True).float().clamp(min=1)
|
| 200 |
+
else:
|
| 201 |
+
x = x.mean(2)
|
| 202 |
+
return x # (B, n_filters*4)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class InceptionTimeContact(nn.Module):
|
| 206 |
+
"""InceptionTime for frame-level contact detection (Exp3).
|
| 207 |
+
|
| 208 |
+
Input: (B, T, C)
|
| 209 |
+
Output: (B, T, 2)
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
def __init__(self, input_dim, hidden_dim=64, n_filters=32, num_blocks=2, depth=3):
|
| 213 |
+
super().__init__()
|
| 214 |
+
blocks = []
|
| 215 |
+
in_ch = input_dim
|
| 216 |
+
for i in range(num_blocks):
|
| 217 |
+
blocks.append(InceptionBlock(in_ch, n_filters, depth))
|
| 218 |
+
in_ch = n_filters * 4
|
| 219 |
+
self.blocks = nn.ModuleList(blocks)
|
| 220 |
+
self.head = nn.Conv1d(n_filters * 4, 2, 1)
|
| 221 |
+
|
| 222 |
+
def forward(self, x):
|
| 223 |
+
x = x.permute(0, 2, 1)
|
| 224 |
+
for block in self.blocks:
|
| 225 |
+
x = block(x)
|
| 226 |
+
out = self.head(x)
|
| 227 |
+
return out.permute(0, 2, 1) # (B, T, 2)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# ============================================================
|
| 231 |
+
# 3. MS-TCN++ (Li et al., TPAMI 2020)
|
| 232 |
+
# "MS-TCN++: Multi-Stage Temporal Convolutional Network
|
| 233 |
+
# for Action Segmentation"
|
| 234 |
+
# Key improvement: dual dilated layers in each residual block
|
| 235 |
+
# ============================================================
|
| 236 |
+
|
| 237 |
+
class DualDilatedResBlock(nn.Module):
|
| 238 |
+
"""Dual dilated residual block (MS-TCN++ key contribution).
|
| 239 |
+
|
| 240 |
+
Uses two parallel dilated convolutions with different dilation rates
|
| 241 |
+
to capture both short-range and long-range temporal patterns.
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
def __init__(self, channels, dilation1, dilation2):
|
| 245 |
+
super().__init__()
|
| 246 |
+
# Branch 1: smaller dilation
|
| 247 |
+
self.conv1_dilated = nn.Conv1d(
|
| 248 |
+
channels, channels, 3,
|
| 249 |
+
padding=dilation1, dilation=dilation1
|
| 250 |
+
)
|
| 251 |
+
# Branch 2: larger dilation
|
| 252 |
+
self.conv2_dilated = nn.Conv1d(
|
| 253 |
+
channels, channels, 3,
|
| 254 |
+
padding=dilation2, dilation=dilation2
|
| 255 |
+
)
|
| 256 |
+
self.conv_fusion = nn.Conv1d(channels, channels, 1)
|
| 257 |
+
self.bn = nn.BatchNorm1d(channels)
|
| 258 |
+
self.dropout = nn.Dropout(0.3)
|
| 259 |
+
|
| 260 |
+
def forward(self, x):
|
| 261 |
+
residual = x
|
| 262 |
+
out1 = F.relu(self.conv1_dilated(x))
|
| 263 |
+
out2 = F.relu(self.conv2_dilated(x))
|
| 264 |
+
out = out1 + out2
|
| 265 |
+
out = self.dropout(F.relu(self.bn(self.conv_fusion(out))))
|
| 266 |
+
return out + residual
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class MSTCNPPStage(nn.Module):
|
| 270 |
+
"""Single stage of MS-TCN++ with dual dilated layers."""
|
| 271 |
+
|
| 272 |
+
def __init__(self, in_channels, hidden_channels, num_classes, num_layers=10):
|
| 273 |
+
super().__init__()
|
| 274 |
+
self.input_conv = nn.Conv1d(in_channels, hidden_channels, 1)
|
| 275 |
+
self.layers = nn.ModuleList()
|
| 276 |
+
for i in range(num_layers):
|
| 277 |
+
dilation1 = 2 ** i
|
| 278 |
+
dilation2 = 2 ** (i + 1) if i < num_layers - 1 else 2 ** i
|
| 279 |
+
self.layers.append(DualDilatedResBlock(hidden_channels, dilation1, dilation2))
|
| 280 |
+
self.output_conv = nn.Conv1d(hidden_channels, num_classes, 1)
|
| 281 |
+
|
| 282 |
+
def forward(self, x):
|
| 283 |
+
x = self.input_conv(x)
|
| 284 |
+
for layer in self.layers:
|
| 285 |
+
x = layer(x)
|
| 286 |
+
return self.output_conv(x)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class MSTCNPP(nn.Module):
|
| 290 |
+
"""MS-TCN++ for temporal action segmentation (Exp2).
|
| 291 |
+
|
| 292 |
+
Input: (B, T, C)
|
| 293 |
+
Output: list of (B, T, num_classes) per stage
|
| 294 |
+
"""
|
| 295 |
+
|
| 296 |
+
def __init__(self, input_dim, num_classes, hidden_dim=64, num_stages=4, num_layers=10):
|
| 297 |
+
super().__init__()
|
| 298 |
+
self.stages = nn.ModuleList()
|
| 299 |
+
# First stage: input features -> predictions
|
| 300 |
+
self.stages.append(MSTCNPPStage(input_dim, hidden_dim, num_classes, num_layers))
|
| 301 |
+
# Refinement stages: predictions -> refined predictions
|
| 302 |
+
for _ in range(num_stages - 1):
|
| 303 |
+
self.stages.append(MSTCNPPStage(num_classes, hidden_dim, num_classes, num_layers))
|
| 304 |
+
|
| 305 |
+
def forward(self, x):
|
| 306 |
+
x = x.permute(0, 2, 1) # (B, C, T)
|
| 307 |
+
outputs = []
|
| 308 |
+
for stage in self.stages:
|
| 309 |
+
x = stage(x)
|
| 310 |
+
outputs.append(x.permute(0, 2, 1)) # (B, T, num_classes)
|
| 311 |
+
# Feed softmax of predictions to next stage
|
| 312 |
+
if stage != self.stages[-1]:
|
| 313 |
+
x = F.softmax(x, dim=1)
|
| 314 |
+
return outputs
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
# ============================================================
|
| 318 |
+
# 4. DiffAct (Liu et al., ICCV 2023)
|
| 319 |
+
# "Diffusion Action Segmentation"
|
| 320 |
+
# Denoising diffusion model for iterative action refinement.
|
| 321 |
+
# Simplified but faithful implementation.
|
| 322 |
+
# ============================================================
|
| 323 |
+
|
| 324 |
+
class ConditionalLayerNorm(nn.Module):
|
| 325 |
+
"""Layer norm conditioned on diffusion timestep."""
|
| 326 |
+
|
| 327 |
+
def __init__(self, channels):
|
| 328 |
+
super().__init__()
|
| 329 |
+
self.norm = nn.GroupNorm(1, channels) # equivalent to LayerNorm for 1D
|
| 330 |
+
|
| 331 |
+
def forward(self, x):
|
| 332 |
+
return self.norm(x)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
class DiffActBlock(nn.Module):
|
| 336 |
+
"""Residual block for DiffAct denoising network."""
|
| 337 |
+
|
| 338 |
+
def __init__(self, channels, dilation, time_emb_dim):
|
| 339 |
+
super().__init__()
|
| 340 |
+
self.conv1 = nn.Conv1d(channels, channels, 3, padding=dilation, dilation=dilation)
|
| 341 |
+
self.conv2 = nn.Conv1d(channels, channels, 1)
|
| 342 |
+
self.norm1 = ConditionalLayerNorm(channels)
|
| 343 |
+
self.norm2 = ConditionalLayerNorm(channels)
|
| 344 |
+
self.time_proj = nn.Linear(time_emb_dim, channels)
|
| 345 |
+
self.dropout = nn.Dropout(0.1)
|
| 346 |
+
|
| 347 |
+
def forward(self, x, time_emb):
|
| 348 |
+
residual = x
|
| 349 |
+
x = self.norm1(x)
|
| 350 |
+
x = F.relu(self.conv1(x))
|
| 351 |
+
# Add time embedding
|
| 352 |
+
t = self.time_proj(time_emb).unsqueeze(-1) # (B, C, 1)
|
| 353 |
+
x = x + t
|
| 354 |
+
x = self.norm2(x)
|
| 355 |
+
x = self.dropout(F.relu(self.conv2(x)))
|
| 356 |
+
return x + residual
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
class DiffActConditionEncoder(nn.Module):
|
| 360 |
+
"""Temporal feature encoder for conditioning the denoising network."""
|
| 361 |
+
|
| 362 |
+
def __init__(self, input_dim, hidden_dim, num_layers=6):
|
| 363 |
+
super().__init__()
|
| 364 |
+
self.input_conv = nn.Conv1d(input_dim, hidden_dim, 1)
|
| 365 |
+
self.layers = nn.ModuleList()
|
| 366 |
+
for i in range(num_layers):
|
| 367 |
+
dilation = 2 ** (i % 5)
|
| 368 |
+
self.layers.append(nn.Sequential(
|
| 369 |
+
nn.Conv1d(hidden_dim, hidden_dim, 3, padding=dilation, dilation=dilation),
|
| 370 |
+
nn.BatchNorm1d(hidden_dim),
|
| 371 |
+
nn.ReLU(),
|
| 372 |
+
nn.Dropout(0.1),
|
| 373 |
+
))
|
| 374 |
+
|
| 375 |
+
def forward(self, x):
|
| 376 |
+
x = self.input_conv(x)
|
| 377 |
+
for layer in self.layers:
|
| 378 |
+
x = layer(x) + x # residual
|
| 379 |
+
return x
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
class SinusoidalTimeEmbedding(nn.Module):
|
| 383 |
+
"""Sinusoidal positional embedding for diffusion timestep."""
|
| 384 |
+
|
| 385 |
+
def __init__(self, dim):
|
| 386 |
+
super().__init__()
|
| 387 |
+
self.dim = dim
|
| 388 |
+
self.mlp = nn.Sequential(
|
| 389 |
+
nn.Linear(dim, dim * 4),
|
| 390 |
+
nn.GELU(),
|
| 391 |
+
nn.Linear(dim * 4, dim),
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
def forward(self, t):
|
| 395 |
+
half_dim = self.dim // 2
|
| 396 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 397 |
+
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
|
| 398 |
+
emb = t.unsqueeze(-1).float() * emb.unsqueeze(0)
|
| 399 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
| 400 |
+
return self.mlp(emb)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
class DiffAct(nn.Module):
|
| 404 |
+
"""DiffAct: Diffusion Action Segmentation (Exp2).
|
| 405 |
+
|
| 406 |
+
During training: noises ground-truth action probabilities and denoises.
|
| 407 |
+
During inference: iteratively denoises from pure noise.
|
| 408 |
+
|
| 409 |
+
Input: (B, T, C)
|
| 410 |
+
Output: list of (B, T, num_classes) [final denoised prediction]
|
| 411 |
+
"""
|
| 412 |
+
|
| 413 |
+
def __init__(self, input_dim, num_classes, hidden_dim=64,
|
| 414 |
+
num_encoder_layers=6, num_denoise_layers=6,
|
| 415 |
+
num_diffusion_steps=10):
|
| 416 |
+
super().__init__()
|
| 417 |
+
self.num_classes = num_classes
|
| 418 |
+
self.num_steps = num_diffusion_steps
|
| 419 |
+
|
| 420 |
+
# Condition encoder: extract temporal features from input
|
| 421 |
+
self.condition_encoder = DiffActConditionEncoder(input_dim, hidden_dim, num_encoder_layers)
|
| 422 |
+
|
| 423 |
+
# Initial prediction head (non-diffusion baseline)
|
| 424 |
+
self.initial_head = nn.Conv1d(hidden_dim, num_classes, 1)
|
| 425 |
+
|
| 426 |
+
# Time embedding
|
| 427 |
+
self.time_emb = SinusoidalTimeEmbedding(hidden_dim)
|
| 428 |
+
|
| 429 |
+
# Denoising network
|
| 430 |
+
self.denoise_input = nn.Conv1d(num_classes + hidden_dim, hidden_dim, 1)
|
| 431 |
+
self.denoise_blocks = nn.ModuleList()
|
| 432 |
+
for i in range(num_denoise_layers):
|
| 433 |
+
dilation = 2 ** (i % 5)
|
| 434 |
+
self.denoise_blocks.append(DiffActBlock(hidden_dim, dilation, hidden_dim))
|
| 435 |
+
self.denoise_output = nn.Conv1d(hidden_dim, num_classes, 1)
|
| 436 |
+
|
| 437 |
+
# Noise schedule (cosine)
|
| 438 |
+
self._setup_noise_schedule()
|
| 439 |
+
|
| 440 |
+
def _setup_noise_schedule(self):
|
| 441 |
+
steps = self.num_steps
|
| 442 |
+
s = 0.008
|
| 443 |
+
t = torch.linspace(0, steps, steps + 1)
|
| 444 |
+
alphas_cumprod = torch.cos(((t / steps) + s) / (1 + s) * math.pi * 0.5) ** 2
|
| 445 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
| 446 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
| 447 |
+
betas = torch.clamp(betas, 0.0001, 0.999)
|
| 448 |
+
alphas = 1.0 - betas
|
| 449 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 450 |
+
self.register_buffer('betas', betas)
|
| 451 |
+
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
| 452 |
+
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
| 453 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1 - alphas_cumprod))
|
| 454 |
+
|
| 455 |
+
def _add_noise(self, x_start, t, noise=None):
|
| 456 |
+
"""Add noise to x_start at timestep t."""
|
| 457 |
+
if noise is None:
|
| 458 |
+
noise = torch.randn_like(x_start)
|
| 459 |
+
sqrt_alpha = self.sqrt_alphas_cumprod[t].view(-1, 1, 1)
|
| 460 |
+
sqrt_one_minus = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1)
|
| 461 |
+
return sqrt_alpha * x_start + sqrt_one_minus * noise
|
| 462 |
+
|
| 463 |
+
def _denoise_step(self, x_noisy, cond_features, time_emb):
|
| 464 |
+
"""Single denoising step."""
|
| 465 |
+
x = torch.cat([x_noisy, cond_features], dim=1) # (B, C+hidden, T)
|
| 466 |
+
x = self.denoise_input(x)
|
| 467 |
+
for block in self.denoise_blocks:
|
| 468 |
+
x = block(x, time_emb)
|
| 469 |
+
return self.denoise_output(x)
|
| 470 |
+
|
| 471 |
+
def forward(self, x):
|
| 472 |
+
"""
|
| 473 |
+
Training: returns [initial_pred, denoised_pred]
|
| 474 |
+
Inference: returns [initial_pred, iteratively_denoised_pred]
|
| 475 |
+
"""
|
| 476 |
+
x_in = x.permute(0, 2, 1) # (B, C, T)
|
| 477 |
+
B, _, T = x_in.shape
|
| 478 |
+
|
| 479 |
+
# Encode condition features
|
| 480 |
+
cond = self.condition_encoder(x_in) # (B, hidden, T)
|
| 481 |
+
initial_logits = self.initial_head(cond).permute(0, 2, 1) # (B, T, num_classes)
|
| 482 |
+
|
| 483 |
+
if self.training:
|
| 484 |
+
# Training: noise the initial prediction and denoise (end-to-end)
|
| 485 |
+
x_start = F.softmax(initial_logits, dim=-1).permute(0, 2, 1) # (B, C, T)
|
| 486 |
+
t = torch.randint(0, self.num_steps, (B,), device=x.device)
|
| 487 |
+
noise = torch.randn_like(x_start)
|
| 488 |
+
x_noisy = self._add_noise(x_start.detach(), t, noise)
|
| 489 |
+
time_emb = self.time_emb(t)
|
| 490 |
+
denoised = self._denoise_step(x_noisy, cond, time_emb)
|
| 491 |
+
return [initial_logits, denoised.permute(0, 2, 1)]
|
| 492 |
+
else:
|
| 493 |
+
# Inference: iterative denoising from noise
|
| 494 |
+
x_t = torch.randn(B, self.num_classes, T, device=x.device)
|
| 495 |
+
for step in reversed(range(self.num_steps)):
|
| 496 |
+
t = torch.full((B,), step, device=x.device, dtype=torch.long)
|
| 497 |
+
time_emb = self.time_emb(t)
|
| 498 |
+
pred_noise = self._denoise_step(x_t, cond, time_emb)
|
| 499 |
+
# Simplified DDPM update
|
| 500 |
+
alpha = self.alphas_cumprod[step]
|
| 501 |
+
alpha_prev = self.alphas_cumprod[step - 1] if step > 0 else torch.tensor(1.0)
|
| 502 |
+
beta = self.betas[step]
|
| 503 |
+
x_t = (1 / torch.sqrt(1 - beta)) * (
|
| 504 |
+
x_t - beta / self.sqrt_one_minus_alphas_cumprod[step] * pred_noise
|
| 505 |
+
)
|
| 506 |
+
if step > 0:
|
| 507 |
+
x_t = x_t + torch.sqrt(beta) * torch.randn_like(x_t) * 0.5
|
| 508 |
+
return [initial_logits, x_t.permute(0, 2, 1)]
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
# ============================================================
|
| 512 |
+
# 5. UnderPressure (Mourot et al., SCA/CGF 2022)
|
| 513 |
+
# "UnderPressure: Deep Learning for Foot Contact Detection,
|
| 514 |
+
# Ground Reaction Force Estimation and Footskate Cleanup"
|
| 515 |
+
# GRU-based architecture for contact detection + force regression.
|
| 516 |
+
# Adapted for hand contact detection and MoCap->Pressure prediction.
|
| 517 |
+
# ============================================================
|
| 518 |
+
|
| 519 |
+
class UnderPressureContact(nn.Module):
|
| 520 |
+
"""UnderPressure model adapted for hand contact detection (Exp3).
|
| 521 |
+
|
| 522 |
+
Architecture: Conv feature extractor -> BiGRU -> contact prediction head
|
| 523 |
+
Input: (B, T, C)
|
| 524 |
+
Output: (B, T, 2) [right_contact, left_contact]
|
| 525 |
+
"""
|
| 526 |
+
|
| 527 |
+
def __init__(self, input_dim, hidden_dim=64, num_gru_layers=2):
|
| 528 |
+
super().__init__()
|
| 529 |
+
# Feature extractor (conv layers for local temporal patterns)
|
| 530 |
+
self.feature_extractor = nn.Sequential(
|
| 531 |
+
nn.Conv1d(input_dim, hidden_dim, 7, padding=3),
|
| 532 |
+
nn.BatchNorm1d(hidden_dim),
|
| 533 |
+
nn.ReLU(),
|
| 534 |
+
nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
|
| 535 |
+
nn.BatchNorm1d(hidden_dim),
|
| 536 |
+
nn.ReLU(),
|
| 537 |
+
)
|
| 538 |
+
# BiGRU for temporal modeling
|
| 539 |
+
self.gru = nn.GRU(
|
| 540 |
+
hidden_dim, hidden_dim, num_layers=num_gru_layers,
|
| 541 |
+
batch_first=True, bidirectional=True,
|
| 542 |
+
dropout=0.2 if num_gru_layers > 1 else 0,
|
| 543 |
+
)
|
| 544 |
+
# Contact prediction head
|
| 545 |
+
self.contact_head = nn.Sequential(
|
| 546 |
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
| 547 |
+
nn.ReLU(),
|
| 548 |
+
nn.Dropout(0.2),
|
| 549 |
+
nn.Linear(hidden_dim, 2),
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
def forward(self, x):
|
| 553 |
+
# x: (B, T, C) -> (B, C, T)
|
| 554 |
+
feat = self.feature_extractor(x.permute(0, 2, 1))
|
| 555 |
+
feat = feat.permute(0, 2, 1) # (B, T, hidden)
|
| 556 |
+
gru_out, _ = self.gru(feat)
|
| 557 |
+
return self.contact_head(gru_out) # (B, T, 2)
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
class UnderPressureRegressor(nn.Module):
|
| 561 |
+
"""UnderPressure model adapted for MoCap -> Pressure regression (Exp4a).
|
| 562 |
+
|
| 563 |
+
Architecture: Conv feature extractor -> BiGRU -> pressure regression head
|
| 564 |
+
Input: (B, T, input_dim)
|
| 565 |
+
Output: (B, T, output_dim)
|
| 566 |
+
"""
|
| 567 |
+
|
| 568 |
+
def __init__(self, input_dim, output_dim, hidden_dim=128, num_gru_layers=2):
|
| 569 |
+
super().__init__()
|
| 570 |
+
self.feature_extractor = nn.Sequential(
|
| 571 |
+
nn.Conv1d(input_dim, hidden_dim, 7, padding=3),
|
| 572 |
+
nn.BatchNorm1d(hidden_dim),
|
| 573 |
+
nn.ReLU(),
|
| 574 |
+
nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
|
| 575 |
+
nn.BatchNorm1d(hidden_dim),
|
| 576 |
+
nn.ReLU(),
|
| 577 |
+
nn.Conv1d(hidden_dim, hidden_dim, 3, padding=1),
|
| 578 |
+
nn.BatchNorm1d(hidden_dim),
|
| 579 |
+
nn.ReLU(),
|
| 580 |
+
)
|
| 581 |
+
self.gru = nn.GRU(
|
| 582 |
+
hidden_dim, hidden_dim, num_layers=num_gru_layers,
|
| 583 |
+
batch_first=True, bidirectional=True,
|
| 584 |
+
dropout=0.2 if num_gru_layers > 1 else 0,
|
| 585 |
+
)
|
| 586 |
+
self.regression_head = nn.Sequential(
|
| 587 |
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
| 588 |
+
nn.ReLU(),
|
| 589 |
+
nn.Dropout(0.2),
|
| 590 |
+
nn.Linear(hidden_dim, output_dim),
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
def forward(self, x):
|
| 594 |
+
feat = self.feature_extractor(x.permute(0, 2, 1))
|
| 595 |
+
feat = feat.permute(0, 2, 1)
|
| 596 |
+
gru_out, _ = self.gru(feat)
|
| 597 |
+
return self.regression_head(gru_out)
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
# ============================================================
|
| 601 |
+
# 6. emg2pose (Meta/Facebook Research, NeurIPS 2024 D&B)
|
| 602 |
+
# "emg2pose: A Large and Diverse Benchmark for
|
| 603 |
+
# Surface Electromyographic Hand Pose Estimation"
|
| 604 |
+
# CNN feature extractor + Transformer encoder,
|
| 605 |
+
# with optional velocity-based integration (vemg2pose).
|
| 606 |
+
# ============================================================
|
| 607 |
+
|
| 608 |
+
class EMG2PoseEncoder(nn.Module):
|
| 609 |
+
"""CNN + Transformer encoder from emg2pose."""
|
| 610 |
+
|
| 611 |
+
def __init__(self, input_dim, hidden_dim=128, num_transformer_layers=4, nhead=4):
|
| 612 |
+
super().__init__()
|
| 613 |
+
# Multi-scale CNN feature extractor
|
| 614 |
+
self.conv_small = nn.Sequential(
|
| 615 |
+
nn.Conv1d(input_dim, hidden_dim // 2, 3, padding=1),
|
| 616 |
+
nn.BatchNorm1d(hidden_dim // 2),
|
| 617 |
+
nn.ReLU(),
|
| 618 |
+
)
|
| 619 |
+
self.conv_medium = nn.Sequential(
|
| 620 |
+
nn.Conv1d(input_dim, hidden_dim // 4, 7, padding=3),
|
| 621 |
+
nn.BatchNorm1d(hidden_dim // 4),
|
| 622 |
+
nn.ReLU(),
|
| 623 |
+
)
|
| 624 |
+
self.conv_large = nn.Sequential(
|
| 625 |
+
nn.Conv1d(input_dim, hidden_dim // 4, 15, padding=7),
|
| 626 |
+
nn.BatchNorm1d(hidden_dim // 4),
|
| 627 |
+
nn.ReLU(),
|
| 628 |
+
)
|
| 629 |
+
# Projection to hidden_dim
|
| 630 |
+
self.proj = nn.Sequential(
|
| 631 |
+
nn.Conv1d(hidden_dim, hidden_dim, 1),
|
| 632 |
+
nn.BatchNorm1d(hidden_dim),
|
| 633 |
+
nn.ReLU(),
|
| 634 |
+
)
|
| 635 |
+
# Transformer encoder for temporal modeling
|
| 636 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 637 |
+
d_model=hidden_dim, nhead=nhead,
|
| 638 |
+
dim_feedforward=hidden_dim * 4,
|
| 639 |
+
dropout=0.1, batch_first=True,
|
| 640 |
+
)
|
| 641 |
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_transformer_layers)
|
| 642 |
+
|
| 643 |
+
def forward(self, x):
|
| 644 |
+
# x: (B, T, C) -> (B, C, T)
|
| 645 |
+
x_t = x.permute(0, 2, 1)
|
| 646 |
+
f_small = self.conv_small(x_t)
|
| 647 |
+
f_medium = self.conv_medium(x_t)
|
| 648 |
+
f_large = self.conv_large(x_t)
|
| 649 |
+
feat = torch.cat([f_small, f_medium, f_large], dim=1)
|
| 650 |
+
feat = self.proj(feat).permute(0, 2, 1) # (B, T, hidden)
|
| 651 |
+
return self.transformer(feat)
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
class EMG2Pose(nn.Module):
|
| 655 |
+
"""emg2pose model for EMG -> Hand Pose regression (Exp4b).
|
| 656 |
+
|
| 657 |
+
Predicts per-frame hand joint positions from EMG signals.
|
| 658 |
+
Uses velocity-based integration (vemg2pose variant):
|
| 659 |
+
predict velocity -> integrate to get positions.
|
| 660 |
+
|
| 661 |
+
Input: (B, T, input_dim) [EMG channels]
|
| 662 |
+
Output: (B, T, output_dim) [hand joint positions]
|
| 663 |
+
"""
|
| 664 |
+
|
| 665 |
+
def __init__(self, input_dim, output_dim, hidden_dim=128,
|
| 666 |
+
num_transformer_layers=4, use_velocity=True):
|
| 667 |
+
super().__init__()
|
| 668 |
+
self.use_velocity = use_velocity
|
| 669 |
+
self.encoder = EMG2PoseEncoder(input_dim, hidden_dim, num_transformer_layers)
|
| 670 |
+
|
| 671 |
+
if use_velocity:
|
| 672 |
+
# Predict velocity, then integrate
|
| 673 |
+
self.velocity_head = nn.Sequential(
|
| 674 |
+
nn.Linear(hidden_dim, hidden_dim // 2),
|
| 675 |
+
nn.ReLU(),
|
| 676 |
+
nn.Dropout(0.1),
|
| 677 |
+
nn.Linear(hidden_dim // 2, output_dim),
|
| 678 |
+
)
|
| 679 |
+
# Learnable initial position
|
| 680 |
+
self.initial_pos = nn.Parameter(torch.zeros(1, 1, output_dim))
|
| 681 |
+
else:
|
| 682 |
+
# Direct position prediction
|
| 683 |
+
self.position_head = nn.Sequential(
|
| 684 |
+
nn.Linear(hidden_dim, hidden_dim // 2),
|
| 685 |
+
nn.ReLU(),
|
| 686 |
+
nn.Dropout(0.1),
|
| 687 |
+
nn.Linear(hidden_dim // 2, output_dim),
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
def forward(self, x):
|
| 691 |
+
features = self.encoder(x) # (B, T, hidden)
|
| 692 |
+
|
| 693 |
+
if self.use_velocity:
|
| 694 |
+
velocity = self.velocity_head(features) # (B, T, output_dim)
|
| 695 |
+
# Cumulative sum to integrate velocity -> position
|
| 696 |
+
positions = torch.cumsum(velocity, dim=1) + self.initial_pos
|
| 697 |
+
return positions
|
| 698 |
+
else:
|
| 699 |
+
return self.position_head(features)
|
experiments/s9_primitives.json
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": "s9_docx_2025_12_05",
|
| 3 |
+
"source": "${PULSE_ROOT}",
|
| 4 |
+
"categories": ["hand", "arm", "body", "fine", "composite"],
|
| 5 |
+
"primitives": [
|
| 6 |
+
{"id": 0, "category": "hand", "zh": "伸手", "en": "reach", "note": "forward/up/down/side"},
|
| 7 |
+
{"id": 1, "category": "hand", "zh": "抓握", "en": "grasp", "note": "pinch / hold / clamp"},
|
| 8 |
+
{"id": 2, "category": "hand", "zh": "松开", "en": "release", "note": "release object"},
|
| 9 |
+
{"id": 3, "category": "hand", "zh": "旋转手腕", "en": "rotate_wrist", "note": "twist / turn"},
|
| 10 |
+
{"id": 4, "category": "hand", "zh": "按压", "en": "press", "note": "downward force"},
|
| 11 |
+
{"id": 5, "category": "hand", "zh": "拉动", "en": "pull", "note": "toward self"},
|
| 12 |
+
{"id": 6, "category": "hand", "zh": "推动", "en": "push", "note": "outward force"},
|
| 13 |
+
{"id": 7, "category": "hand", "zh": "滑动", "en": "slide", "note": "translation motion"},
|
| 14 |
+
{"id": 8, "category": "hand", "zh": "捏合", "en": "pinch", "note": "two/multi finger pinch"},
|
| 15 |
+
{"id": 9, "category": "hand", "zh": "展开", "en": "spread_fingers", "note": "fingers open"},
|
| 16 |
+
|
| 17 |
+
{"id": 10, "category": "arm", "zh": "抬起", "en": "raise_arm", "note": "arm up"},
|
| 18 |
+
{"id": 11, "category": "arm", "zh": "放下", "en": "lower_arm", "note": "arm down"},
|
| 19 |
+
{"id": 12, "category": "arm", "zh": "伸展", "en": "extend_arm", "note": "arm straight"},
|
| 20 |
+
{"id": 13, "category": "arm", "zh": "弯曲", "en": "bend_elbow", "note": "elbow bend"},
|
| 21 |
+
{"id": 14, "category": "arm", "zh": "摆动", "en": "swing_arm", "note": "left-right / forward-back"},
|
| 22 |
+
{"id": 15, "category": "arm", "zh": "环绕", "en": "circle_arm", "note": "circular motion"},
|
| 23 |
+
|
| 24 |
+
{"id": 16, "category": "body", "zh": "弯腰", "en": "bend_torso", "note": "lean forward"},
|
| 25 |
+
{"id": 17, "category": "body", "zh": "直立", "en": "stand_upright", "note": "return to standing"},
|
| 26 |
+
{"id": 18, "category": "body", "zh": "蹲下", "en": "squat_down", "note": "lower center of mass"},
|
| 27 |
+
{"id": 19, "category": "body", "zh": "站起", "en": "stand_up", "note": "return to height"},
|
| 28 |
+
{"id": 20, "category": "body", "zh": "转身", "en": "turn_body", "note": "torso rotate"},
|
| 29 |
+
{"id": 21, "category": "body", "zh": "侧身", "en": "lean_side", "note": "torso tilt"},
|
| 30 |
+
{"id": 22, "category": "body", "zh": "迈步", "en": "step", "note": "shift position"},
|
| 31 |
+
|
| 32 |
+
{"id": 23, "category": "fine", "zh": "插入", "en": "insert", "note": "object enters"},
|
| 33 |
+
{"id": 24, "category": "fine", "zh": "拔出", "en": "extract", "note": "object exits"},
|
| 34 |
+
{"id": 25, "category": "fine", "zh": "折叠", "en": "fold", "note": "change shape"},
|
| 35 |
+
{"id": 26, "category": "fine", "zh": "撕扯", "en": "tear", "note": "separate"},
|
| 36 |
+
{"id": 27, "category": "fine", "zh": "擦拭", "en": "wipe", "note": "back-and-forth"},
|
| 37 |
+
|
| 38 |
+
{"id": 28, "category": "composite", "zh": "拿起物品", "en": "pick_up_object", "note": "reach -> grasp -> raise"},
|
| 39 |
+
{"id": 29, "category": "composite", "zh": "放下物品", "en": "put_down_object", "note": "move -> release -> retract"},
|
| 40 |
+
{"id": 30, "category": "composite", "zh": "移动物品", "en": "move_object", "note": "pick_up -> move -> put_down"},
|
| 41 |
+
{"id": 31, "category": "composite", "zh": "交换手持物", "en": "transfer_between_hands","note": "one hand grasp -> other hand take -> first release"},
|
| 42 |
+
{"id": 32, "category": "composite", "zh": "打开盖子", "en": "open_lid", "note": "grasp -> rotate/lift"},
|
| 43 |
+
{"id": 33, "category": "composite", "zh": "关闭盖子", "en": "close_lid", "note": "align -> press/rotate"},
|
| 44 |
+
{"id": 34, "category": "composite", "zh": "倒入液体", "en": "pour_liquid", "note": "lift -> tilt -> control flow -> reset"},
|
| 45 |
+
{"id": 35, "category": "composite", "zh": "舀取", "en": "scoop", "note": "insert -> raise -> move"},
|
| 46 |
+
{"id": 36, "category": "composite", "zh": "打开柜门", "en": "open_cabinet_door", "note": "grasp handle -> pull"},
|
| 47 |
+
{"id": 37, "category": "composite", "zh": "关闭柜门", "en": "close_cabinet_door", "note": "push -> confirm"},
|
| 48 |
+
{"id": 38, "category": "composite", "zh": "打开抽屉", "en": "open_drawer", "note": "grasp -> pull out"},
|
| 49 |
+
{"id": 39, "category": "composite", "zh": "按下开关", "en": "press_switch", "note": "reach -> press"},
|
| 50 |
+
{"id": 40, "category": "composite", "zh": "折叠衣物", "en": "fold_clothing", "note": "spread -> fold -> flatten"},
|
| 51 |
+
{"id": 41, "category": "composite", "zh": "叠放物品", "en": "stack_objects", "note": "pick_up -> align -> place gently"},
|
| 52 |
+
{"id": 42, "category": "composite", "zh": "排列物品", "en": "arrange_objects", "note": "move -> adjust spacing -> align"},
|
| 53 |
+
{"id": 43, "category": "composite", "zh": "分类收纳", "en": "sort_and_store", "note": "identify -> group -> place"},
|
| 54 |
+
{"id": 44, "category": "composite", "zh": "擦拭表面", "en": "wipe_surface", "note": "take cloth -> press -> back-and-forth"},
|
| 55 |
+
{"id": 45, "category": "composite", "zh": "扫除垃圾", "en": "sweep_debris", "note": "broom -> gather -> dustpan"},
|
| 56 |
+
{"id": 46, "category": "composite", "zh": "倾倒垃圾", "en": "dump_trash", "note": "lift container -> align -> tilt -> pour"},
|
| 57 |
+
{"id": 47, "category": "composite", "zh": "喷洒液体", "en": "spray_liquid", "note": "press nozzle -> move -> release"},
|
| 58 |
+
{"id": 48, "category": "composite", "zh": "撕胶带", "en": "tear_tape", "note": "pull -> tear off"},
|
| 59 |
+
{"id": 49, "category": "composite", "zh": "贴标签", "en": "stick_label", "note": "peel -> align -> press"},
|
| 60 |
+
{"id": 50, "category": "composite", "zh": "包裹物品", "en": "wrap_object", "note": "spread wrap -> place item -> fold -> seal"},
|
| 61 |
+
{"id": 51, "category": "composite", "zh": "系绳打结", "en": "tie_knot", "note": "cross -> through -> tighten"},
|
| 62 |
+
{"id": 52, "category": "composite", "zh": "拿起笔", "en": "pick_up_pen", "note": "pinch -> adjust grip"},
|
| 63 |
+
{"id": 53, "category": "composite", "zh": "写字", "en": "write", "note": "controlled motion -> apply pressure"},
|
| 64 |
+
{"id": 54, "category": "composite", "zh": "翻页", "en": "turn_page", "note": "pinch corner -> flip"},
|
| 65 |
+
{"id": 55, "category": "composite", "zh": "插入电源", "en": "plug_in_power", "note": "align -> push in"},
|
| 66 |
+
{"id": 56, "category": "composite", "zh": "连接线缆", "en": "connect_cable", "note": "align connector -> insert -> confirm"},
|
| 67 |
+
{"id": 57, "category": "composite", "zh": "组装部件", "en": "assemble_parts", "note": "align -> snap/screw"},
|
| 68 |
+
{"id": 58, "category": "composite", "zh": "称重", "en": "weigh", "note": "place item -> read scale"},
|
| 69 |
+
{"id": 59, "category": "composite", "zh": "量取", "en": "measure_volume", "note": "pour -> read marking -> adjust"},
|
| 70 |
+
{"id": 60, "category": "composite", "zh": "计数", "en": "count", "note": "move one by one -> tally"},
|
| 71 |
+
{"id": 61, "category": "composite", "zh": "挂衣服", "en": "hang_clothing", "note": "take hanger -> insert garment -> hang"},
|
| 72 |
+
{"id": 62, "category": "composite", "zh": "铲猫砂", "en": "scoop_litter", "note": "insert -> raise -> sift -> pour"},
|
| 73 |
+
{"id": 63, "category": "composite", "zh": "搅拌", "en": "stir", "note": "insert spoon -> circular motion"},
|
| 74 |
+
{"id": 64, "category": "composite", "zh": "剪切", "en": "cut", "note": "hold scissors -> align -> close"}
|
| 75 |
+
]
|
| 76 |
+
}
|
experiments/slurm/freeze_all_rows.sh
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Create folder structure for ALL rows across Tables 1, 3, 4, 5, 7 and
|
| 3 |
+
# freeze the current experiments/ code into each one. After this you can
|
| 4 |
+
# cd into any <table>/<row>/ and run ./run.sh to submit 5 SLURM seeds.
|
| 5 |
+
#
|
| 6 |
+
# Re-running this script is safe: it will re-freeze the code (overwrite the
|
| 7 |
+
# snapshot), but won't clobber any existing seeds/ outputs.
|
| 8 |
+
set -euo pipefail
|
| 9 |
+
|
| 10 |
+
BASEDIR=${BASEDIR:-${PULSE_ROOT}}
|
| 11 |
+
EXP=${BASEDIR}/experiments
|
| 12 |
+
SETUP="${EXP}/setup_row.sh"
|
| 13 |
+
|
| 14 |
+
COMMON="--epochs 40 --batch_size 32 --lr 3e-4 --weight_decay 1e-4 \
|
| 15 |
+
--patience 12 --label_smoothing 0.05 --use_class_weights \
|
| 16 |
+
--num_workers 2"
|
| 17 |
+
|
| 18 |
+
ALL5="imu,emg,eyetrack,mocap,pressure"
|
| 19 |
+
|
| 20 |
+
row () {
|
| 21 |
+
# $1=table $2=row $3=desc $4=cli
|
| 22 |
+
bash "${SETUP}" --table "$1" --row "$2" --desc "$3" --cli "$4 ${COMMON}"
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
# ============================================================
|
| 26 |
+
# Table 1: Main comparison at T_fut=2s
|
| 27 |
+
# ============================================================
|
| 28 |
+
T1=table1_main_comparison
|
| 29 |
+
cat > "${BASEDIR}/${T1}/README.md" <<'EOF'
|
| 30 |
+
# Table 1: Main Comparison (Next-Action Prediction, T_fut = 2 s)
|
| 31 |
+
|
| 32 |
+
Each baseline is run on its most favourable modality subset; our model
|
| 33 |
+
(DailyActFormer) uses all 5 synchronised modalities. 5 seeds per row;
|
| 34 |
+
report mean ± std of Verb fine Top-1/5, Noun Top-1/5, Hand Top-1, Action
|
| 35 |
+
Top-1 (= verb ∧ noun ∧ hand). Action Top-1 is the headline metric.
|
| 36 |
+
|
| 37 |
+
| Row | Method | Family | Modalities |
|
| 38 |
+
|-----|-------------------|-----------------|---------------------|
|
| 39 |
+
| 01 | DailyActFormer | cross-modal Trf | imu+emg+eye+mocap+P |
|
| 40 |
+
| 02 | DeepConvLSTM | CNN+LSTM (IMU) | imu |
|
| 41 |
+
| 03 | DeepConvLSTM 3mod | CNN+LSTM | imu+mocap+emg |
|
| 42 |
+
| 04 | RULSTM | rolling LSTM | imu+mocap |
|
| 43 |
+
| 05 | FUTR | long-term Trf | mocap+imu+emg |
|
| 44 |
+
| 06 | AFFT | multimodal Trf | imu+emg+eye+mocap |
|
| 45 |
+
| 07 | HandFormer | hand-pose Trf | mocap (fingers) |
|
| 46 |
+
| 08 | ActionLLM (LoRA) | LLM-based | imu+emg+eye |
|
| 47 |
+
EOF
|
| 48 |
+
|
| 49 |
+
mkdir -p "${BASEDIR}/${T1}"
|
| 50 |
+
row ${T1} row01_ours_dailyactformer_all5 \
|
| 51 |
+
"Our model, all 5 modalities (headline row)" \
|
| 52 |
+
"--model dailyactformer --modalities ${ALL5} --t_obs 8 --t_fut 2"
|
| 53 |
+
|
| 54 |
+
row ${T1} row02_deepconvlstm_imu \
|
| 55 |
+
"DeepConvLSTM on IMU only (classic HAR baseline)" \
|
| 56 |
+
"--model deepconvlstm --modalities imu --t_obs 8 --t_fut 2"
|
| 57 |
+
|
| 58 |
+
row ${T1} row03_deepconvlstm_3mod \
|
| 59 |
+
"DeepConvLSTM on IMU+MoCap+EMG (best 3-modality concat)" \
|
| 60 |
+
"--model deepconvlstm --modalities imu,mocap,emg --t_obs 8 --t_fut 2"
|
| 61 |
+
|
| 62 |
+
row ${T1} row04_rulstm_imu_mocap \
|
| 63 |
+
"RULSTM, rolling-unrolling LSTM (IMU + MoCap late fusion)" \
|
| 64 |
+
"--model rulstm --modalities imu,mocap --t_obs 8 --t_fut 2"
|
| 65 |
+
|
| 66 |
+
row ${T1} row05_futr_3mod \
|
| 67 |
+
"FUTR (causal transformer) on MoCap+IMU+EMG" \
|
| 68 |
+
"--model futr --modalities mocap,imu,emg --t_obs 8 --t_fut 2"
|
| 69 |
+
|
| 70 |
+
row ${T1} row06_afft_4mod \
|
| 71 |
+
"AFFT (anticipative feature fusion transformer) on 4 modalities" \
|
| 72 |
+
"--model afft --modalities imu,emg,eyetrack,mocap --t_obs 8 --t_fut 2"
|
| 73 |
+
|
| 74 |
+
row ${T1} row07_handformer_mocap \
|
| 75 |
+
"HandFormer (skeleton-only ECCV'24) on MoCap finger joints" \
|
| 76 |
+
"--model handformer --modalities mocap --t_obs 8 --t_fut 2"
|
| 77 |
+
|
| 78 |
+
row ${T1} row08_actionllm_3mod \
|
| 79 |
+
"ActionLLM (Qwen2.5-0.5B + LoRA) on IMU+EMG+EyeTrack" \
|
| 80 |
+
"--model actionllm --modalities imu,emg,eyetrack --t_obs 8 --t_fut 2"
|
| 81 |
+
|
| 82 |
+
# ============================================================
|
| 83 |
+
# Table 3: Horizon curve (DailyActFormer)
|
| 84 |
+
# ============================================================
|
| 85 |
+
T3=table3_horizon_curve
|
| 86 |
+
mkdir -p "${BASEDIR}/${T3}"
|
| 87 |
+
cat > "${BASEDIR}/${T3}/README.md" <<'EOF'
|
| 88 |
+
# Table 3: Prediction Horizon Curve (DailyActFormer, all 5 modalities)
|
| 89 |
+
|
| 90 |
+
Same model, varying T_fut. Expect monotonic drop in Action Top-1 as
|
| 91 |
+
horizon grows; plot line graph in the paper alongside this table.
|
| 92 |
+
EOF
|
| 93 |
+
HORIZONS=(1 2 5 10 15)
|
| 94 |
+
for i in "${!HORIZONS[@]}"; do
|
| 95 |
+
tfut="${HORIZONS[$i]}"
|
| 96 |
+
idx=$(printf "%02d" $((i+1)))
|
| 97 |
+
row ${T3} row${idx}_ours_tfut${tfut}s \
|
| 98 |
+
"Our model at T_fut=${tfut}s" \
|
| 99 |
+
"--model dailyactformer --modalities ${ALL5} --t_obs 8 --t_fut ${tfut}"
|
| 100 |
+
done
|
| 101 |
+
|
| 102 |
+
# ============================================================
|
| 103 |
+
# Table 4: Modality ablation on DailyActFormer (T_fut=2s)
|
| 104 |
+
# ============================================================
|
| 105 |
+
T4=table4_modality_ablation
|
| 106 |
+
mkdir -p "${BASEDIR}/${T4}"
|
| 107 |
+
cat > "${BASEDIR}/${T4}/README.md" <<'EOF'
|
| 108 |
+
# Table 4: Modality Ablation (DailyActFormer, T_fut = 2 s)
|
| 109 |
+
|
| 110 |
+
Same model, progressively remove modalities. Each row trained from scratch.
|
| 111 |
+
EOF
|
| 112 |
+
row ${T4} row01_full_5mod "Full 5-modality (reference)" "--model dailyactformer --modalities imu,emg,eyetrack,mocap,pressure --t_obs 8 --t_fut 2"
|
| 113 |
+
row ${T4} row02_no_pressure "Drop pressure" "--model dailyactformer --modalities imu,emg,eyetrack,mocap --t_obs 8 --t_fut 2"
|
| 114 |
+
row ${T4} row03_no_eyetrack "Drop eye-tracking" "--model dailyactformer --modalities imu,emg,mocap,pressure --t_obs 8 --t_fut 2"
|
| 115 |
+
row ${T4} row04_no_emg "Drop EMG" "--model dailyactformer --modalities imu,eyetrack,mocap,pressure --t_obs 8 --t_fut 2"
|
| 116 |
+
row ${T4} row05_no_imu "Drop IMU" "--model dailyactformer --modalities emg,eyetrack,mocap,pressure --t_obs 8 --t_fut 2"
|
| 117 |
+
row ${T4} row06_no_mocap "Drop MoCap" "--model dailyactformer --modalities imu,emg,eyetrack,pressure --t_obs 8 --t_fut 2"
|
| 118 |
+
row ${T4} row07_imu_emg_only "Only IMU + EMG (physiology-light)" "--model dailyactformer --modalities imu,emg --t_obs 8 --t_fut 2"
|
| 119 |
+
row ${T4} row08_mocap_only "Only MoCap (skeleton-only)" "--model dailyactformer --modalities mocap --t_obs 8 --t_fut 2"
|
| 120 |
+
|
| 121 |
+
# ============================================================
|
| 122 |
+
# Table 5: Component ablation (DailyActFormer switches)
|
| 123 |
+
# ============================================================
|
| 124 |
+
T5=table5_component_ablation
|
| 125 |
+
mkdir -p "${BASEDIR}/${T5}"
|
| 126 |
+
cat > "${BASEDIR}/${T5}/README.md" <<'EOF'
|
| 127 |
+
# Table 5: Component Ablation (DailyActFormer, T_fut = 2 s)
|
| 128 |
+
|
| 129 |
+
Each row toggles one architectural/training component of our model.
|
| 130 |
+
Component flags are implemented as CLI switches on train_seqpred.py;
|
| 131 |
+
see models_seqpred.py for the corresponding model options.
|
| 132 |
+
EOF
|
| 133 |
+
row ${T5} row01_full \
|
| 134 |
+
"Full model (reference)" \
|
| 135 |
+
"--model dailyactformer --modalities ${ALL5} --t_obs 8 --t_fut 2"
|
| 136 |
+
row ${T5} row02_no_composite_head \
|
| 137 |
+
"Drop the auxiliary verb-composite head (lambda=0)" \
|
| 138 |
+
"--model dailyactformer --modalities ${ALL5} --t_obs 8 --t_fut 2 --lambda_verb_composite 0.0"
|
| 139 |
+
row ${T5} row03_equal_lambda \
|
| 140 |
+
"Equal-weight all 4 heads (no prior on verb>hand)" \
|
| 141 |
+
"--model dailyactformer --modalities ${ALL5} --t_obs 8 --t_fut 2 --lambda_verb_composite 1.0 --lambda_hand 1.0"
|
| 142 |
+
row ${T5} row04_no_class_weight \
|
| 143 |
+
"No inverse-frequency class weighting" \
|
| 144 |
+
"--model dailyactformer --modalities ${ALL5} --t_obs 8 --t_fut 2 --lambda_verb_composite 0.5"
|
| 145 |
+
# row04 re-exposes the default; the variable-off is the absence of --use_class_weights
|
| 146 |
+
# We patch this manually — strip the flag out of COMMON.
|
| 147 |
+
ROW_DIR="${BASEDIR}/${T5}/row04_no_class_weight/run.sh"
|
| 148 |
+
if [[ -e "${ROW_DIR}" ]]; then
|
| 149 |
+
sed -i 's/--use_class_weights //g' "${ROW_DIR}"
|
| 150 |
+
fi
|
| 151 |
+
|
| 152 |
+
row ${T5} row05_no_label_smoothing \
|
| 153 |
+
"Label smoothing off" \
|
| 154 |
+
"--model dailyactformer --modalities ${ALL5} --t_obs 8 --t_fut 2 --label_smoothing 0.0"
|
| 155 |
+
|
| 156 |
+
# ============================================================
|
| 157 |
+
# Table 7: Missing-modality robustness (train once, eval 6 ways)
|
| 158 |
+
# ============================================================
|
| 159 |
+
T7=table7_missing_modality
|
| 160 |
+
mkdir -p "${BASEDIR}/${T7}"
|
| 161 |
+
cat > "${BASEDIR}/${T7}/README.md" <<'EOF'
|
| 162 |
+
# Table 7: Missing-Modality Robustness (T_fut = 2 s)
|
| 163 |
+
|
| 164 |
+
Train DailyActFormer with random per-modality dropout (p=0.3). At test time,
|
| 165 |
+
evaluate under 6 configurations: full / drop one modality each. Only the
|
| 166 |
+
training job has its own folder; eval uses the trained checkpoint to fill
|
| 167 |
+
multiple rows of the final table.
|
| 168 |
+
EOF
|
| 169 |
+
row ${T7} row01_train_with_modality_dropout \
|
| 170 |
+
"DailyActFormer trained with --modality_dropout 0.3" \
|
| 171 |
+
"--model dailyactformer --modalities ${ALL5} --t_obs 8 --t_fut 2 --modality_dropout 0.3"
|
| 172 |
+
# The 6 test-time configurations (full / no_P / no_E / no_emg / no_imu /
|
| 173 |
+
# no_mocap) will be produced by a separate eval script that loads the
|
| 174 |
+
# checkpoint from row01 and runs evaluate() with modality subsets. See
|
| 175 |
+
# experiments/tasks/eval_missing_modality.py (TBD).
|
| 176 |
+
|
| 177 |
+
echo ""
|
| 178 |
+
echo "[ok] Froze rows under:"
|
| 179 |
+
echo " ${BASEDIR}/{${T1},${T3},${T4},${T5},${T7}}/"
|
experiments/slurm/run_ablation_fix.sh
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=ablation_fix
|
| 3 |
+
#SBATCH --partition=gpuA800
|
| 4 |
+
#SBATCH --gres=gpu:1
|
| 5 |
+
#SBATCH --nodes=1
|
| 6 |
+
#SBATCH --ntasks=1
|
| 7 |
+
#SBATCH --cpus-per-task=4
|
| 8 |
+
#SBATCH --mem=32G
|
| 9 |
+
#SBATCH --time=1:00:00
|
| 10 |
+
#SBATCH --output=${PULSE_ROOT}/results/ablation_fix_%j.log
|
| 11 |
+
|
| 12 |
+
# Fix: mocap+emg late+pretrained — pretrain MOCAP branch (idx=0) instead of emg
|
| 13 |
+
set -e
|
| 14 |
+
export PYTHONUNBUFFERED=1
|
| 15 |
+
|
| 16 |
+
PYTHON=python
|
| 17 |
+
BASEDIR=${PULSE_ROOT}
|
| 18 |
+
SCRIPT=${BASEDIR}/experiments/train_exp1.py
|
| 19 |
+
OUTDIR=${BASEDIR}/results/modality_ablation
|
| 20 |
+
COMMON="--model transformer --epochs 100 --batch_size 16 --lr 1e-3 --weight_decay 1e-4 --hidden_dim 128 --downsample 5 --patience 15 --proj_dim 0 --output_dir $OUTDIR"
|
| 21 |
+
SEEDS=(42 123 456 789 2024)
|
| 22 |
+
|
| 23 |
+
PT_MOCAP=${BASEDIR}/results/exp1_v8/transformer_mocap_early/model_best.pt
|
| 24 |
+
|
| 25 |
+
echo "=== Fix: mocap+emg / late+pretrained(mocap, idx=0) ==="
|
| 26 |
+
for seed in "${SEEDS[@]}"; do
|
| 27 |
+
echo " mocap+emg seed=$seed"
|
| 28 |
+
$PYTHON $SCRIPT --modalities mocap,emg --fusion late --seed $seed \
|
| 29 |
+
--pretrained_backbone $PT_MOCAP --freeze_backbone_idx 0 \
|
| 30 |
+
--tag ablation_pt_s${seed} $COMMON 2>&1 | tail -5
|
| 31 |
+
done
|
| 32 |
+
|
| 33 |
+
echo "=== Done ==="
|
experiments/slurm/run_ablation_fusion.sh
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=ablation_fuse
|
| 3 |
+
#SBATCH --partition=gpuA800
|
| 4 |
+
#SBATCH --gres=gpu:2
|
| 5 |
+
#SBATCH --nodes=1
|
| 6 |
+
#SBATCH --ntasks=1
|
| 7 |
+
#SBATCH --cpus-per-task=8
|
| 8 |
+
#SBATCH --mem=64G
|
| 9 |
+
#SBATCH --time=4:00:00
|
| 10 |
+
#SBATCH --output=${PULSE_ROOT}/results/ablation_fusion_%j.log
|
| 11 |
+
|
| 12 |
+
# Test confidence-weighted and learned-weight fusion on all multi-modal combos
|
| 13 |
+
# Compare against existing mean fusion results
|
| 14 |
+
|
| 15 |
+
set -e
|
| 16 |
+
export PYTHONUNBUFFERED=1
|
| 17 |
+
|
| 18 |
+
PYTHON=python
|
| 19 |
+
BASEDIR=${PULSE_ROOT}
|
| 20 |
+
SCRIPT=${BASEDIR}/experiments/train_exp1.py
|
| 21 |
+
OUTDIR=${BASEDIR}/results/modality_ablation
|
| 22 |
+
COMMON="--model transformer --epochs 100 --batch_size 16 --lr 1e-3 --weight_decay 1e-4 --hidden_dim 128 --downsample 5 --patience 15 --proj_dim 0 --output_dir $OUTDIR"
|
| 23 |
+
SEEDS=(42 123 456 789 2024)
|
| 24 |
+
|
| 25 |
+
PT_IMU=${BASEDIR}/results/exp1_v7/transformer_imu_early/model_best.pt
|
| 26 |
+
PT_MOCAP=${BASEDIR}/results/exp1_v8/transformer_mocap_early/model_best.pt
|
| 27 |
+
|
| 28 |
+
echo "=== Ablation: Confidence & Learned Fusion ==="
|
| 29 |
+
|
| 30 |
+
# ============================================================
|
| 31 |
+
# GPU 0: confidence-weighted fusion
|
| 32 |
+
# ============================================================
|
| 33 |
+
(
|
| 34 |
+
export CUDA_VISIBLE_DEVICES=0
|
| 35 |
+
|
| 36 |
+
# mocap+imu / confidence / pretrained imu (idx=1)
|
| 37 |
+
echo "--- GPU0: mocap+imu / confidence ---"
|
| 38 |
+
for seed in "${SEEDS[@]}"; do
|
| 39 |
+
echo " mocap+imu confidence seed=$seed"
|
| 40 |
+
$PYTHON $SCRIPT --modalities mocap,imu --fusion late --late_agg confidence \
|
| 41 |
+
--seed $seed --pretrained_backbone $PT_IMU --freeze_backbone_idx 1 \
|
| 42 |
+
--tag ablation_conf_s${seed} $COMMON 2>&1 | tail -3
|
| 43 |
+
done
|
| 44 |
+
|
| 45 |
+
# emg+imu / confidence / pretrained imu (idx=1)
|
| 46 |
+
echo "--- GPU0: emg+imu / confidence ---"
|
| 47 |
+
for seed in "${SEEDS[@]}"; do
|
| 48 |
+
echo " emg+imu confidence seed=$seed"
|
| 49 |
+
$PYTHON $SCRIPT --modalities emg,imu --fusion late --late_agg confidence \
|
| 50 |
+
--seed $seed --pretrained_backbone $PT_IMU --freeze_backbone_idx 1 \
|
| 51 |
+
--tag ablation_conf_s${seed} $COMMON 2>&1 | tail -3
|
| 52 |
+
done
|
| 53 |
+
|
| 54 |
+
# mocap+emg / confidence / pretrained mocap (idx=0)
|
| 55 |
+
echo "--- GPU0: mocap+emg / confidence ---"
|
| 56 |
+
for seed in "${SEEDS[@]}"; do
|
| 57 |
+
echo " mocap+emg confidence seed=$seed"
|
| 58 |
+
$PYTHON $SCRIPT --modalities mocap,emg --fusion late --late_agg confidence \
|
| 59 |
+
--seed $seed --pretrained_backbone $PT_MOCAP --freeze_backbone_idx 0 \
|
| 60 |
+
--tag ablation_conf_s${seed} $COMMON 2>&1 | tail -3
|
| 61 |
+
done
|
| 62 |
+
|
| 63 |
+
# mocap+emg+imu / confidence / pretrained imu (idx=2, modalities=mocap,emg,imu)
|
| 64 |
+
echo "--- GPU0: mocap+emg+imu / confidence ---"
|
| 65 |
+
for seed in "${SEEDS[@]}"; do
|
| 66 |
+
echo " mocap+emg+imu confidence seed=$seed"
|
| 67 |
+
$PYTHON $SCRIPT --modalities imu,mocap,emg --fusion late --late_agg confidence \
|
| 68 |
+
--seed $seed --pretrained_backbone $PT_IMU --freeze_backbone_idx 0 \
|
| 69 |
+
--tag ablation_conf_s${seed} $COMMON 2>&1 | tail -3
|
| 70 |
+
done
|
| 71 |
+
|
| 72 |
+
echo "--- GPU0 Done ---"
|
| 73 |
+
) &
|
| 74 |
+
PID0=$!
|
| 75 |
+
|
| 76 |
+
# ============================================================
|
| 77 |
+
# GPU 1: learned-weight fusion
|
| 78 |
+
# ============================================================
|
| 79 |
+
(
|
| 80 |
+
export CUDA_VISIBLE_DEVICES=1
|
| 81 |
+
|
| 82 |
+
# mocap+imu / learned / pretrained imu (idx=1)
|
| 83 |
+
echo "--- GPU1: mocap+imu / learned ---"
|
| 84 |
+
for seed in "${SEEDS[@]}"; do
|
| 85 |
+
echo " mocap+imu learned seed=$seed"
|
| 86 |
+
$PYTHON $SCRIPT --modalities mocap,imu --fusion late --late_agg learned \
|
| 87 |
+
--seed $seed --pretrained_backbone $PT_IMU --freeze_backbone_idx 1 \
|
| 88 |
+
--tag ablation_lrn_s${seed} $COMMON 2>&1 | tail -3
|
| 89 |
+
done
|
| 90 |
+
|
| 91 |
+
# emg+imu / learned / pretrained imu (idx=1)
|
| 92 |
+
echo "--- GPU1: emg+imu / learned ---"
|
| 93 |
+
for seed in "${SEEDS[@]}"; do
|
| 94 |
+
echo " emg+imu learned seed=$seed"
|
| 95 |
+
$PYTHON $SCRIPT --modalities emg,imu --fusion late --late_agg learned \
|
| 96 |
+
--seed $seed --pretrained_backbone $PT_IMU --freeze_backbone_idx 1 \
|
| 97 |
+
--tag ablation_lrn_s${seed} $COMMON 2>&1 | tail -3
|
| 98 |
+
done
|
| 99 |
+
|
| 100 |
+
# mocap+emg / learned / pretrained mocap (idx=0)
|
| 101 |
+
echo "--- GPU1: mocap+emg / learned ---"
|
| 102 |
+
for seed in "${SEEDS[@]}"; do
|
| 103 |
+
echo " mocap+emg learned seed=$seed"
|
| 104 |
+
$PYTHON $SCRIPT --modalities mocap,emg --fusion late --late_agg learned \
|
| 105 |
+
--seed $seed --pretrained_backbone $PT_MOCAP --freeze_backbone_idx 0 \
|
| 106 |
+
--tag ablation_lrn_s${seed} $COMMON 2>&1 | tail -3
|
| 107 |
+
done
|
| 108 |
+
|
| 109 |
+
# mocap+emg+imu / learned / pretrained imu (idx=0, modalities=imu,mocap,emg)
|
| 110 |
+
echo "--- GPU1: mocap+emg+imu / learned ---"
|
| 111 |
+
for seed in "${SEEDS[@]}"; do
|
| 112 |
+
echo " mocap+emg+imu learned seed=$seed"
|
| 113 |
+
$PYTHON $SCRIPT --modalities imu,mocap,emg --fusion late --late_agg learned \
|
| 114 |
+
--seed $seed --pretrained_backbone $PT_IMU --freeze_backbone_idx 0 \
|
| 115 |
+
--tag ablation_lrn_s${seed} $COMMON 2>&1 | tail -3
|
| 116 |
+
done
|
| 117 |
+
|
| 118 |
+
echo "--- GPU1 Done ---"
|
| 119 |
+
) &
|
| 120 |
+
PID1=$!
|
| 121 |
+
|
| 122 |
+
wait $PID0 $PID1
|
| 123 |
+
|
| 124 |
+
# ============================================================
|
| 125 |
+
# Collect results
|
| 126 |
+
# ============================================================
|
| 127 |
+
echo ""
|
| 128 |
+
echo "=== Fusion Comparison ==="
|
| 129 |
+
$PYTHON -c "
|
| 130 |
+
import json, os, numpy as np
|
| 131 |
+
|
| 132 |
+
base = '$OUTDIR'
|
| 133 |
+
v8_base = '${BASEDIR}/results/exp1_v8_multiseed'
|
| 134 |
+
v9_base = '${BASEDIR}/results/exp1_v9'
|
| 135 |
+
seeds = [42, 123, 456, 789, 2024]
|
| 136 |
+
|
| 137 |
+
configs = [
|
| 138 |
+
# (label, pattern_template)
|
| 139 |
+
# mean (from previous ablation run)
|
| 140 |
+
('mocap+imu / mean', base + '/transformer_mocap-imu_late_ablation_pt_s{}/results.json'),
|
| 141 |
+
('mocap+imu / confidence', base + '/transformer_mocap-imu_late_ablation_conf_s{}/results.json'),
|
| 142 |
+
('mocap+imu / learned', base + '/transformer_mocap-imu_late_ablation_lrn_s{}/results.json'),
|
| 143 |
+
('emg+imu / mean', base + '/transformer_emg-imu_late_ablation_pt_s{}/results.json'),
|
| 144 |
+
('emg+imu / confidence', base + '/transformer_emg-imu_late_ablation_conf_s{}/results.json'),
|
| 145 |
+
('emg+imu / learned', base + '/transformer_emg-imu_late_ablation_lrn_s{}/results.json'),
|
| 146 |
+
('mocap+emg / mean', base + '/transformer_mocap-emg_late_ablation_pt_s{}/results.json'),
|
| 147 |
+
('mocap+emg / confidence', base + '/transformer_mocap-emg_late_ablation_conf_s{}/results.json'),
|
| 148 |
+
('mocap+emg / learned', base + '/transformer_mocap-emg_late_ablation_lrn_s{}/results.json'),
|
| 149 |
+
('3mod / mean', v9_base + '/transformer_imu-mocap-emg_late_pt_s{}/results.json'),
|
| 150 |
+
('3mod / confidence', base + '/transformer_imu-mocap-emg_late_ablation_conf_s{}/results.json'),
|
| 151 |
+
('3mod / learned', base + '/transformer_imu-mocap-emg_late_ablation_lrn_s{}/results.json'),
|
| 152 |
+
]
|
| 153 |
+
|
| 154 |
+
print(f'{\"Config\":<30} {\"F1 (mean±std)\":<20} {\"Acc (mean±std)\":<20} N')
|
| 155 |
+
print('-' * 75)
|
| 156 |
+
for label, pat in configs:
|
| 157 |
+
f1s, accs = [], []
|
| 158 |
+
for s in seeds:
|
| 159 |
+
path = pat.format(s)
|
| 160 |
+
if os.path.exists(path):
|
| 161 |
+
with open(path) as f:
|
| 162 |
+
d = json.load(f)
|
| 163 |
+
f1s.append(d['test_macro_f1'])
|
| 164 |
+
accs.append(d['test_accuracy'])
|
| 165 |
+
if f1s:
|
| 166 |
+
f1 = np.array(f1s)
|
| 167 |
+
acc = np.array(accs)
|
| 168 |
+
print(f'{label:<30} {f1.mean():.3f}±{f1.std():.3f} {acc.mean():.3f}±{acc.std():.3f} {len(f1s)}')
|
| 169 |
+
else:
|
| 170 |
+
print(f'{label:<30} (no results)')
|
| 171 |
+
"
|
| 172 |
+
|
| 173 |
+
echo ""
|
| 174 |
+
echo "=== All done ==="
|
experiments/slurm/run_asformer_exp3.sh
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --partition=gpuA800
|
| 3 |
+
#SBATCH --nodes=1
|
| 4 |
+
#SBATCH --ntasks=1
|
| 5 |
+
#SBATCH --cpus-per-task=4
|
| 6 |
+
#SBATCH --gres=gpu:1
|
| 7 |
+
#SBATCH --mem=32G
|
| 8 |
+
#SBATCH --time=4:00:00
|
| 9 |
+
#SBATCH --job-name=ASF_exp3
|
| 10 |
+
#SBATCH --output=${PULSE_ROOT}/results/asformer_exp3_%j.log
|
| 11 |
+
|
| 12 |
+
set -e
|
| 13 |
+
PYTHON=python
|
| 14 |
+
PROJECT=${PULSE_ROOT}
|
| 15 |
+
cd $PROJECT
|
| 16 |
+
|
| 17 |
+
EXP3_OUT=$PROJECT/results/published_baselines/exp3_asformer
|
| 18 |
+
mkdir -p $EXP3_OUT
|
| 19 |
+
|
| 20 |
+
echo "=== ASFormer Contact Detection ==="
|
| 21 |
+
|
| 22 |
+
for MOD in mocap emg imu "mocap,emg" "mocap,emg,eyetrack" "mocap,emg,eyetrack,imu"; do
|
| 23 |
+
echo "--- ASFormer / ${MOD} ---"
|
| 24 |
+
$PYTHON experiments/train_exp3.py \
|
| 25 |
+
--model asformer --modalities $MOD \
|
| 26 |
+
--hidden_dim 64 --epochs 50 --batch_size 32 \
|
| 27 |
+
--lr 1e-3 --weight_decay 1e-4 --downsample 2 \
|
| 28 |
+
--seed 42 --output_dir $EXP3_OUT 2>&1 | tail -8
|
| 29 |
+
done
|
| 30 |
+
|
| 31 |
+
echo ""
|
| 32 |
+
echo "=== Results ==="
|
| 33 |
+
for f in $EXP3_OUT/*/results.json; do
|
| 34 |
+
if [ -f "$f" ]; then
|
| 35 |
+
$PYTHON -c "
|
| 36 |
+
import json
|
| 37 |
+
with open('$f') as fp:
|
| 38 |
+
r = json.load(fp)
|
| 39 |
+
mods = ','.join(r.get('input_modalities', []))
|
| 40 |
+
m = r.get('test_metrics', {})
|
| 41 |
+
print(f' ASFormer | {mods:<30} | R_F1={m.get(\"right_f1\",0):.4f} L_F1={m.get(\"left_f1\",0):.4f} Avg_F1={m.get(\"avg_f1\",0):.4f}')
|
| 42 |
+
"
|
| 43 |
+
fi
|
| 44 |
+
done
|
experiments/slurm/run_exp1.sh
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH -J exp1_scene
|
| 3 |
+
#SBATCH -p gpuA800
|
| 4 |
+
#SBATCH --gres=gpu:1
|
| 5 |
+
#SBATCH -N 1
|
| 6 |
+
#SBATCH -n 1
|
| 7 |
+
#SBATCH --cpus-per-task=8
|
| 8 |
+
#SBATCH --mem=64G
|
| 9 |
+
#SBATCH -t 12:00:00
|
| 10 |
+
#SBATCH -o ${PULSE_ROOT}/results/exp1/slurm_%j.out
|
| 11 |
+
#SBATCH -e ${PULSE_ROOT}/results/exp1/slurm_%j.err
|
| 12 |
+
|
| 13 |
+
export PYTHONUNBUFFERED=1
|
| 14 |
+
|
| 15 |
+
echo "=== Job Info ==="
|
| 16 |
+
echo "Job ID: $SLURM_JOB_ID"
|
| 17 |
+
echo "Node: $SLURM_NODELIST"
|
| 18 |
+
echo "Start time: $(date)"
|
| 19 |
+
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
|
| 20 |
+
echo "================"
|
| 21 |
+
|
| 22 |
+
PYTHON=python
|
| 23 |
+
SCRIPT=${PULSE_ROOT}/experiments/train_exp1.py
|
| 24 |
+
OUTDIR=${PULSE_ROOT}/results/exp1
|
| 25 |
+
|
| 26 |
+
cd ${PULSE_ROOT}
|
| 27 |
+
|
| 28 |
+
$PYTHON $SCRIPT --run_all \
|
| 29 |
+
--epochs 100 \
|
| 30 |
+
--batch_size 16 \
|
| 31 |
+
--lr 1e-3 \
|
| 32 |
+
--weight_decay 1e-4 \
|
| 33 |
+
--hidden_dim 128 \
|
| 34 |
+
--downsample 5 \
|
| 35 |
+
--patience 15 \
|
| 36 |
+
--seed 42 \
|
| 37 |
+
--output_dir $OUTDIR
|
| 38 |
+
|
| 39 |
+
echo "=== Done ==="
|
| 40 |
+
echo "End time: $(date)"
|
experiments/slurm/run_exp1_fusion.sh
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Submit all fusion experiments as individual 1-GPU SLURM jobs
|
| 3 |
+
# SLURM scheduler will automatically place them on any available GPU
|
| 4 |
+
|
| 5 |
+
PYTHON=python
|
| 6 |
+
SCRIPT=${PULSE_ROOT}/experiments/train_exp1.py
|
| 7 |
+
OUTDIR=${PULSE_ROOT}/results/exp1
|
| 8 |
+
LOGDIR=${OUTDIR}/slurm_logs
|
| 9 |
+
mkdir -p $LOGDIR
|
| 10 |
+
|
| 11 |
+
COMMON_ARGS="--model transformer --epochs 100 --batch_size 16 --lr 1e-3 --weight_decay 1e-4 --hidden_dim 128 --downsample 5 --patience 15 --seed 42 --output_dir $OUTDIR"
|
| 12 |
+
|
| 13 |
+
FUSIONS=(weighted_late gated_late stacking product moe late attention)
|
| 14 |
+
MODALITIES=("mocap,emg,eyetrack" "mocap,emg,eyetrack,imu,pressure")
|
| 15 |
+
|
| 16 |
+
for fusion in "${FUSIONS[@]}"; do
|
| 17 |
+
for mods in "${MODALITIES[@]}"; do
|
| 18 |
+
mod_tag=$(echo $mods | tr ',' '-')
|
| 19 |
+
job_name="f_${fusion}_${mod_tag}"
|
| 20 |
+
sbatch \
|
| 21 |
+
-J "$job_name" \
|
| 22 |
+
-p gpuA800 \
|
| 23 |
+
--gres=gpu:1 \
|
| 24 |
+
-N 1 -n 1 \
|
| 25 |
+
--cpus-per-task=8 \
|
| 26 |
+
--mem=32G \
|
| 27 |
+
-t 3:00:00 \
|
| 28 |
+
-o "${LOGDIR}/${job_name}_%j.out" \
|
| 29 |
+
-e "${LOGDIR}/${job_name}_%j.err" \
|
| 30 |
+
--export=ALL \
|
| 31 |
+
--wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --fusion $fusion --modalities $mods $COMMON_ARGS"
|
| 32 |
+
echo "Submitted: $job_name"
|
| 33 |
+
done
|
| 34 |
+
done
|
| 35 |
+
|
| 36 |
+
echo "All 14 fusion experiments submitted!"
|
experiments/slurm/run_exp1_parallel.sh
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Scene Recognition (Exp1) - Parallelized version
|
| 3 |
+
# Part 1: 9 modality combos × 3 backbones = 27 jobs (early fusion)
|
| 4 |
+
# Part 2: 7 fusion methods × transformer × (3-core + all-5) = 14 jobs
|
| 5 |
+
# Total: 41 jobs
|
| 6 |
+
|
| 7 |
+
PYTHON=python
|
| 8 |
+
BASEDIR=${PULSE_ROOT}
|
| 9 |
+
SCRIPT=${BASEDIR}/experiments/train_exp1.py
|
| 10 |
+
OUTDIR=${BASEDIR}/results/exp1_v2
|
| 11 |
+
LOGDIR=${OUTDIR}/slurm_logs
|
| 12 |
+
mkdir -p $LOGDIR
|
| 13 |
+
|
| 14 |
+
COMMON="--epochs 100 --batch_size 16 --lr 1e-3 --weight_decay 1e-4 --hidden_dim 128 --downsample 5 --patience 15 --seed 42 --output_dir $OUTDIR"
|
| 15 |
+
|
| 16 |
+
MODS=("mocap" "emg" "eyetrack" "imu" "pressure" "mocap,emg,eyetrack" "mocap,emg,eyetrack,imu" "mocap,emg,eyetrack,pressure" "mocap,emg,eyetrack,imu,pressure")
|
| 17 |
+
MODELS=("cnn" "lstm" "transformer")
|
| 18 |
+
|
| 19 |
+
# Part 1: Modality ablation × 3 backbones
|
| 20 |
+
echo "=== Part 1: Modality Ablation (27 jobs) ==="
|
| 21 |
+
for mods in "${MODS[@]}"; do
|
| 22 |
+
mod_tag=$(echo $mods | tr ',' '-')
|
| 23 |
+
for model in "${MODELS[@]}"; do
|
| 24 |
+
sbatch \
|
| 25 |
+
-J "exp1_${model}_${mod_tag}" \
|
| 26 |
+
-p gpuA800 \
|
| 27 |
+
--gres=gpu:1 \
|
| 28 |
+
-N 1 -n 1 \
|
| 29 |
+
--cpus-per-task=4 \
|
| 30 |
+
--mem=32G \
|
| 31 |
+
-t 2:00:00 \
|
| 32 |
+
-o "${LOGDIR}/${model}_${mod_tag}_early_%j.out" \
|
| 33 |
+
-e "${LOGDIR}/${model}_${mod_tag}_early_%j.err" \
|
| 34 |
+
--export=ALL \
|
| 35 |
+
--wrap="export PYTHONUNBUFFERED=1; cd ${BASEDIR}; $PYTHON $SCRIPT --model $model --modalities $mods --fusion early $COMMON"
|
| 36 |
+
echo " Submitted: $model / $mods / early"
|
| 37 |
+
done
|
| 38 |
+
done
|
| 39 |
+
|
| 40 |
+
# Part 2: Fusion methods × transformer
|
| 41 |
+
FUSIONS=("late" "attention" "weighted_late" "gated_late" "stacking" "product" "moe")
|
| 42 |
+
FUSION_MODS=("mocap,emg,eyetrack" "mocap,emg,eyetrack,imu,pressure")
|
| 43 |
+
|
| 44 |
+
echo ""
|
| 45 |
+
echo "=== Part 2: Fusion Ablation (14 jobs) ==="
|
| 46 |
+
for fmods in "${FUSION_MODS[@]}"; do
|
| 47 |
+
fmod_tag=$(echo $fmods | tr ',' '-')
|
| 48 |
+
for fusion in "${FUSIONS[@]}"; do
|
| 49 |
+
sbatch \
|
| 50 |
+
-J "exp1_tf_${fusion}_${fmod_tag}" \
|
| 51 |
+
-p gpuA800 \
|
| 52 |
+
--gres=gpu:1 \
|
| 53 |
+
-N 1 -n 1 \
|
| 54 |
+
--cpus-per-task=4 \
|
| 55 |
+
--mem=32G \
|
| 56 |
+
-t 2:00:00 \
|
| 57 |
+
-o "${LOGDIR}/transformer_${fmod_tag}_${fusion}_%j.out" \
|
| 58 |
+
-e "${LOGDIR}/transformer_${fmod_tag}_${fusion}_%j.err" \
|
| 59 |
+
--export=ALL \
|
| 60 |
+
--wrap="export PYTHONUNBUFFERED=1; cd ${BASEDIR}; $PYTHON $SCRIPT --model transformer --modalities $fmods --fusion $fusion $COMMON"
|
| 61 |
+
echo " Submitted: transformer / $fmods / $fusion"
|
| 62 |
+
done
|
| 63 |
+
done
|
| 64 |
+
|
| 65 |
+
echo ""
|
| 66 |
+
echo "Total: 41 jobs | Scene Recognition | Updated IMU data"
|
| 67 |
+
echo "Results: $OUTDIR"
|
experiments/slurm/run_exp1_small.sh
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Exp1 small model: hidden_dim=32, dropout=0.5, weight_decay=1e-3
|
| 3 |
+
# 3 modalities: mocap, emg, imu (exclude pressure & eyetrack)
|
| 4 |
+
# Output: results/exp1_small
|
| 5 |
+
|
| 6 |
+
PYTHON=python
|
| 7 |
+
SCRIPT=${PULSE_ROOT}/experiments/train_exp1.py
|
| 8 |
+
OUTDIR=${PULSE_ROOT}/results/exp1_small
|
| 9 |
+
LOGDIR=${OUTDIR}/slurm_logs
|
| 10 |
+
mkdir -p $LOGDIR
|
| 11 |
+
|
| 12 |
+
COMMON="--model transformer --epochs 100 --batch_size 16 --lr 1e-3 --weight_decay 1e-3 --hidden_dim 32 --downsample 5 --patience 15 --seed 42 --output_dir $OUTDIR"
|
| 13 |
+
|
| 14 |
+
# ============================================================
|
| 15 |
+
# Part 1: Single modality (early fusion = single backbone)
|
| 16 |
+
# ============================================================
|
| 17 |
+
for mod in mocap emg imu; do
|
| 18 |
+
job_name="s_${mod}"
|
| 19 |
+
sbatch \
|
| 20 |
+
-J "$job_name" \
|
| 21 |
+
-p gpuA800 \
|
| 22 |
+
--gres=gpu:1 \
|
| 23 |
+
-N 1 -n 1 \
|
| 24 |
+
--cpus-per-task=8 \
|
| 25 |
+
--mem=32G \
|
| 26 |
+
-t 1:00:00 \
|
| 27 |
+
-o "${LOGDIR}/${job_name}_%j.out" \
|
| 28 |
+
-e "${LOGDIR}/${job_name}_%j.err" \
|
| 29 |
+
--export=ALL \
|
| 30 |
+
--wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --fusion early --modalities $mod $COMMON"
|
| 31 |
+
echo "Submitted: $job_name"
|
| 32 |
+
done
|
| 33 |
+
|
| 34 |
+
# ============================================================
|
| 35 |
+
# Part 2: Multi-modality early fusion (4 combos)
|
| 36 |
+
# ============================================================
|
| 37 |
+
EARLY_COMBOS=("mocap,emg" "mocap,imu" "emg,imu" "mocap,emg,imu")
|
| 38 |
+
for mods in "${EARLY_COMBOS[@]}"; do
|
| 39 |
+
mod_tag=$(echo $mods | tr ',' '-')
|
| 40 |
+
job_name="e_${mod_tag}"
|
| 41 |
+
sbatch \
|
| 42 |
+
-J "$job_name" \
|
| 43 |
+
-p gpuA800 \
|
| 44 |
+
--gres=gpu:1 \
|
| 45 |
+
-N 1 -n 1 \
|
| 46 |
+
--cpus-per-task=8 \
|
| 47 |
+
--mem=32G \
|
| 48 |
+
-t 1:00:00 \
|
| 49 |
+
-o "${LOGDIR}/${job_name}_%j.out" \
|
| 50 |
+
-e "${LOGDIR}/${job_name}_%j.err" \
|
| 51 |
+
--export=ALL \
|
| 52 |
+
--wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --fusion early --modalities $mods $COMMON"
|
| 53 |
+
echo "Submitted: $job_name"
|
| 54 |
+
done
|
| 55 |
+
|
| 56 |
+
# ============================================================
|
| 57 |
+
# Part 3: Fusion methods x modality sets
|
| 58 |
+
# ============================================================
|
| 59 |
+
FUSIONS=(late attention weighted_late gated_late stacking product moe)
|
| 60 |
+
FUSION_MODS=("mocap,emg,imu" "mocap,imu")
|
| 61 |
+
|
| 62 |
+
for fusion in "${FUSIONS[@]}"; do
|
| 63 |
+
for mods in "${FUSION_MODS[@]}"; do
|
| 64 |
+
mod_tag=$(echo $mods | tr ',' '-')
|
| 65 |
+
job_name="f_${fusion}_${mod_tag}"
|
| 66 |
+
sbatch \
|
| 67 |
+
-J "$job_name" \
|
| 68 |
+
-p gpuA800 \
|
| 69 |
+
--gres=gpu:1 \
|
| 70 |
+
-N 1 -n 1 \
|
| 71 |
+
--cpus-per-task=8 \
|
| 72 |
+
--mem=32G \
|
| 73 |
+
-t 1:00:00 \
|
| 74 |
+
-o "${LOGDIR}/${job_name}_%j.out" \
|
| 75 |
+
-e "${LOGDIR}/${job_name}_%j.err" \
|
| 76 |
+
--export=ALL \
|
| 77 |
+
--wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --fusion $fusion --modalities $mods $COMMON"
|
| 78 |
+
echo "Submitted: $job_name"
|
| 79 |
+
done
|
| 80 |
+
done
|
| 81 |
+
|
| 82 |
+
echo ""
|
| 83 |
+
echo "Total: 3 single + 4 early + 14 fusion = 21 jobs submitted!"
|
| 84 |
+
echo "Results will be saved to: $OUTDIR"
|
experiments/slurm/run_exp1_small2.sh
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Exp1 small2: per-modality hidden_dim + missing emg+imu fusion experiments
|
| 3 |
+
# hidden_dim=32 base, scaled per modality: mocap(211)->48, imu(161)->48, emg(9)->16
|
| 4 |
+
# Output: results/exp1_small2
|
| 5 |
+
|
| 6 |
+
PYTHON=python
|
| 7 |
+
SCRIPT=${PULSE_ROOT}/experiments/train_exp1.py
|
| 8 |
+
OUTDIR=${PULSE_ROOT}/results/exp1_small2
|
| 9 |
+
LOGDIR=${OUTDIR}/slurm_logs
|
| 10 |
+
mkdir -p $LOGDIR
|
| 11 |
+
|
| 12 |
+
COMMON="--model transformer --epochs 100 --batch_size 16 --lr 1e-3 --weight_decay 1e-3 --hidden_dim 32 --downsample 5 --patience 15 --seed 42 --output_dir $OUTDIR"
|
| 13 |
+
|
| 14 |
+
# ============================================================
|
| 15 |
+
# Part 1: Single modality baselines (3 jobs)
|
| 16 |
+
# ============================================================
|
| 17 |
+
for mod in mocap emg imu; do
|
| 18 |
+
job_name="s2_${mod}"
|
| 19 |
+
sbatch \
|
| 20 |
+
-J "$job_name" \
|
| 21 |
+
-p gpuA800 \
|
| 22 |
+
--gres=gpu:1 \
|
| 23 |
+
-N 1 -n 1 \
|
| 24 |
+
--cpus-per-task=8 \
|
| 25 |
+
--mem=32G \
|
| 26 |
+
-t 1:00:00 \
|
| 27 |
+
-o "${LOGDIR}/${job_name}_%j.out" \
|
| 28 |
+
-e "${LOGDIR}/${job_name}_%j.err" \
|
| 29 |
+
--export=ALL \
|
| 30 |
+
--wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --fusion early --modalities $mod $COMMON"
|
| 31 |
+
echo "Submitted: $job_name"
|
| 32 |
+
done
|
| 33 |
+
|
| 34 |
+
# ============================================================
|
| 35 |
+
# Part 2: Early fusion baselines (3 combos)
|
| 36 |
+
# ============================================================
|
| 37 |
+
EARLY_COMBOS=("emg,imu" "mocap,imu" "mocap,emg,imu")
|
| 38 |
+
for mods in "${EARLY_COMBOS[@]}"; do
|
| 39 |
+
mod_tag=$(echo $mods | tr ',' '-')
|
| 40 |
+
job_name="s2_e_${mod_tag}"
|
| 41 |
+
sbatch \
|
| 42 |
+
-J "$job_name" \
|
| 43 |
+
-p gpuA800 \
|
| 44 |
+
--gres=gpu:1 \
|
| 45 |
+
-N 1 -n 1 \
|
| 46 |
+
--cpus-per-task=8 \
|
| 47 |
+
--mem=32G \
|
| 48 |
+
-t 1:00:00 \
|
| 49 |
+
-o "${LOGDIR}/${job_name}_%j.out" \
|
| 50 |
+
-e "${LOGDIR}/${job_name}_%j.err" \
|
| 51 |
+
--export=ALL \
|
| 52 |
+
--wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --fusion early --modalities $mods $COMMON"
|
| 53 |
+
echo "Submitted: $job_name"
|
| 54 |
+
done
|
| 55 |
+
|
| 56 |
+
# ============================================================
|
| 57 |
+
# Part 3: Fusion methods x modality combos (7 methods x 3 combos = 21 jobs)
|
| 58 |
+
# Key addition: emg,imu fusion (was missing in round 1)
|
| 59 |
+
# ============================================================
|
| 60 |
+
FUSIONS=(late attention weighted_late gated_late stacking product moe)
|
| 61 |
+
FUSION_MODS=("emg,imu" "mocap,imu" "mocap,emg,imu")
|
| 62 |
+
|
| 63 |
+
for fusion in "${FUSIONS[@]}"; do
|
| 64 |
+
for mods in "${FUSION_MODS[@]}"; do
|
| 65 |
+
mod_tag=$(echo $mods | tr ',' '-')
|
| 66 |
+
job_name="s2_${fusion}_${mod_tag}"
|
| 67 |
+
sbatch \
|
| 68 |
+
-J "$job_name" \
|
| 69 |
+
-p gpuA800 \
|
| 70 |
+
--gres=gpu:1 \
|
| 71 |
+
-N 1 -n 1 \
|
| 72 |
+
--cpus-per-task=8 \
|
| 73 |
+
--mem=32G \
|
| 74 |
+
-t 1:00:00 \
|
| 75 |
+
-o "${LOGDIR}/${job_name}_%j.out" \
|
| 76 |
+
-e "${LOGDIR}/${job_name}_%j.err" \
|
| 77 |
+
--export=ALL \
|
| 78 |
+
--wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --fusion $fusion --modalities $mods $COMMON"
|
| 79 |
+
echo "Submitted: $job_name"
|
| 80 |
+
done
|
| 81 |
+
done
|
| 82 |
+
|
| 83 |
+
echo ""
|
| 84 |
+
echo "Total: 3 single + 3 early + 21 fusion = 27 jobs submitted!"
|
| 85 |
+
echo "Results will be saved to: $OUTDIR"
|
experiments/slurm/run_exp1_small3.sh
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Exp1 small3: Data augmentation + Frozen pretrained IMU + Label smoothing
|
| 3 |
+
# Goal: Break the IMU-alone F1=0.771 ceiling with emg+imu fusion
|
| 4 |
+
# Phase 0: pretrain IMU with hidden_dim=48 (matches fusion branch)
|
| 5 |
+
# Baselines: IMU+aug+ls, emg+imu early+aug+ls
|
| 6 |
+
# Group A: 7 fusion + aug + ls (no freeze)
|
| 7 |
+
# Group B: 7 fusion + frozen IMU + ls (no aug) [dep: phase0]
|
| 8 |
+
# Group C: 7 fusion + frozen IMU + aug + ls [dep: phase0]
|
| 9 |
+
# Total: 1 + 2 + 7 + 7 + 7 = 24 jobs
|
| 10 |
+
|
| 11 |
+
PYTHON=python
|
| 12 |
+
SCRIPT=${PULSE_ROOT}/experiments/train_exp1.py
|
| 13 |
+
OUTDIR=${PULSE_ROOT}/results/exp1_small3
|
| 14 |
+
LOGDIR=${OUTDIR}/slurm_logs
|
| 15 |
+
mkdir -p $LOGDIR
|
| 16 |
+
|
| 17 |
+
COMMON="--model transformer --epochs 100 --batch_size 16 --lr 1e-3 --weight_decay 1e-3 --hidden_dim 32 --downsample 5 --patience 15 --seed 42"
|
| 18 |
+
FUSIONS=(late attention weighted_late gated_late stacking product moe)
|
| 19 |
+
|
| 20 |
+
# ============================================================
|
| 21 |
+
# Phase 0: Pretrain IMU with hidden_dim=48 (matches fusion branch)
|
| 22 |
+
# ============================================================
|
| 23 |
+
PHASE0_JOB=$(sbatch --parsable \
|
| 24 |
+
-J "s3_phase0_imu48" \
|
| 25 |
+
-p gpuA800 \
|
| 26 |
+
--gres=gpu:1 \
|
| 27 |
+
-N 1 -n 1 \
|
| 28 |
+
--cpus-per-task=8 \
|
| 29 |
+
--mem=32G \
|
| 30 |
+
-t 1:00:00 \
|
| 31 |
+
-o "${LOGDIR}/phase0_imu48_%j.out" \
|
| 32 |
+
-e "${LOGDIR}/phase0_imu48_%j.err" \
|
| 33 |
+
--export=ALL \
|
| 34 |
+
--wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --model transformer --fusion early --modalities imu --hidden_dim 48 --epochs 100 --batch_size 16 --lr 1e-3 --weight_decay 1e-3 --downsample 5 --patience 15 --seed 42 --output_dir ${OUTDIR}/phase0")
|
| 35 |
+
echo "Phase 0 (IMU h48): job $PHASE0_JOB"
|
| 36 |
+
|
| 37 |
+
PRETRAINED="${OUTDIR}/phase0/transformer_imu_early/model_best.pt"
|
| 38 |
+
|
| 39 |
+
# ============================================================
|
| 40 |
+
# Baselines (no dependency)
|
| 41 |
+
# ============================================================
|
| 42 |
+
|
| 43 |
+
# Baseline 1: IMU alone + augment + label_smoothing
|
| 44 |
+
sbatch \
|
| 45 |
+
-J "s3_bl_imu_aug" \
|
| 46 |
+
-p gpuA800 \
|
| 47 |
+
--gres=gpu:1 \
|
| 48 |
+
-N 1 -n 1 \
|
| 49 |
+
--cpus-per-task=8 \
|
| 50 |
+
--mem=32G \
|
| 51 |
+
-t 1:00:00 \
|
| 52 |
+
-o "${LOGDIR}/bl_imu_aug_%j.out" \
|
| 53 |
+
-e "${LOGDIR}/bl_imu_aug_%j.err" \
|
| 54 |
+
--export=ALL \
|
| 55 |
+
--wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --fusion early --modalities imu $COMMON --augment --label_smoothing 0.1 --tag bl_aug --output_dir $OUTDIR"
|
| 56 |
+
echo "Submitted: baseline IMU+aug+ls"
|
| 57 |
+
|
| 58 |
+
# Baseline 2: emg,imu early + augment + label_smoothing
|
| 59 |
+
sbatch \
|
| 60 |
+
-J "s3_bl_ei_aug" \
|
| 61 |
+
-p gpuA800 \
|
| 62 |
+
--gres=gpu:1 \
|
| 63 |
+
-N 1 -n 1 \
|
| 64 |
+
--cpus-per-task=8 \
|
| 65 |
+
--mem=32G \
|
| 66 |
+
-t 1:00:00 \
|
| 67 |
+
-o "${LOGDIR}/bl_ei_aug_%j.out" \
|
| 68 |
+
-e "${LOGDIR}/bl_ei_aug_%j.err" \
|
| 69 |
+
--export=ALL \
|
| 70 |
+
--wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --fusion early --modalities emg,imu $COMMON --augment --label_smoothing 0.1 --tag bl_aug --output_dir $OUTDIR"
|
| 71 |
+
echo "Submitted: baseline emg+imu early+aug+ls"
|
| 72 |
+
|
| 73 |
+
# ============================================================
|
| 74 |
+
# Group A: emg+imu x 7 fusion + augment + label_smoothing (no freeze)
|
| 75 |
+
# ============================================================
|
| 76 |
+
for fusion in "${FUSIONS[@]}"; do
|
| 77 |
+
sbatch \
|
| 78 |
+
-J "s3_A_${fusion}" \
|
| 79 |
+
-p gpuA800 \
|
| 80 |
+
--gres=gpu:1 \
|
| 81 |
+
-N 1 -n 1 \
|
| 82 |
+
--cpus-per-task=8 \
|
| 83 |
+
--mem=32G \
|
| 84 |
+
-t 1:00:00 \
|
| 85 |
+
-o "${LOGDIR}/grpA_${fusion}_%j.out" \
|
| 86 |
+
-e "${LOGDIR}/grpA_${fusion}_%j.err" \
|
| 87 |
+
--export=ALL \
|
| 88 |
+
--wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --fusion $fusion --modalities emg,imu $COMMON --augment --label_smoothing 0.1 --tag grpA --output_dir $OUTDIR"
|
| 89 |
+
echo "Submitted: Group A $fusion"
|
| 90 |
+
done
|
| 91 |
+
|
| 92 |
+
# ============================================================
|
| 93 |
+
# Group B: emg+imu x 7 fusion + frozen IMU + label_smoothing (no augment)
|
| 94 |
+
# Depends on Phase 0
|
| 95 |
+
# ============================================================
|
| 96 |
+
for fusion in "${FUSIONS[@]}"; do
|
| 97 |
+
sbatch \
|
| 98 |
+
--dependency=afterok:${PHASE0_JOB} \
|
| 99 |
+
-J "s3_B_${fusion}" \
|
| 100 |
+
-p gpuA800 \
|
| 101 |
+
--gres=gpu:1 \
|
| 102 |
+
-N 1 -n 1 \
|
| 103 |
+
--cpus-per-task=8 \
|
| 104 |
+
--mem=32G \
|
| 105 |
+
-t 1:00:00 \
|
| 106 |
+
-o "${LOGDIR}/grpB_${fusion}_%j.out" \
|
| 107 |
+
-e "${LOGDIR}/grpB_${fusion}_%j.err" \
|
| 108 |
+
--export=ALL \
|
| 109 |
+
--wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --fusion $fusion --modalities emg,imu $COMMON --label_smoothing 0.1 --pretrained_backbone $PRETRAINED --freeze_backbone_idx 1 --tag grpB --output_dir $OUTDIR"
|
| 110 |
+
echo "Submitted: Group B $fusion (dep: $PHASE0_JOB)"
|
| 111 |
+
done
|
| 112 |
+
|
| 113 |
+
# ============================================================
|
| 114 |
+
# Group C: emg+imu x 7 fusion + frozen IMU + augment + label_smoothing
|
| 115 |
+
# Depends on Phase 0
|
| 116 |
+
# ============================================================
|
| 117 |
+
for fusion in "${FUSIONS[@]}"; do
|
| 118 |
+
sbatch \
|
| 119 |
+
--dependency=afterok:${PHASE0_JOB} \
|
| 120 |
+
-J "s3_C_${fusion}" \
|
| 121 |
+
-p gpuA800 \
|
| 122 |
+
--gres=gpu:1 \
|
| 123 |
+
-N 1 -n 1 \
|
| 124 |
+
--cpus-per-task=8 \
|
| 125 |
+
--mem=32G \
|
| 126 |
+
-t 1:00:00 \
|
| 127 |
+
-o "${LOGDIR}/grpC_${fusion}_%j.out" \
|
| 128 |
+
-e "${LOGDIR}/grpC_${fusion}_%j.err" \
|
| 129 |
+
--export=ALL \
|
| 130 |
+
--wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --fusion $fusion --modalities emg,imu $COMMON --augment --label_smoothing 0.1 --pretrained_backbone $PRETRAINED --freeze_backbone_idx 1 --tag grpC --output_dir $OUTDIR"
|
| 131 |
+
echo "Submitted: Group C $fusion (dep: $PHASE0_JOB)"
|
| 132 |
+
done
|
| 133 |
+
|
| 134 |
+
echo ""
|
| 135 |
+
echo "Total: 1 phase0 + 2 baselines + 7 grpA + 7 grpB + 7 grpC = 24 jobs"
|
| 136 |
+
echo "Results: $OUTDIR"
|
| 137 |
+
echo "Phase 0 job ID: $PHASE0_JOB (Groups B & C depend on it)"
|
experiments/slurm/run_exp1_v3.sh
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Scene Recognition (Exp1 v3) - Train 14 vols / Test 4 vols (no val)
|
| 3 |
+
# v23,v24 moved from val to train; v3 stays in test
|
| 4 |
+
# Part 1: 9 modality combos × 3 backbones = 27 jobs (early fusion)
|
| 5 |
+
# Part 2: 7 fusion methods × transformer × (3-core + all-5) = 14 jobs
|
| 6 |
+
# Total: 41 jobs
|
| 7 |
+
|
| 8 |
+
PYTHON=python
|
| 9 |
+
BASEDIR=${PULSE_ROOT}
|
| 10 |
+
SCRIPT=${BASEDIR}/experiments/train_exp1.py
|
| 11 |
+
OUTDIR=${BASEDIR}/results/exp1_v3
|
| 12 |
+
LOGDIR=${OUTDIR}/slurm_logs
|
| 13 |
+
mkdir -p $LOGDIR
|
| 14 |
+
|
| 15 |
+
COMMON="--epochs 100 --batch_size 16 --lr 1e-3 --weight_decay 1e-4 --hidden_dim 128 --downsample 5 --patience 15 --seed 42 --output_dir $OUTDIR"
|
| 16 |
+
|
| 17 |
+
MODS=("mocap" "emg" "eyetrack" "imu" "pressure" "mocap,emg,eyetrack" "mocap,emg,eyetrack,imu" "mocap,emg,eyetrack,pressure" "mocap,emg,eyetrack,imu,pressure")
|
| 18 |
+
MODELS=("cnn" "lstm" "transformer")
|
| 19 |
+
|
| 20 |
+
# Part 1: Modality ablation × 3 backbones
|
| 21 |
+
echo "=== Part 1: Modality Ablation (27 jobs) ==="
|
| 22 |
+
for mods in "${MODS[@]}"; do
|
| 23 |
+
mod_tag=$(echo $mods | tr ',' '-')
|
| 24 |
+
for model in "${MODELS[@]}"; do
|
| 25 |
+
sbatch \
|
| 26 |
+
-J "e1v3_${model}_${mod_tag}" \
|
| 27 |
+
-p gpuA800 \
|
| 28 |
+
--gres=gpu:1 \
|
| 29 |
+
-N 1 -n 1 \
|
| 30 |
+
--cpus-per-task=4 \
|
| 31 |
+
--mem=32G \
|
| 32 |
+
-t 2:00:00 \
|
| 33 |
+
-o "${LOGDIR}/${model}_${mod_tag}_early_%j.out" \
|
| 34 |
+
-e "${LOGDIR}/${model}_${mod_tag}_early_%j.err" \
|
| 35 |
+
--export=ALL \
|
| 36 |
+
--wrap="export PYTHONUNBUFFERED=1; cd ${BASEDIR}; $PYTHON $SCRIPT --model $model --modalities $mods --fusion early $COMMON"
|
| 37 |
+
echo " $model / $mods / early"
|
| 38 |
+
done
|
| 39 |
+
done
|
| 40 |
+
|
| 41 |
+
# Part 2: Fusion methods × transformer
|
| 42 |
+
FUSIONS=("late" "attention" "weighted_late" "gated_late" "stacking" "product" "moe")
|
| 43 |
+
FUSION_MODS=("mocap,emg,eyetrack" "mocap,emg,eyetrack,imu,pressure")
|
| 44 |
+
|
| 45 |
+
echo ""
|
| 46 |
+
echo "=== Part 2: Fusion Ablation (14 jobs) ==="
|
| 47 |
+
for fmods in "${FUSION_MODS[@]}"; do
|
| 48 |
+
fmod_tag=$(echo $fmods | tr ',' '-')
|
| 49 |
+
for fusion in "${FUSIONS[@]}"; do
|
| 50 |
+
sbatch \
|
| 51 |
+
-J "e1v3_tf_${fusion}" \
|
| 52 |
+
-p gpuA800 \
|
| 53 |
+
--gres=gpu:1 \
|
| 54 |
+
-N 1 -n 1 \
|
| 55 |
+
--cpus-per-task=4 \
|
| 56 |
+
--mem=32G \
|
| 57 |
+
-t 2:00:00 \
|
| 58 |
+
-o "${LOGDIR}/transformer_${fmod_tag}_${fusion}_%j.out" \
|
| 59 |
+
-e "${LOGDIR}/transformer_${fmod_tag}_${fusion}_%j.err" \
|
| 60 |
+
--export=ALL \
|
| 61 |
+
--wrap="export PYTHONUNBUFFERED=1; cd ${BASEDIR}; $PYTHON $SCRIPT --model transformer --modalities $fmods --fusion $fusion $COMMON"
|
| 62 |
+
echo " transformer / $fmods / $fusion"
|
| 63 |
+
done
|
| 64 |
+
done
|
| 65 |
+
|
| 66 |
+
echo ""
|
| 67 |
+
echo "Total: 41 jobs | Scene Recognition v3 | Train=14vols, Test=4vols"
|
| 68 |
+
echo "Results: $OUTDIR"
|
experiments/slurm/run_exp1_v4.sh
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Scene Recognition (Exp1 v4) - Per-modality projection to 50 dims
|
| 3 |
+
# All modalities projected to 50d via FC before backbone processing
|
| 4 |
+
# Train 14 vols / Test 4 vols (no val)
|
| 5 |
+
# Part 1: 9 modality combos × 3 backbones = 27 jobs (early fusion)
|
| 6 |
+
# Part 2: 7 fusion methods × transformer × (3-core + all-5) = 14 jobs
|
| 7 |
+
# Total: 41 jobs
|
| 8 |
+
|
| 9 |
+
PYTHON=python
|
| 10 |
+
BASEDIR=${PULSE_ROOT}
|
| 11 |
+
SCRIPT=${BASEDIR}/experiments/train_exp1.py
|
| 12 |
+
OUTDIR=${BASEDIR}/results/exp1_v4
|
| 13 |
+
LOGDIR=${OUTDIR}/slurm_logs
|
| 14 |
+
mkdir -p $LOGDIR
|
| 15 |
+
|
| 16 |
+
COMMON="--epochs 100 --batch_size 16 --lr 1e-3 --weight_decay 1e-4 --hidden_dim 128 --downsample 5 --patience 15 --seed 42 --output_dir $OUTDIR"
|
| 17 |
+
|
| 18 |
+
MODS=("mocap" "emg" "eyetrack" "imu" "pressure" "mocap,emg,eyetrack" "mocap,emg,eyetrack,imu" "mocap,emg,eyetrack,pressure" "mocap,emg,eyetrack,imu,pressure")
|
| 19 |
+
MODELS=("cnn" "lstm" "transformer")
|
| 20 |
+
|
| 21 |
+
# Part 1: Modality ablation × 3 backbones
|
| 22 |
+
echo "=== Part 1: Modality Ablation (27 jobs) ==="
|
| 23 |
+
for mods in "${MODS[@]}"; do
|
| 24 |
+
mod_tag=$(echo $mods | tr ',' '-')
|
| 25 |
+
for model in "${MODELS[@]}"; do
|
| 26 |
+
sbatch \
|
| 27 |
+
-J "e1v4_${model}_${mod_tag}" \
|
| 28 |
+
-p gpuA800 \
|
| 29 |
+
--gres=gpu:1 \
|
| 30 |
+
-N 1 -n 1 \
|
| 31 |
+
--cpus-per-task=4 \
|
| 32 |
+
--mem=32G \
|
| 33 |
+
-t 2:00:00 \
|
| 34 |
+
-o "${LOGDIR}/${model}_${mod_tag}_early_%j.out" \
|
| 35 |
+
-e "${LOGDIR}/${model}_${mod_tag}_early_%j.err" \
|
| 36 |
+
--export=ALL \
|
| 37 |
+
--wrap="export PYTHONUNBUFFERED=1; cd ${BASEDIR}; $PYTHON $SCRIPT --model $model --modalities $mods --fusion early $COMMON"
|
| 38 |
+
echo " $model / $mods / early"
|
| 39 |
+
done
|
| 40 |
+
done
|
| 41 |
+
|
| 42 |
+
# Part 2: Fusion methods × transformer
|
| 43 |
+
FUSIONS=("late" "attention" "weighted_late" "gated_late" "stacking" "product" "moe")
|
| 44 |
+
FUSION_MODS=("mocap,emg,eyetrack" "mocap,emg,eyetrack,imu,pressure")
|
| 45 |
+
|
| 46 |
+
echo ""
|
| 47 |
+
echo "=== Part 2: Fusion Ablation (14 jobs) ==="
|
| 48 |
+
for fmods in "${FUSION_MODS[@]}"; do
|
| 49 |
+
fmod_tag=$(echo $fmods | tr ',' '-')
|
| 50 |
+
for fusion in "${FUSIONS[@]}"; do
|
| 51 |
+
sbatch \
|
| 52 |
+
-J "e1v4_tf_${fusion}" \
|
| 53 |
+
-p gpuA800 \
|
| 54 |
+
--gres=gpu:1 \
|
| 55 |
+
-N 1 -n 1 \
|
| 56 |
+
--cpus-per-task=4 \
|
| 57 |
+
--mem=32G \
|
| 58 |
+
-t 2:00:00 \
|
| 59 |
+
-o "${LOGDIR}/transformer_${fmod_tag}_${fusion}_%j.out" \
|
| 60 |
+
-e "${LOGDIR}/transformer_${fmod_tag}_${fusion}_%j.err" \
|
| 61 |
+
--export=ALL \
|
| 62 |
+
--wrap="export PYTHONUNBUFFERED=1; cd ${BASEDIR}; $PYTHON $SCRIPT --model transformer --modalities $fmods --fusion $fusion $COMMON"
|
| 63 |
+
echo " transformer / $fmods / $fusion"
|
| 64 |
+
done
|
| 65 |
+
done
|
| 66 |
+
|
| 67 |
+
echo ""
|
| 68 |
+
echo "Total: 41 jobs | Scene Recognition v4 | Proj50d | Train=14vols, Test=4vols"
|
| 69 |
+
echo "Results: $OUTDIR"
|
experiments/slurm/run_exp1_v5.sh
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Scene Recognition (Exp1 v5) - Only imu, mocap, emg
|
| 3 |
+
# Per-modality projection to 50d
|
| 4 |
+
# Train 14 vols / Test 4 vols
|
| 5 |
+
|
| 6 |
+
PYTHON=python
|
| 7 |
+
BASEDIR=${PULSE_ROOT}
|
| 8 |
+
SCRIPT=${BASEDIR}/experiments/train_exp1.py
|
| 9 |
+
OUTDIR=${BASEDIR}/results/exp1_v5
|
| 10 |
+
LOGDIR=${OUTDIR}/slurm_logs
|
| 11 |
+
mkdir -p $LOGDIR
|
| 12 |
+
|
| 13 |
+
COMMON="--epochs 100 --batch_size 16 --lr 1e-3 --weight_decay 1e-4 --hidden_dim 128 --downsample 5 --patience 15 --seed 42 --output_dir $OUTDIR"
|
| 14 |
+
MODELS=("cnn" "lstm" "transformer")
|
| 15 |
+
|
| 16 |
+
# Part 1: Single modality (3 mods × 3 backbones = 9 jobs)
|
| 17 |
+
echo "=== Part 1: Single Modality (9 jobs) ==="
|
| 18 |
+
for mods in "imu" "mocap" "emg"; do
|
| 19 |
+
for model in "${MODELS[@]}"; do
|
| 20 |
+
sbatch -J "e1v5_${model}_${mods}" -p gpuA800 --gres=gpu:1 -N1 -n1 \
|
| 21 |
+
--cpus-per-task=4 --mem=32G -t 2:00:00 \
|
| 22 |
+
-o "${LOGDIR}/${model}_${mods}_early_%j.out" \
|
| 23 |
+
-e "${LOGDIR}/${model}_${mods}_early_%j.err" \
|
| 24 |
+
--export=ALL \
|
| 25 |
+
--wrap="export PYTHONUNBUFFERED=1; cd ${BASEDIR}; $PYTHON $SCRIPT --model $model --modalities $mods --fusion early $COMMON"
|
| 26 |
+
echo " $model / $mods / early"
|
| 27 |
+
done
|
| 28 |
+
done
|
| 29 |
+
|
| 30 |
+
# Part 2: Multi-modality early fusion (4 combos × 3 backbones = 12 jobs)
|
| 31 |
+
echo ""
|
| 32 |
+
echo "=== Part 2: Multi-Modality Early Fusion (12 jobs) ==="
|
| 33 |
+
for mods in "imu,mocap" "imu,emg" "mocap,emg" "imu,mocap,emg"; do
|
| 34 |
+
mod_tag=$(echo $mods | tr ',' '-')
|
| 35 |
+
for model in "${MODELS[@]}"; do
|
| 36 |
+
sbatch -J "e1v5_${model}_${mod_tag}" -p gpuA800 --gres=gpu:1 -N1 -n1 \
|
| 37 |
+
--cpus-per-task=4 --mem=32G -t 2:00:00 \
|
| 38 |
+
-o "${LOGDIR}/${model}_${mod_tag}_early_%j.out" \
|
| 39 |
+
-e "${LOGDIR}/${model}_${mod_tag}_early_%j.err" \
|
| 40 |
+
--export=ALL \
|
| 41 |
+
--wrap="export PYTHONUNBUFFERED=1; cd ${BASEDIR}; $PYTHON $SCRIPT --model $model --modalities $mods --fusion early $COMMON"
|
| 42 |
+
echo " $model / $mods / early"
|
| 43 |
+
done
|
| 44 |
+
done
|
| 45 |
+
|
| 46 |
+
# Part 3: Fusion ablation with imu+mocap+emg × transformer (7 jobs)
|
| 47 |
+
FUSIONS=("late" "attention" "weighted_late" "gated_late" "stacking" "product" "moe")
|
| 48 |
+
echo ""
|
| 49 |
+
echo "=== Part 3: Fusion Ablation - transformer × imu+mocap+emg (7 jobs) ==="
|
| 50 |
+
for fusion in "${FUSIONS[@]}"; do
|
| 51 |
+
sbatch -J "e1v5_tf_${fusion}" -p gpuA800 --gres=gpu:1 -N1 -n1 \
|
| 52 |
+
--cpus-per-task=4 --mem=32G -t 2:00:00 \
|
| 53 |
+
-o "${LOGDIR}/transformer_imu-mocap-emg_${fusion}_%j.out" \
|
| 54 |
+
-e "${LOGDIR}/transformer_imu-mocap-emg_${fusion}_%j.err" \
|
| 55 |
+
--export=ALL \
|
| 56 |
+
--wrap="export PYTHONUNBUFFERED=1; cd ${BASEDIR}; $PYTHON $SCRIPT --model transformer --modalities imu,mocap,emg --fusion $fusion $COMMON"
|
| 57 |
+
echo " transformer / imu,mocap,emg / $fusion"
|
| 58 |
+
done
|
| 59 |
+
|
| 60 |
+
echo ""
|
| 61 |
+
echo "Total: 28 jobs | 3 modalities: imu(160d→50d), mocap(156d→50d), emg(8d→50d)"
|
| 62 |
+
echo "Results: $OUTDIR"
|