velvet-pine-22 commited on
Commit
b4b2877
·
verified ·
1 Parent(s): 6f63aa1

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +21 -0
  2. README.md +152 -0
  3. experiments/__init__.py +0 -0
  4. experiments/analysis/__init__.py +0 -0
  5. experiments/analysis/aggregate_new_exps.py +166 -0
  6. experiments/analysis/aggregate_t1_extended.py +60 -0
  7. experiments/analysis/analysis_figures.py +444 -0
  8. experiments/analysis/build_taxonomy.py +136 -0
  9. experiments/analysis/check_seg_lengths.py +229 -0
  10. experiments/analysis/data_statistics_figure.py +126 -0
  11. experiments/analysis/exp_per_subject.py +150 -0
  12. experiments/analysis/extract_video_features.py +208 -0
  13. experiments/analysis/extract_videomae_features.py +276 -0
  14. experiments/analysis/gen_val_comparison.py +74 -0
  15. experiments/analysis/generate_action_labels.py +133 -0
  16. experiments/analysis/generate_coarse_annotations.py +296 -0
  17. experiments/analysis/grasp_phase_analysis.py +442 -0
  18. experiments/analysis/modality_viz.py +145 -0
  19. experiments/analysis/reannotate_actions.py +363 -0
  20. experiments/data/__init__.py +0 -0
  21. experiments/data/__pycache__/dataset.cpython-312.pyc +0 -0
  22. experiments/data/dataset.py +332 -0
  23. experiments/data/dataset_forecast.py +319 -0
  24. experiments/data/dataset_grasp_state.py +571 -0
  25. experiments/data/dataset_seqpred.py +533 -0
  26. experiments/data/dataset_signal_forecast.py +391 -0
  27. experiments/nets/__init__.py +0 -0
  28. experiments/nets/__pycache__/models_seqpred.cpython-312.pyc +0 -0
  29. experiments/nets/baselines_published/__init__.py +0 -0
  30. experiments/nets/baselines_published/baselines.py +488 -0
  31. experiments/nets/baselines_published/syncfuse.py +270 -0
  32. experiments/nets/models.py +648 -0
  33. experiments/nets/models_forecast.py +269 -0
  34. experiments/nets/models_forecast_priv.py +76 -0
  35. experiments/nets/models_seqpred.py +806 -0
  36. experiments/nets/published_models.py +699 -0
  37. experiments/s9_primitives.json +76 -0
  38. experiments/slurm/freeze_all_rows.sh +179 -0
  39. experiments/slurm/run_ablation_fix.sh +33 -0
  40. experiments/slurm/run_ablation_fusion.sh +174 -0
  41. experiments/slurm/run_asformer_exp3.sh +44 -0
  42. experiments/slurm/run_exp1.sh +40 -0
  43. experiments/slurm/run_exp1_fusion.sh +36 -0
  44. experiments/slurm/run_exp1_parallel.sh +67 -0
  45. experiments/slurm/run_exp1_small.sh +84 -0
  46. experiments/slurm/run_exp1_small2.sh +85 -0
  47. experiments/slurm/run_exp1_small3.sh +137 -0
  48. experiments/slurm/run_exp1_v3.sh +68 -0
  49. experiments/slurm/run_exp1_v4.sh +69 -0
  50. experiments/slurm/run_exp1_v5.sh +62 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Anonymous Authors (under double-blind review for NeurIPS 2026)
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ library_name: pytorch
6
+ tags:
7
+ - multi-modal
8
+ - daily-activity
9
+ - wearable-sensors
10
+ - benchmark
11
+ ---
12
+
13
+ # PULSE — Code Repository
14
+
15
+ Reference implementation, training scripts, and benchmark baselines for the
16
+ **PULSE** dataset paper (under double-blind review at NeurIPS 2026 Evaluations &
17
+ Datasets Track).
18
+
19
+ > **Dataset:** [`velvet-pine-22/PULSE`](https://huggingface.co/datasets/velvet-pine-22/PULSE)
20
+ > · **Sample subset (≈285 MB):** [`velvet-pine-22/PULSE-sample`](https://huggingface.co/datasets/velvet-pine-22/PULSE-sample)
21
+
22
+ ## Repository layout
23
+
24
+ ```
25
+ PULSE-code/
26
+ ├── experiments/
27
+ │ ├── data/ # PyTorch Dataset wrappers
28
+ │ │ ├── dataset.py # core multi-modal dataset (T1, T2)
29
+ │ │ ├── dataset_seqpred.py # T2 fine-grained action recognition
30
+ │ │ ├── dataset_grasp_state.py # T3 grasp onset anticipation
31
+ │ │ ├── dataset_forecast.py # auxiliary forecasting heads
32
+ │ │ └── dataset_signal_forecast.py # T5 tactile-driven motion forecast
33
+ │ │
34
+ │ ├── nets/ # Model architectures
35
+ │ │ ├── models.py # backbone networks (Transformer / LSTM / 1D-CNN)
36
+ │ │ ├── models_seqpred.py # DailyActFormer (DAF) — multi-modal Transformer
37
+ │ │ ├── models_forecast.py # forecasting heads
38
+ │ │ ├── models_forecast_priv.py # privileged-tactile variants for T5
39
+ │ │ ├── published_models.py # third-party model implementations
40
+ │ │ └── baselines_published/ # 7 published baselines (re-implementation)
41
+ │ │ ├── baselines.py # DeepConvLSTM / InceptionTime / MS-TCN / etc.
42
+ │ │ └── syncfuse.py # under-pressure-style multi-modal fusion
43
+ │ │
44
+ │ ├── tasks/ # Training + evaluation entry points
45
+ │ │ ├── train_exp1.py # T1 — scene recognition
46
+ │ │ ├── train_seqpred.py # T2 — action recognition (DAF + ablations)
47
+ │ │ ├── train_grasp_state.py # T3 — grasp onset anticipation
48
+ │ │ ├── train_pred_cls.py # T3 alt classification head
49
+ │ │ ├── train_exp_missing.py # T4 — missing-modality robustness
50
+ │ │ ├── train_signal_forecast.py # T5 — tactile-driven motion forecasting
51
+ │ │ ├── train_signal_forecast_priv.py # T5 privileged variants
52
+ │ │ ├── train_baselines_t1.py # baselines for T1
53
+ │ │ ├── train_exp{2,3,4}.py # ablation experiments
54
+ │ │ ├── train_exp_{anticipate,grip,pose,retrieval,zeroshot}.py # auxiliary
55
+ │ │ ├── train_pred.py / train_forecast.py
56
+ │ │ ├── eval_baselines.py / eval_combined.py
57
+ │ │ └── published_baselines.py # baseline registry
58
+ │ │
59
+ │ ├── analysis/ # Case study, figures, data prep utilities
60
+ │ │ ├── grasp_phase_analysis.py # case study (gaze→EMG→hand→contact cascade)
61
+ │ │ ├── modality_viz.py / analysis_figures.py / data_statistics_figure.py
62
+ │ │ ├── extract_video_features.py / extract_videomae_features.py
63
+ │ │ ├── build_taxonomy.py / generate_action_labels.py / generate_coarse_annotations.py
64
+ │ │ ├── reannotate_actions.py / gen_val_comparison.py
65
+ │ │ ├── exp_per_subject.py / check_seg_lengths.py
66
+ │ │ └── aggregate_*.py # collate run results
67
+ │ │
68
+ │ ├── slurm/ # 60+ SLURM launch scripts (one per main experiment)
69
+ │ │ └── run_*.sh
70
+ │ │
71
+ │ ├── taxonomy.py # shared 18-primitive taxonomy
72
+ │ ├── s9_primitives.json
73
+ │ └── taxonomy_v3.json
74
+
75
+ ├── scripts/ # Top-level utilities (not task-specific)
76
+ │ ├── build_paper_tables.py # collates results JSONs into LaTeX tables
77
+ │ ├── eval_macrof1.py / eval_subset.py / eval_topk_v3.py
78
+ │ └── dispatch_eval.sh # batch dispatcher
79
+
80
+ ├── LICENSE # MIT
81
+ ├── requirements.txt # Python deps
82
+ └── README.md
83
+ ```
84
+
85
+ ## Quick start
86
+
87
+ ```bash
88
+ # 1. Set up Python environment
89
+ python -m venv .venv && source .venv/bin/activate
90
+ pip install -r requirements.txt
91
+
92
+ # 2. Point at the PULSE dataset (download from HuggingFace first)
93
+ export PULSE_ROOT=/path/to/PULSE # the dataset root (not this code repo)
94
+
95
+ # 3. Run a training entry point as a module (from the experiments/ directory)
96
+ cd experiments
97
+ python -m tasks.train_seqpred \
98
+ --root $PULSE_ROOT \
99
+ --modalities mocap emg eyetrack imu pressure \
100
+ --output_dir runs/t2_daf
101
+
102
+ # 4. Reproduce paper tables (after training all benchmarks)
103
+ cd ..
104
+ python scripts/build_paper_tables.py \
105
+ --results_root experiments/runs/ \
106
+ --out tables/
107
+ ```
108
+
109
+ > **Why `python -m tasks.train_seqpred` and not `python tasks/train_seqpred.py`?**
110
+ > The training scripts import sibling modules (`from data.dataset import …`,
111
+ > `from nets.models import …`). Running with `-m` from the `experiments/`
112
+ > directory makes Python treat `data/`, `nets/`, `tasks/`, and `analysis/` as
113
+ > top-level packages so the imports resolve cleanly.
114
+
115
+ ## Reproducing the benchmark tasks
116
+
117
+ | Task | Entry point | Output |
118
+ |---|---|---|
119
+ | T1 — Scene recognition (8-way) | `tasks.train_exp1` | scene-classification metrics |
120
+ | T2 — Fine-grained action recognition | `tasks.train_seqpred` | verb / noun / hand top-k accuracy |
121
+ | T3 — Grasp onset anticipation | `tasks.train_grasp_state` / `tasks.train_pred_cls` | anticipation F1 / time-to-contact |
122
+ | T4 — Missing-modality robustness | `tasks.train_exp_missing` + `tasks.eval_combined` | per-modality ablation table |
123
+ | T5 — Tactile-driven grasp-state recognition | `tasks.train_signal_forecast` (+ `_priv` variants) | sub-second grasp-state metrics |
124
+ | T6 — Cross-modal pressure prediction | `tasks.train_forecast` / `tasks.train_signal_forecast` | pressure reconstruction metrics |
125
+
126
+ The exact command lines (with hyperparameters, seeds, GPU configs) used for
127
+ every paper table are checked in under `experiments/slurm/run_*.sh`, one
128
+ SLURM script per paper experiment. Output JSON files from these runs are
129
+ collated into LaTeX tables by `scripts/build_paper_tables.py`.
130
+
131
+ ## Hardware
132
+
133
+ Headline experiments were run on **NVIDIA A800 (80 GB)** GPUs. A single seed of
134
+ DailyActFormer T2 trains in ~6 hours on one A800. Most baselines fit on a
135
+ single 24 GB consumer GPU.
136
+
137
+ ## License & attribution
138
+
139
+ Code is released under **MIT** (see `LICENSE`). The PULSE dataset itself is
140
+ released under **CC BY-NC 4.0** (see the dataset repository).
141
+
142
+ ## Citation
143
+
144
+ ```bibtex
145
+ @inproceedings{anonymous2026pulse,
146
+ title = {PULSE: A Synchronized Five-Modality Dataset for Multi-Modal Daily Activity Understanding},
147
+ author = {Anonymous Authors},
148
+ booktitle = {Submitted to NeurIPS 2026 Evaluations and Datasets Track},
149
+ year = {2026},
150
+ note = {Under double-blind review}
151
+ }
152
+ ```
experiments/__init__.py ADDED
File without changes
experiments/analysis/__init__.py ADDED
File without changes
experiments/analysis/aggregate_new_exps.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Aggregate results from the three new benchmark experiments."""
3
+ import os
4
+ import json
5
+ import glob
6
+ import numpy as np
7
+
8
+ ROOT = '${PULSE_ROOT}/results/exp_new'
9
+
10
+
11
+ def load_results(pattern):
12
+ files = sorted(glob.glob(pattern))
13
+ results = []
14
+ for f in files:
15
+ try:
16
+ results.append(json.load(open(f)))
17
+ except Exception as e:
18
+ print(f" ERR: {f}: {e}")
19
+ return results
20
+
21
+
22
+ def aggregate_expA():
23
+ """Missing modality: average across seeds per eval config."""
24
+ print("\n" + "=" * 70)
25
+ print("EXP A: Missing-modality robustness")
26
+ print("=" * 70)
27
+
28
+ for subdir in ['expA_missing', 'expA_baseline']:
29
+ files = load_results(f'{ROOT}/{subdir}/*/results.json')
30
+ if not files:
31
+ print(f" No results yet for {subdir}")
32
+ continue
33
+ print(f"\n-- {subdir} (n seeds = {len(files)}) --")
34
+ # Group by eval config name; accumulate F1/Acc over seeds
35
+ config_stats = {}
36
+ for r in files:
37
+ if 'eval_configs' not in r:
38
+ continue
39
+ for name, info in r['eval_configs'].items():
40
+ config_stats.setdefault(name, {'f1': [], 'acc': [], 'active': info['active']})
41
+ config_stats[name]['f1'].append(info['f1'])
42
+ config_stats[name]['acc'].append(info['acc'])
43
+
44
+ # Order: full, leave-one-out, singletons
45
+ full_names = [n for n in config_stats if n == 'full']
46
+ drop_names = sorted([n for n in config_stats if n.startswith('drop_')])
47
+ only_names = sorted([n for n in config_stats if n.startswith('only_')])
48
+
49
+ print(f" {'Config':<22s} {'Active modalities':<42s} "
50
+ f"{'F1 mean±std':<14s} {'Acc mean±std':<14s}")
51
+ print(' ' + '-' * 96)
52
+ for grp in [full_names, drop_names, only_names]:
53
+ for name in grp:
54
+ d = config_stats[name]
55
+ f1_m, f1_s = np.mean(d['f1']), np.std(d['f1'])
56
+ ac_m, ac_s = np.mean(d['acc']), np.std(d['acc'])
57
+ active = ','.join(d['active'])
58
+ print(f" {name:<22s} {active:<42s} "
59
+ f"{f1_m:.3f}±{f1_s:.3f} {ac_m:.3f}±{ac_s:.3f}")
60
+
61
+
62
+ def aggregate_expB():
63
+ """Grip regression: group by (backbone, mod_config), average over seeds."""
64
+ print("\n" + "=" * 70)
65
+ print("EXP B: Grip force regression")
66
+ print("=" * 70)
67
+ files = load_results(f'{ROOT}/expB_grip/*/results.json')
68
+ if not files:
69
+ print(" No results yet")
70
+ return
71
+
72
+ # Group
73
+ groups = {}
74
+ for r in files:
75
+ if 'best_test_metrics' not in r:
76
+ continue
77
+ key = (r['backbone'], ','.join(r['modalities']))
78
+ groups.setdefault(key, []).append(r)
79
+
80
+ rows = []
81
+ for (bb, mods), rs in groups.items():
82
+ mae_R = [r['best_test_metrics']['right_hand']['mae_g'] for r in rs]
83
+ mae_L = [r['best_test_metrics']['left_hand']['mae_g'] for r in rs]
84
+ r_R = [r['best_test_metrics']['right_hand']['pearson_r'] for r in rs]
85
+ r_L = [r['best_test_metrics']['left_hand']['pearson_r'] for r in rs]
86
+ r2_R = [r['best_test_metrics']['right_hand']['r2'] for r in rs]
87
+ r2_L = [r['best_test_metrics']['left_hand']['r2'] for r in rs]
88
+ mae_avg = [r['best_test_metrics']['avg_mae_g'] for r in rs]
89
+ r_avg = [r['best_test_metrics']['avg_pearson_r'] for r in rs]
90
+ rows.append({
91
+ 'backbone': bb,
92
+ 'modalities': mods,
93
+ 'n_seeds': len(rs),
94
+ 'mae_R': (np.mean(mae_R), np.std(mae_R)),
95
+ 'mae_L': (np.mean(mae_L), np.std(mae_L)),
96
+ 'mae_avg': (np.mean(mae_avg), np.std(mae_avg)),
97
+ 'r_R': (np.mean(r_R), np.std(r_R)),
98
+ 'r_L': (np.mean(r_L), np.std(r_L)),
99
+ 'r_avg': (np.mean(r_avg), np.std(r_avg)),
100
+ 'r2_R': (np.mean(r2_R), np.std(r2_R)),
101
+ 'r2_L': (np.mean(r2_L), np.std(r2_L)),
102
+ })
103
+ rows.sort(key=lambda r: r['r_avg'][0], reverse=True)
104
+ print(f" {'Backbone':<12s} {'Modalities':<30s} N "
105
+ f"{'MAE(g) avg':<14s} {'Pearson r avg':<14s} {'R²(R)':<12s} {'R²(L)':<12s}")
106
+ print(' ' + '-' * 102)
107
+ for row in rows:
108
+ print(f" {row['backbone']:<12s} {row['modalities']:<30s} {row['n_seeds']} "
109
+ f"{row['mae_avg'][0]:.1f}±{row['mae_avg'][1]:.1f} "
110
+ f"{row['r_avg'][0]:.3f}±{row['r_avg'][1]:.3f} "
111
+ f"{row['r2_R'][0]:.3f}±{row['r2_R'][1]:.3f} "
112
+ f"{row['r2_L'][0]:.3f}±{row['r2_L'][1]:.3f}")
113
+
114
+
115
+ def aggregate_expC():
116
+ """T5 retrieval: group by mod config, average over seeds."""
117
+ print("\n" + "=" * 70)
118
+ print("EXP C: T5 Cross-modal text retrieval")
119
+ print("=" * 70)
120
+ files = load_results(f'{ROOT}/expC_retrieval/*/results.json')
121
+ if not files:
122
+ print(" No results yet")
123
+ return
124
+ groups = {}
125
+ for r in files:
126
+ if 'final_avg_over_3_pool_seeds' not in r:
127
+ continue
128
+ key = ','.join(r['modalities'])
129
+ groups.setdefault(key, []).append(r)
130
+
131
+ rows = []
132
+ for mods, rs in groups.items():
133
+ r1 = [r['final_avg_over_3_pool_seeds']['recall@1'] for r in rs]
134
+ r5 = [r['final_avg_over_3_pool_seeds']['recall@5'] for r in rs]
135
+ r10 = [r['final_avg_over_3_pool_seeds']['recall@10'] for r in rs]
136
+ medR = [r['final_avg_over_3_pool_seeds']['median_rank'] for r in rs]
137
+ rows.append({
138
+ 'modalities': mods,
139
+ 'n_seeds': len(rs),
140
+ 'r1': (np.mean(r1), np.std(r1)),
141
+ 'r5': (np.mean(r5), np.std(r5)),
142
+ 'r10': (np.mean(r10), np.std(r10)),
143
+ 'medR': (np.mean(medR), np.std(medR)),
144
+ 'n_test': rs[0].get('n_test_segments', 0),
145
+ 'K': rs[0].get('K_pool', 100),
146
+ })
147
+ rows.sort(key=lambda r: r['r10'][0], reverse=True)
148
+ print(f" {'Modalities':<30s} N N_test K "
149
+ f"{'R@1':<12s} {'R@5':<12s} {'R@10':<12s} {'medR':<12s}")
150
+ print(' ' + '-' * 100)
151
+ for row in rows:
152
+ print(f" {row['modalities']:<30s} {row['n_seeds']} {row['n_test']:<6d} {row['K']:<2d} "
153
+ f"{row['r1'][0]:.3f}±{row['r1'][1]:.3f} "
154
+ f"{row['r5'][0]:.3f}±{row['r5'][1]:.3f} "
155
+ f"{row['r10'][0]:.3f}±{row['r10'][1]:.3f} "
156
+ f"{row['medR'][0]:.1f}±{row['medR'][1]:.1f}")
157
+
158
+
159
+ def main():
160
+ aggregate_expA()
161
+ aggregate_expB()
162
+ aggregate_expC()
163
+
164
+
165
+ if __name__ == '__main__':
166
+ main()
experiments/analysis/aggregate_t1_extended.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Aggregate T1 extended benchmark results.
3
+ Prints a Markdown-style table sorted by F1 desc."""
4
+ import os
5
+ import json
6
+ import glob
7
+ import numpy as np
8
+ from collections import defaultdict
9
+
10
+ ROOT = '${PULSE_ROOT}/results/t1_extended'
11
+
12
+
13
+ def collect(pattern):
14
+ by_key = defaultdict(list)
15
+ for f in sorted(glob.glob(pattern)):
16
+ try:
17
+ r = json.load(open(f))
18
+ except Exception as e:
19
+ print(f" ERR reading {f}: {e}")
20
+ continue
21
+ key = r.get('method', os.path.basename(os.path.dirname(f)))
22
+ # Distinguish ablations by tag
23
+ tag = r.get('args', {}).get('tag', '')
24
+ if tag:
25
+ key = f"{key}_{tag}"
26
+ by_key[key].append(r)
27
+ return by_key
28
+
29
+
30
+ def main():
31
+ groups = collect(f'{ROOT}/*/results.json')
32
+ rows = []
33
+ for key, rs in groups.items():
34
+ f1s = [r['test_f1'] for r in rs]
35
+ accs = [r['test_acc'] for r in rs]
36
+ mods = ','.join(rs[0]['modalities'])
37
+ rows.append({
38
+ 'method': key,
39
+ 'modalities': mods,
40
+ 'n_seeds': len(rs),
41
+ 'f1_mean': np.mean(f1s),
42
+ 'f1_std': np.std(f1s),
43
+ 'acc_mean': np.mean(accs),
44
+ 'acc_std': np.std(accs),
45
+ 'n_params': rs[0].get('n_params', 0),
46
+ })
47
+ rows.sort(key=lambda r: r['f1_mean'], reverse=True)
48
+
49
+ print(f"\n{'Method':<28s} {'Modalities':<32s} N {'F1 mean±std':<14s} "
50
+ f"{'Acc mean±std':<14s} Params")
51
+ print('-' * 110)
52
+ for r in rows:
53
+ print(f"{r['method']:<28s} {r['modalities']:<32s} {r['n_seeds']} "
54
+ f"{r['f1_mean']:.3f}±{r['f1_std']:.3f} "
55
+ f"{r['acc_mean']:.3f}±{r['acc_std']:.3f} "
56
+ f"{r['n_params']:,}")
57
+
58
+
59
+ if __name__ == '__main__':
60
+ main()
experiments/analysis/analysis_figures.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Generate three showcase figures for the main paper:
3
+ 1. Eye-Hand-Contact coordination (gaze fixation + hand velocity + pressure)
4
+ 2. Pressure fingerprints per action category
5
+ 3. 3D hand trajectory colored by pressure
6
+ """
7
+ import os, glob, json, re
8
+ import numpy as np
9
+ import pandas as pd
10
+ import matplotlib
11
+ matplotlib.use('Agg')
12
+ import matplotlib.pyplot as plt
13
+ from scipy.signal import savgol_filter
14
+
15
+ DATASET = "${PULSE_ROOT}/dataset"
16
+ OUT_DIR = "${PULSE_ROOT}/paper/figures"
17
+ os.makedirs(OUT_DIR, exist_ok=True)
18
+
19
+ PRESSURE_THRESHOLD = 5.0
20
+ FPS = 100
21
+
22
+
23
+ # ============================================================
24
+ # Shared data-loading helpers
25
+ # ============================================================
26
+
27
+ def load_pressure(scenario_dir):
28
+ """Return (T, 2) array of (right_total, left_total) pressure."""
29
+ f = os.path.join(scenario_dir, "aligned_pressure_100hz.csv")
30
+ if not os.path.exists(f):
31
+ return None
32
+ df = pd.read_csv(f, low_memory=False)
33
+ r_cols = [c for c in df.columns if c.startswith('R') and c.endswith('(g)')]
34
+ l_cols = [c for c in df.columns if c.startswith('L') and c.endswith('(g)')]
35
+ if len(r_cols) < 20 or len(l_cols) < 20:
36
+ return None
37
+ r = df[r_cols].apply(pd.to_numeric, errors='coerce').fillna(0).values
38
+ l = df[l_cols].apply(pd.to_numeric, errors='coerce').fillna(0).values
39
+ return r, l # (T, 25) each
40
+
41
+
42
+ def load_emg(scenario_dir):
43
+ f = os.path.join(scenario_dir, "aligned_emg_100hz.csv")
44
+ if not os.path.exists(f):
45
+ return None
46
+ df = pd.read_csv(f, low_memory=False)
47
+ numeric = [c for c in df.select_dtypes(include=[np.number]).columns
48
+ if c not in ('time', 'UTC', 'Frame')]
49
+ if len(numeric) < 4:
50
+ return None
51
+ return np.nan_to_num(df[numeric].values.astype(np.float32))
52
+
53
+
54
+ def load_gaze(scenario_dir):
55
+ f = os.path.join(scenario_dir, "aligned_eyetrack_100hz.csv")
56
+ if not os.path.exists(f):
57
+ return None
58
+ df = pd.read_csv(f, low_memory=False)
59
+ gx_col = [c for c in df.columns if 'Gaze X' in c and 'Scene Cam' in c]
60
+ gy_col = [c for c in df.columns if 'Gaze Y' in c and 'Scene Cam' in c]
61
+ if gx_col and gy_col:
62
+ gx = pd.to_numeric(df[gx_col[0]], errors='coerce').fillna(0).values
63
+ gy = pd.to_numeric(df[gy_col[0]], errors='coerce').fillna(0).values
64
+ return np.stack([gx, gy], axis=1)
65
+ return None
66
+
67
+
68
+ def load_mocap_hand(scenario_dir, vol, scenario):
69
+ """Return wrist 3D position (T,3) and tip position summary."""
70
+ f = os.path.join(scenario_dir, f"aligned_{vol}{scenario}_s_Q.tsv")
71
+ if not os.path.exists(f):
72
+ return None, None
73
+ df = pd.read_csv(f, sep='\t', low_memory=False)
74
+ # Right hand wrist (try several naming patterns)
75
+ candidates = [
76
+ ['RightHand_X','RightHand_Y','RightHand_Z'],
77
+ ['R_Hand_X','R_Hand_Y','R_Hand_Z'],
78
+ ['Q_RWristIn_X','Q_RWristIn_Y','Q_RWristIn_Z'],
79
+ ]
80
+ r_wrist = None
81
+ for cs in candidates:
82
+ if all(c in df.columns for c in cs):
83
+ r_wrist = df[cs].apply(pd.to_numeric, errors='coerce').fillna(0).values
84
+ break
85
+ l_wrist = None
86
+ for cs_l in [['LeftHand_X','LeftHand_Y','LeftHand_Z'],
87
+ ['L_Hand_X','L_Hand_Y','L_Hand_Z'],
88
+ ['Q_LWristIn_X','Q_LWristIn_Y','Q_LWristIn_Z']]:
89
+ if all(c in df.columns for c in cs_l):
90
+ l_wrist = df[cs_l].apply(pd.to_numeric, errors='coerce').fillna(0).values
91
+ break
92
+ return r_wrist, l_wrist
93
+
94
+
95
+ def compute_velocity(position, window=5):
96
+ """Magnitude of velocity (after smoothing)."""
97
+ vel = np.zeros_like(position)
98
+ vel[1:] = position[1:] - position[:-1]
99
+ mag = np.linalg.norm(vel, axis=1)
100
+ try:
101
+ mag = savgol_filter(mag, window_length=min(window*2+1, len(mag)-1 if len(mag)%2==0 else len(mag)), polyorder=2)
102
+ except:
103
+ pass
104
+ return mag
105
+
106
+
107
+ def detect_grasp_events(hand_pressure, threshold=PRESSURE_THRESHOLD, min_gap=50):
108
+ """Detect pressure onset events."""
109
+ total = hand_pressure.sum(axis=1) if hand_pressure.ndim == 2 else hand_pressure
110
+ above = total > threshold
111
+ onsets = []
112
+ last_state = False
113
+ for i, a in enumerate(above):
114
+ if a and not last_state:
115
+ if i + 10 < len(above) and np.mean(above[i:i+10]) > 0.7:
116
+ if not onsets or i - onsets[-1] > min_gap:
117
+ onsets.append(i)
118
+ last_state = True
119
+ elif not a and last_state:
120
+ if i + 5 < len(above) and np.mean(above[i:i+5]) < 0.3:
121
+ last_state = False
122
+ return onsets
123
+
124
+
125
+ def emg_envelope(emg, window=20):
126
+ rect = np.abs(emg - np.mean(emg, axis=0))
127
+ kernel = np.ones(window) / window
128
+ env = np.stack([np.convolve(rect[:, c], kernel, mode='same') for c in range(rect.shape[1])], axis=1)
129
+ return env.sum(axis=1)
130
+
131
+
132
+ def gaze_velocity(gaze_xy, window=5):
133
+ """Magnitude of gaze velocity — high = saccade, low = fixation."""
134
+ v = np.zeros_like(gaze_xy)
135
+ v[1:] = gaze_xy[1:] - gaze_xy[:-1]
136
+ mag = np.linalg.norm(v, axis=1)
137
+ try:
138
+ mag = savgol_filter(mag, window_length=min(window*2+1, 15), polyorder=2)
139
+ except:
140
+ pass
141
+ return mag
142
+
143
+
144
+ # ============================================================
145
+ # FIGURE 1: Eye-Hand-Contact coordination
146
+ # ============================================================
147
+ def make_eye_hand_contact_figure():
148
+ print("=== Figure 1: Eye-Hand-Contact coordination ===")
149
+ context = 200 # 2s before + 0.5s after
150
+ after = 50
151
+ events = [] # list of dicts: gaze_vel, hand_vel, pressure, all shape (context+after,)
152
+
153
+ for vol_dir in sorted(glob.glob(f"{DATASET}/v*")):
154
+ vol = os.path.basename(vol_dir)
155
+ for sd in sorted(glob.glob(f"{vol_dir}/s*")):
156
+ scenario = os.path.basename(sd)
157
+ meta_path = os.path.join(sd, "alignment_metadata.json")
158
+ if not os.path.exists(meta_path):
159
+ continue
160
+ meta = json.load(open(meta_path))
161
+ if not {'pressure', 'eyetrack', 'mocap'}.issubset(set(meta['modalities'])):
162
+ continue
163
+
164
+ p = load_pressure(sd)
165
+ g = load_gaze(sd)
166
+ r_wrist, _ = load_mocap_hand(sd, vol, scenario)
167
+ if p is None or g is None or r_wrist is None:
168
+ continue
169
+ r_p, _ = p
170
+ min_len = min(len(r_p), len(g), len(r_wrist))
171
+ r_p, g, r_wrist = r_p[:min_len], g[:min_len], r_wrist[:min_len]
172
+
173
+ hand_vel = compute_velocity(r_wrist)
174
+ gvel = gaze_velocity(g)
175
+ total_p = r_p.sum(axis=1)
176
+
177
+ onsets = detect_grasp_events(r_p)
178
+ for o in onsets:
179
+ if o < context or o + after >= min_len:
180
+ continue
181
+ # Require quiescent pre-grasp
182
+ rest_window = gvel[o-150:o-100]
183
+ vel_rest = hand_vel[o-150:o-100]
184
+ if np.mean(vel_rest) > hand_vel[o-50:o].mean() * 0.5:
185
+ continue
186
+ gv_seg = gvel[o-context:o+after]
187
+ hv_seg = hand_vel[o-context:o+after]
188
+ pr_seg = total_p[o-context:o+after]
189
+ if len(gv_seg) != context+after or np.isnan(gv_seg).any():
190
+ continue
191
+ events.append({'gv': gv_seg, 'hv': hv_seg, 'p': pr_seg})
192
+ if len(events) > 400:
193
+ break
194
+ if len(events) > 400:
195
+ break
196
+
197
+ print(f" Collected {len(events)} events")
198
+ if len(events) < 50:
199
+ print(" Not enough events, skipping")
200
+ return
201
+
202
+ # Gaze: fixation = low gaze velocity, so use "1 - normalized gaze velocity"
203
+ # This represents "gaze fixation stability"
204
+ def norm01(arr):
205
+ arr = np.array(arr)
206
+ arr = arr - arr.min(axis=1, keepdims=True)
207
+ mx = arr.max(axis=1, keepdims=True)
208
+ return arr / (mx + 1e-8)
209
+
210
+ gv_stack = norm01([e['gv'] for e in events])
211
+ hv_stack = norm01([e['hv'] for e in events])
212
+ p_stack = norm01([e['p'] for e in events])
213
+
214
+ # Smooth gaze to show fixation trend
215
+ # Gaze fixation = low velocity. Plot (1 - gaze_velocity) -> rises as gaze fixates
216
+ gaze_fix = 1 - gv_stack # high = fixating
217
+ # Normalize each event's fix to [0,1] for display
218
+ gaze_fix_plot = norm01(gaze_fix)
219
+
220
+ time_axis = np.arange(-context, after) * 10 # ms
221
+
222
+ fig, ax = plt.subplots(figsize=(9, 4.5))
223
+
224
+ for stack, color, label in [
225
+ (gaze_fix_plot, '#8E44AD', 'Gaze fixation'),
226
+ (hv_stack, '#3498DB', 'Hand velocity'),
227
+ (p_stack, '#27AE60', 'Pressure (contact)'),
228
+ ]:
229
+ mean = stack.mean(axis=0)
230
+ std = stack.std(axis=0)
231
+ ax.plot(time_axis, mean, color=color, linewidth=2.5, label=label)
232
+ ax.fill_between(time_axis, mean - std*0.4, mean + std*0.4, color=color, alpha=0.15)
233
+
234
+ ax.axvline(0, color='black', linestyle='--', linewidth=1.2, alpha=0.7)
235
+ ax.set_xlabel('Time relative to contact onset (ms)', fontsize=12)
236
+ ax.set_ylabel('Normalized amplitude', fontsize=12)
237
+ ax.set_title(f'Gaze → Hand → Contact coordination ({len(events)} events)',
238
+ fontsize=13, fontweight='bold')
239
+ ax.set_xlim(-2000, 500)
240
+ ax.legend(loc='upper left', fontsize=10, frameon=True)
241
+ ax.grid(True, alpha=0.3)
242
+ ax.set_ylim(-0.05, 1.1)
243
+
244
+ plt.tight_layout()
245
+ out_path = os.path.join(OUT_DIR, 'eye_hand_contact.pdf')
246
+ plt.savefig(out_path, dpi=150, bbox_inches='tight')
247
+ plt.savefig(out_path.replace('.pdf', '.png'), dpi=150, bbox_inches='tight')
248
+ plt.close()
249
+ print(f" Saved {out_path}")
250
+
251
+
252
+ # ============================================================
253
+ # FIGURE 2: Pressure fingerprints per action category
254
+ # ============================================================
255
+ def make_pressure_fingerprints():
256
+ print("\n=== Figure 2: Pressure fingerprints ===")
257
+ import sys
258
+ sys.path.insert(0, '${PULSE_ROOT}')
259
+ from experiments.train_exp2 import load_annotations
260
+
261
+ # For each action class, accumulate mean pressure profile (50 channels)
262
+ action_r_sum = {} # action -> (sum 25 channels, count)
263
+ action_l_sum = {}
264
+
265
+ for vol_dir in sorted(glob.glob(f"{DATASET}/v*")):
266
+ vol = os.path.basename(vol_dir)
267
+ for sd in sorted(glob.glob(f"{vol_dir}/s*")):
268
+ scenario = os.path.basename(sd)
269
+ meta_path = os.path.join(sd, "alignment_metadata.json")
270
+ if not os.path.exists(meta_path):
271
+ continue
272
+ meta = json.load(open(meta_path))
273
+ if 'pressure' not in set(meta['modalities']):
274
+ continue
275
+ p = load_pressure(sd)
276
+ if p is None:
277
+ continue
278
+ r_p, l_p = p
279
+ labels = load_annotations(vol, scenario, len(r_p), sampling_rate=100, use_coarse=False)
280
+ if labels is None:
281
+ continue
282
+ labels = labels[:len(r_p)]
283
+ from experiments.train_exp2 import ACTION_NAMES
284
+ for a_id, a_name in ACTION_NAMES.items():
285
+ if a_name == 'Idle':
286
+ continue
287
+ mask = labels == a_id
288
+ if mask.sum() < 10:
289
+ continue
290
+ r_mean = r_p[mask].mean(axis=0)
291
+ l_mean = l_p[mask].mean(axis=0)
292
+ if a_name not in action_r_sum:
293
+ action_r_sum[a_name] = [np.zeros(25), 0]
294
+ action_l_sum[a_name] = [np.zeros(25), 0]
295
+ action_r_sum[a_name][0] += r_mean * mask.sum()
296
+ action_r_sum[a_name][1] += mask.sum()
297
+ action_l_sum[a_name][0] += l_mean * mask.sum()
298
+ action_l_sum[a_name][1] += mask.sum()
299
+
300
+ # Compute mean for each action
301
+ results = {}
302
+ for a_name in action_r_sum:
303
+ r_cnt = action_r_sum[a_name][1]
304
+ l_cnt = action_l_sum[a_name][1]
305
+ if r_cnt == 0 or l_cnt == 0:
306
+ continue
307
+ results[a_name] = {
308
+ 'r': action_r_sum[a_name][0] / r_cnt,
309
+ 'l': action_l_sum[a_name][0] / l_cnt,
310
+ }
311
+ print(f" Action categories: {list(results.keys())}")
312
+
313
+ if not results:
314
+ print(" No data")
315
+ return
316
+
317
+ # Pick top 6 by frequency (they have most data)
318
+ # Sort by right-hand count
319
+ sorted_actions = sorted(results.keys(),
320
+ key=lambda a: action_r_sum[a][1], reverse=True)[:6]
321
+
322
+ # Plot as 2-row grid: top row = right hand, bottom row = left hand (or combine as single image)
323
+ # Use 25 points arranged as a 5x5 grid (stylized hand layout)
324
+ # Actual finger layout is complex; for visualization use simple grid
325
+ # Layout (rough hand analogy): arrange as fingertips at top, palm base at bottom
326
+ # Index mapping — 25 points, organized heuristically:
327
+ # row 0 (fingertips): 1-5
328
+ # row 1-2: finger segments
329
+ # row 3-4: palm area
330
+ def point_to_xy(idx):
331
+ """Map channel index (0-24) to 2D hand position (stylized)."""
332
+ # Simple 5x5 grid
333
+ row = idx // 5
334
+ col = idx % 5
335
+ return col, 4 - row # flip y so fingertips at top
336
+
337
+ n = len(sorted_actions)
338
+ fig, axes = plt.subplots(2, n, figsize=(2.0 * n, 4.8), squeeze=False)
339
+ vmax = max(max(results[a]['r'].max(), results[a]['l'].max()) for a in sorted_actions)
340
+
341
+ for i, a in enumerate(sorted_actions):
342
+ for row, (hand, title) in enumerate([('r', 'Right'), ('l', 'Left')]):
343
+ ax = axes[row][i]
344
+ data = results[a][hand]
345
+ grid = np.zeros((5, 5))
346
+ for idx, v in enumerate(data):
347
+ x, y = point_to_xy(idx)
348
+ grid[4-y, x] = v
349
+ im = ax.imshow(grid, cmap='hot', vmin=0, vmax=vmax, aspect='equal')
350
+ ax.set_xticks([]); ax.set_yticks([])
351
+ if row == 0:
352
+ ax.set_title(a, fontsize=11, fontweight='bold')
353
+ if i == 0:
354
+ ax.set_ylabel(title, fontsize=10)
355
+
356
+ fig.suptitle('Per-action fingertip pressure signatures (mean across events)',
357
+ fontsize=12, fontweight='bold', y=0.98)
358
+ cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.7, pad=0.02)
359
+ cbar.set_label('Pressure (g)', fontsize=10)
360
+ plt.savefig(os.path.join(OUT_DIR, 'pressure_fingerprints.pdf'), bbox_inches='tight')
361
+ plt.savefig(os.path.join(OUT_DIR, 'pressure_fingerprints.png'), dpi=150, bbox_inches='tight')
362
+ plt.close()
363
+ print(f" Saved pressure_fingerprints.pdf")
364
+
365
+
366
+ # ============================================================
367
+ # FIGURE 3: 3D hand trajectory colored by pressure
368
+ # ============================================================
369
+ def make_3d_trajectory():
370
+ print("\n=== Figure 3: 3D hand trajectory + pressure coloring ===")
371
+ from mpl_toolkits.mplot3d import Axes3D
372
+ # Pick a few illustrative recordings with rich grasping — use v1 s3 (kitchen) or similar
373
+ candidates = [('v1', 's3'), ('v2', 's4'), ('v1', 's5'), ('v1', 's7')]
374
+ picked = []
375
+
376
+ for vol, scn in candidates:
377
+ sd = f"{DATASET}/{vol}/{scn}"
378
+ if not os.path.isdir(sd):
379
+ continue
380
+ p = load_pressure(sd)
381
+ r_wrist, _ = load_mocap_hand(sd, vol, scn)
382
+ if p is None or r_wrist is None:
383
+ continue
384
+ r_p, _ = p
385
+ min_len = min(len(r_p), len(r_wrist))
386
+ total_p = r_p[:min_len].sum(axis=1)
387
+ r_wrist = r_wrist[:min_len]
388
+ # Take a window that contains a grasp
389
+ onsets = detect_grasp_events(r_p[:min_len])
390
+ if not onsets:
391
+ continue
392
+ # Take ~3s centred on first onset
393
+ o = onsets[0]
394
+ start = max(0, o - 150)
395
+ end = min(min_len, o + 150)
396
+ traj = r_wrist[start:end]
397
+ pressure = total_p[start:end]
398
+ picked.append((vol, scn, traj, pressure))
399
+ if len(picked) >= 3:
400
+ break
401
+
402
+ if not picked:
403
+ print(" No valid recordings found")
404
+ return
405
+
406
+ fig = plt.figure(figsize=(3.5 * len(picked), 4))
407
+ for i, (vol, scn, traj, pr) in enumerate(picked):
408
+ ax = fig.add_subplot(1, len(picked), i+1, projection='3d')
409
+ # Normalize pressure for coloring
410
+ pr_norm = pr / (pr.max() + 1e-6)
411
+ # Plot as colored line segments
412
+ for j in range(len(traj) - 1):
413
+ x = traj[j:j+2, 0]
414
+ y = traj[j:j+2, 1]
415
+ z = traj[j:j+2, 2]
416
+ c = plt.cm.coolwarm(pr_norm[j])
417
+ ax.plot(x, y, z, color=c, linewidth=2.5, alpha=0.85)
418
+ # Mark contact point
419
+ contact_idx = np.argmax(pr)
420
+ ax.scatter(traj[contact_idx, 0], traj[contact_idx, 1], traj[contact_idx, 2],
421
+ color='red', s=50, marker='*', zorder=5, label='Peak contact')
422
+ ax.set_title(f'{vol}/{scn}', fontsize=10)
423
+ ax.set_xlabel('X', fontsize=8); ax.set_ylabel('Y', fontsize=8); ax.set_zlabel('Z', fontsize=8)
424
+ ax.tick_params(labelsize=7)
425
+
426
+ # Colorbar
427
+ sm = plt.cm.ScalarMappable(cmap='coolwarm', norm=matplotlib.colors.Normalize(vmin=0, vmax=1))
428
+ sm.set_array([])
429
+ cbar = fig.colorbar(sm, ax=fig.axes, shrink=0.6, pad=0.02)
430
+ cbar.set_label('Normalised pressure', fontsize=10)
431
+
432
+ fig.suptitle('Right-hand wrist 3D trajectory coloured by fingertip pressure',
433
+ fontsize=12, fontweight='bold', y=1.02)
434
+ plt.savefig(os.path.join(OUT_DIR, 'hand_trajectory_3d.pdf'), bbox_inches='tight')
435
+ plt.savefig(os.path.join(OUT_DIR, 'hand_trajectory_3d.png'), dpi=150, bbox_inches='tight')
436
+ plt.close()
437
+ print(f" Saved hand_trajectory_3d.pdf")
438
+
439
+
440
+ if __name__ == '__main__':
441
+ make_eye_hand_contact_figure()
442
+ make_pressure_fingerprints()
443
+ make_3d_trajectory()
444
+ print("\nAll figures generated in", OUT_DIR)
experiments/analysis/build_taxonomy.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Rebuild the frozen taxonomy JSON from the current annotations_v3/ state.
4
+
5
+ Run this *once* after annotation is complete to lock the 28+ noun list. Later
6
+ experiments load the frozen list via taxonomy.py, so class indices don't
7
+ drift if more annotations are ever added.
8
+
9
+ Usage:
10
+ python3 experiments/build_taxonomy.py
11
+ python3 experiments/build_taxonomy.py --threshold 50 --out experiments/taxonomy_v3.json
12
+ """
13
+
14
+ import argparse
15
+ import glob
16
+ import json
17
+ import os
18
+ from collections import Counter
19
+ from pathlib import Path
20
+
21
+ REPO = Path(__file__).resolve().parents[1]
22
+
23
+
24
+ def main():
25
+ ap = argparse.ArgumentParser()
26
+ ap.add_argument(
27
+ "--annotations_dir",
28
+ default=str(REPO / "annotations_v3"),
29
+ help="Directory containing v*/s*.json annotation files",
30
+ )
31
+ ap.add_argument("--threshold", type=int, default=50,
32
+ help="Minimum noun frequency to keep (Strategy A drops the rest)")
33
+ ap.add_argument(
34
+ "--out",
35
+ default=str(REPO / "experiments" / "taxonomy_v3.json"),
36
+ help="Output frozen taxonomy JSON",
37
+ )
38
+ args = ap.parse_args()
39
+
40
+ # Late import so building the list doesn't depend on the frozen file
41
+ # being present yet.
42
+ import sys
43
+ sys.path.insert(0, str(REPO))
44
+ from experiments.taxonomy import (
45
+ VERB_FINE, VERB_COMPOSITE, HAND, NOUN_CANONICAL, canonical_noun,
46
+ )
47
+
48
+ paths = sorted(glob.glob(os.path.join(args.annotations_dir, "v*", "s*.json")))
49
+ if not paths:
50
+ raise SystemExit(f"No json files under {args.annotations_dir}")
51
+
52
+ verbs, nouns, hands = Counter(), Counter(), Counter()
53
+ total = 0
54
+ dropped_unknown_verb = 0
55
+ dropped_unknown_hand = 0
56
+ for p in paths:
57
+ try:
58
+ with open(p) as f:
59
+ d = json.load(f)
60
+ except Exception as e:
61
+ print(f" WARN: could not parse {p}: {e}")
62
+ continue
63
+ for s in d.get("segments", []):
64
+ a = s.get("action_annotation", {})
65
+ v = a.get("action_name")
66
+ n = a.get("object_name")
67
+ h = a.get("hand_type")
68
+ if not (v and n and h):
69
+ continue
70
+ total += 1
71
+ if v not in VERB_FINE:
72
+ dropped_unknown_verb += 1
73
+ continue
74
+ if h not in HAND:
75
+ dropped_unknown_hand += 1
76
+ continue
77
+ verbs[v] += 1
78
+ nouns[canonical_noun(n)] += 1
79
+ hands[h] += 1
80
+
81
+ kept = [n for n, c in nouns.most_common() if c >= args.threshold]
82
+
83
+ # Stable alphabetical ordering within kept-set, so re-runs that swap two
84
+ # near-tie classes don't flip indices.
85
+ kept = sorted(kept, key=lambda n: (-nouns[n], n))
86
+
87
+ surviving_segs = 0
88
+ for p in paths:
89
+ with open(p) as f:
90
+ d = json.load(f)
91
+ for s in d.get("segments", []):
92
+ a = s.get("action_annotation", {})
93
+ v = a.get("action_name")
94
+ n = a.get("object_name")
95
+ h = a.get("hand_type")
96
+ if not (v and n and h):
97
+ continue
98
+ if v not in VERB_FINE or h not in HAND:
99
+ continue
100
+ if canonical_noun(n) not in kept:
101
+ continue
102
+ surviving_segs += 1
103
+
104
+ out = {
105
+ "threshold": args.threshold,
106
+ "annotation_file_count": len(paths),
107
+ "total_segments": total,
108
+ "dropped_unknown_verb": dropped_unknown_verb,
109
+ "dropped_unknown_hand": dropped_unknown_hand,
110
+ "surviving_segments": surviving_segs,
111
+ "verbs": VERB_FINE,
112
+ "verb_composite": VERB_COMPOSITE,
113
+ "hand": HAND,
114
+ "nouns": kept,
115
+ "noun_counts": {n: nouns[n] for n in kept},
116
+ "verb_counts": dict(verbs),
117
+ "hand_counts": dict(hands),
118
+ }
119
+ Path(args.out).parent.mkdir(parents=True, exist_ok=True)
120
+ with open(args.out, "w") as f:
121
+ json.dump(out, f, ensure_ascii=False, indent=2)
122
+
123
+ print(f"Scanned {len(paths)} files, {total} segments")
124
+ print(f"Dropped (unknown verb / hand): {dropped_unknown_verb} / "
125
+ f"{dropped_unknown_hand}")
126
+ print(f"Kept {len(kept)} nouns (>= {args.threshold}):")
127
+ for n in kept:
128
+ print(f" {n}: {nouns[n]}")
129
+ print(f"Surviving segments (Strategy A): "
130
+ f"{surviving_segs} / {total} "
131
+ f"({100 * surviving_segs / max(1, total):.1f}%)")
132
+ print(f"Wrote {args.out}")
133
+
134
+
135
+ if __name__ == "__main__":
136
+ main()
experiments/analysis/check_seg_lengths.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Analyze segment lengths in the recognition dataset.
4
+
5
+ For each annotation file, computes segment lengths in:
6
+ - Raw frames (at 100Hz sampling rate)
7
+ - Downsampled frames (downsample=5 -> 20Hz effective)
8
+
9
+ Reports statistics and distribution relative to window_frames used in training.
10
+ """
11
+
12
+ import os
13
+ import sys
14
+ import json
15
+ import re
16
+ import numpy as np
17
+ from collections import defaultdict
18
+
19
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
20
+ from data.dataset import DATASET_DIR, TRAIN_VOLS, VAL_VOLS, TEST_VOLS
21
+
22
+ ANNOTATION_DIR = "${PULSE_ROOT}"
23
+ SAMPLING_RATE = 100 # Hz
24
+ DOWNSAMPLE = 5
25
+
26
+
27
+ def parse_timestamp(ts_str):
28
+ parts = ts_str.strip().split(':')
29
+ if len(parts) == 2:
30
+ return int(parts[0]) * 60 + int(parts[1])
31
+ elif len(parts) == 3:
32
+ return int(parts[0]) * 3600 + int(parts[1]) * 60 + int(parts[2])
33
+ return 0
34
+
35
+
36
+ def main():
37
+ all_vols = TRAIN_VOLS + VAL_VOLS + TEST_VOLS
38
+
39
+ # Collect segment lengths
40
+ raw_lengths_sec = [] # in seconds
41
+ raw_lengths_frames = [] # in raw 100Hz frames
42
+ ds_lengths_frames = [] # in downsampled frames (100/5 = 20Hz)
43
+
44
+ split_stats = defaultdict(list) # split -> list of ds_lengths
45
+
46
+ total_scenarios = 0
47
+ total_segments = 0
48
+ skipped_segments = 0
49
+
50
+ for vol in sorted(all_vols):
51
+ # Determine split
52
+ if vol in TRAIN_VOLS:
53
+ split = 'train'
54
+ elif vol in VAL_VOLS:
55
+ split = 'val'
56
+ else:
57
+ split = 'test'
58
+
59
+ ann_vol_dir = os.path.join(ANNOTATION_DIR, vol)
60
+ if not os.path.isdir(ann_vol_dir):
61
+ print(f"WARNING: No annotation dir for {vol}")
62
+ continue
63
+
64
+ for ann_file in sorted(os.listdir(ann_vol_dir)):
65
+ if not ann_file.endswith('.json'):
66
+ continue
67
+ scenario = ann_file.replace('.json', '')
68
+ ann_path = os.path.join(ann_vol_dir, ann_file)
69
+
70
+ # Also check that corresponding dataset dir exists
71
+ scenario_dir = os.path.join(DATASET_DIR, vol, scenario)
72
+ if not os.path.isdir(scenario_dir):
73
+ continue
74
+
75
+ with open(ann_path) as f:
76
+ ann = json.load(f)
77
+
78
+ total_scenarios += 1
79
+
80
+ for seg in ann.get('segments', []):
81
+ m = re.match(r'(\d+:\d+(?::\d+)?)\s*-\s*(\d+:\d+(?::\d+)?)',
82
+ seg['timestamp'])
83
+ if not m:
84
+ skipped_segments += 1
85
+ continue
86
+
87
+ start_sec = parse_timestamp(m.group(1))
88
+ end_sec = parse_timestamp(m.group(2))
89
+
90
+ if end_sec <= start_sec:
91
+ skipped_segments += 1
92
+ continue
93
+
94
+ duration_sec = end_sec - start_sec
95
+ raw_frames = duration_sec * SAMPLING_RATE
96
+ ds_frames = int(end_sec * SAMPLING_RATE / DOWNSAMPLE) - int(start_sec * SAMPLING_RATE / DOWNSAMPLE)
97
+
98
+ raw_lengths_sec.append(duration_sec)
99
+ raw_lengths_frames.append(raw_frames)
100
+ ds_lengths_frames.append(ds_frames)
101
+ split_stats[split].append(ds_frames)
102
+ total_segments += 1
103
+
104
+ # Convert to numpy
105
+ raw_sec = np.array(raw_lengths_sec)
106
+ raw_fr = np.array(raw_lengths_frames)
107
+ ds_fr = np.array(ds_lengths_frames)
108
+
109
+ print("=" * 70)
110
+ print("SEGMENT LENGTH ANALYSIS FOR RECOGNITION DATASET")
111
+ print("=" * 70)
112
+ print(f"\nTotal scenarios: {total_scenarios}")
113
+ print(f"Total valid segments: {total_segments}")
114
+ print(f"Skipped segments (bad timestamp): {skipped_segments}")
115
+ print(f"Sampling rate: {SAMPLING_RATE} Hz")
116
+ print(f"Downsample factor: {DOWNSAMPLE}")
117
+ print(f"Effective rate after downsample: {SAMPLING_RATE / DOWNSAMPLE} Hz")
118
+
119
+ # --- Raw seconds ---
120
+ print("\n" + "-" * 70)
121
+ print("SEGMENT DURATION (seconds)")
122
+ print("-" * 70)
123
+ print(f" Min: {raw_sec.min():.1f}s")
124
+ print(f" Max: {raw_sec.max():.1f}s")
125
+ print(f" Mean: {raw_sec.mean():.2f}s")
126
+ print(f" Median: {np.median(raw_sec):.1f}s")
127
+ print(f" Std: {raw_sec.std():.2f}s")
128
+
129
+ # Percentiles
130
+ for p in [5, 10, 25, 50, 75, 90, 95]:
131
+ print(f" P{p:2d}: {np.percentile(raw_sec, p):.1f}s")
132
+
133
+ # --- Raw frames (100Hz) ---
134
+ print("\n" + "-" * 70)
135
+ print("SEGMENT LENGTH (raw frames @ 100Hz)")
136
+ print("-" * 70)
137
+ print(f" Min: {raw_fr.min()}")
138
+ print(f" Max: {raw_fr.max()}")
139
+ print(f" Mean: {raw_fr.mean():.1f}")
140
+ print(f" Median: {np.median(raw_fr):.0f}")
141
+
142
+ # --- Downsampled frames ---
143
+ print("\n" + "-" * 70)
144
+ print(f"SEGMENT LENGTH (downsampled frames @ {SAMPLING_RATE/DOWNSAMPLE:.0f}Hz)")
145
+ print("-" * 70)
146
+ print(f" Min: {ds_fr.min()}")
147
+ print(f" Max: {ds_fr.max()}")
148
+ print(f" Mean: {ds_fr.mean():.1f}")
149
+ print(f" Median: {np.median(ds_fr):.0f}")
150
+ print(f" Std: {ds_fr.std():.1f}")
151
+
152
+ for p in [5, 10, 25, 50, 75, 90, 95]:
153
+ print(f" P{p:2d}: {np.percentile(ds_fr, p):.0f}")
154
+
155
+ # --- Comparison with window_frames ---
156
+ print("\n" + "-" * 70)
157
+ print("COMPARISON WITH window_frames SETTINGS")
158
+ print("-" * 70)
159
+
160
+ # Common window_sec values and their corresponding window_frames
161
+ for window_sec in [5.0, 10.0, 15.0, 20.0, 30.0]:
162
+ wf = int(window_sec * SAMPLING_RATE / DOWNSAMPLE)
163
+ shorter = (ds_fr < wf).sum()
164
+ equal_or_longer = (ds_fr >= wf).sum()
165
+ longer = (ds_fr > wf).sum()
166
+ pct_shorter = 100.0 * shorter / len(ds_fr)
167
+ pct_longer = 100.0 * longer / len(ds_fr)
168
+ print(f"\n window_sec={window_sec:5.1f}s -> window_frames={wf}")
169
+ print(f" Segments SHORTER than window: {shorter:4d} ({pct_shorter:5.1f}%) -> will be PADDED")
170
+ print(f" Segments LONGER than window: {longer:4d} ({pct_longer:5.1f}%) -> will be CENTER-CROPPED")
171
+
172
+ # --- Thresholds in downsampled frames ---
173
+ print("\n" + "-" * 70)
174
+ print("PERCENTAGE SHORTER THAN THRESHOLDS (downsampled frames)")
175
+ print("-" * 70)
176
+ for thresh in [20, 40, 60, 100, 200, 300, 400, 500, 1000, 2000]:
177
+ pct = 100.0 * (ds_fr < thresh).sum() / len(ds_fr)
178
+ print(f" < {thresh:5d} frames ({thresh * DOWNSAMPLE / SAMPLING_RATE:6.1f}s): {pct:5.1f}%")
179
+
180
+ # --- Per-split stats ---
181
+ print("\n" + "-" * 70)
182
+ print("PER-SPLIT STATISTICS (downsampled frames)")
183
+ print("-" * 70)
184
+ for split in ['train', 'val', 'test']:
185
+ arr = np.array(split_stats[split])
186
+ if len(arr) == 0:
187
+ print(f" {split}: no segments")
188
+ continue
189
+ print(f"\n {split.upper()} ({len(arr)} segments):")
190
+ print(f" Min={arr.min()}, Max={arr.max()}, Mean={arr.mean():.1f}, Median={np.median(arr):.0f}")
191
+
192
+ # --- Histogram (text-based) ---
193
+ print("\n" + "-" * 70)
194
+ print("HISTOGRAM OF SEGMENT DURATIONS (seconds)")
195
+ print("-" * 70)
196
+ bins = [0, 1, 2, 3, 4, 5, 7, 10, 15, 20, 30, 60, 120, 300, 600]
197
+ for i in range(len(bins) - 1):
198
+ count = ((raw_sec >= bins[i]) & (raw_sec < bins[i + 1])).sum()
199
+ pct = 100.0 * count / len(raw_sec)
200
+ bar = '#' * int(pct / 2)
201
+ print(f" [{bins[i]:4d}-{bins[i+1]:4d})s: {count:5d} ({pct:5.1f}%) {bar}")
202
+ # Last bin: >= 600
203
+ count = (raw_sec >= bins[-1]).sum()
204
+ pct = 100.0 * count / len(raw_sec)
205
+ bar = '#' * int(pct / 2)
206
+ print(f" [{bins[-1]:4d}+ )s: {count:5d} ({pct:5.1f}%) {bar}")
207
+
208
+ # --- Key insight ---
209
+ print("\n" + "=" * 70)
210
+ print("KEY INSIGHTS")
211
+ print("=" * 70)
212
+ median_sec = np.median(raw_sec)
213
+ mean_sec = raw_sec.mean()
214
+ print(f" Median segment duration: {median_sec:.1f}s ({median_sec * SAMPLING_RATE / DOWNSAMPLE:.0f} ds-frames)")
215
+ print(f" Mean segment duration: {mean_sec:.1f}s ({mean_sec * SAMPLING_RATE / DOWNSAMPLE:.0f} ds-frames)")
216
+ print()
217
+ # Suggest optimal window
218
+ p95_sec = np.percentile(raw_sec, 95)
219
+ print(f" 95th percentile duration: {p95_sec:.1f}s")
220
+ print(f" -> A window of {p95_sec:.0f}s would cover 95% of segments without cropping")
221
+ print(f" -> Current default window_sec=15.0 -> window_frames={int(15.0 * SAMPLING_RATE / DOWNSAMPLE)}")
222
+ wf15 = int(15.0 * SAMPLING_RATE / DOWNSAMPLE)
223
+ pct_crop = 100.0 * (ds_fr > wf15).sum() / len(ds_fr)
224
+ pct_pad = 100.0 * (ds_fr < wf15).sum() / len(ds_fr)
225
+ print(f" {pct_pad:.1f}% segments padded, {pct_crop:.1f}% center-cropped")
226
+
227
+
228
+ if __name__ == '__main__':
229
+ main()
experiments/analysis/data_statistics_figure.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate dataset statistics figure from the currently-available annotations.
2
+
3
+ Panels (3):
4
+ (a) Recording duration distribution per scene (boxplot)
5
+ (b) Segment length distribution (histogram)
6
+ (c) Top-20 manipulated objects by segment count
7
+
8
+ Note: panel for motor-primitive frequency is deferred until the 18-primitive
9
+ annotation pipeline (anno.py) is rerun across all recordings.
10
+ """
11
+ import json, re
12
+ from pathlib import Path
13
+ from collections import Counter
14
+ import numpy as np
15
+ import matplotlib.pyplot as plt
16
+
17
+ ANNO_DIR = Path("${PULSE_ROOT}/annotations_by_scene")
18
+ OUT = Path("${PULSE_ROOT}/paper/figures/dataset_stats.pdf")
19
+
20
+ # Chinese -> English object name mapping (from anno.py OBJECT_TRANSLATIONS)
21
+ OBJ_EN = {
22
+ "笔记本电脑": "laptop", "有线鼠标": "wired mouse", "有线键盘": "wired keyboard",
23
+ "马克笔": "marker", "胶带": "tape", "笔记本电源": "laptop power", "折叠伞": "umbrella",
24
+ "剪刀": "scissors", "钱包": "wallet", "纸": "paper", "订书机": "stapler",
25
+ "纸箱": "box", "文件": "document", "架子": "rack", "桌布": "tablecloth", "罐子": "jar",
26
+ "调料瓶": "seasoning bottle", "密封罐": "sealed jar", "厨房纸巾": "kitchen paper",
27
+ "抹布": "cloth", "茶包": "tea bag", "饭碗": "rice bowl", "菜盘": "plate",
28
+ "菜锅": "pot", "勺子": "spoon", "水杯": "water cup", "茶杯": "tea cup",
29
+ "茶壶": "teapot", "食物残渣": "food residue", "垃圾桶": "trash bin",
30
+ "纸巾": "tissue", "餐垫": "placemat", "托盘": "tray", "清洁喷雾": "spray",
31
+ "食物": "food", "电源": "power adapter", "移动硬盘": "HDD", "鼠标": "mouse",
32
+ "笔记本充电器": "laptop charger", "转换插头": "plug adapter", "插线板": "power strip",
33
+ "线材收纳包": "cable organizer", "衬衫": "shirt", "裤子": "pants",
34
+ "牙膏": "toothpaste", "牙刷": "toothbrush", "牙刷盒": "toothbrush case",
35
+ "剃须刀": "razor", "毛巾": "towel", "皮鞋": "shoes", "鞋袋": "shoe bag",
36
+ "耳机": "headphones", "护照套": "passport holder", "证件夹": "ID holder",
37
+ "纸巾包": "tissue pack", "行李箱": "suitcase", "马克杯": "mug",
38
+ "调料罐": "seasoning jar", "茶罐": "tea canister", "外套": "coat",
39
+ "围巾": "scarf", "衣架": "hanger",
40
+ }
41
+
42
+
43
+ def parse_t(ts: str) -> float:
44
+ parts = ts.split(":")
45
+ if len(parts) == 2: # MM:SS
46
+ m, s = parts
47
+ return int(m) * 60 + int(s)
48
+ h, m, s = parts
49
+ return int(h) * 3600 + int(m) * 60 + int(s)
50
+
51
+
52
+ durations = {f"S{i}": [] for i in range(1, 9)}
53
+ seg_lengths = []
54
+ objects = Counter()
55
+
56
+ for v_dir in sorted(ANNO_DIR.glob("v*")):
57
+ for jf in sorted(v_dir.glob("s*.json")):
58
+ scene = jf.stem.upper()
59
+ try:
60
+ data = json.loads(jf.read_text())
61
+ except Exception:
62
+ continue
63
+ segs = data.get("segments", [])
64
+ if not segs:
65
+ continue
66
+ max_end = 0
67
+ for seg in segs:
68
+ ts = seg.get("timestamp", "")
69
+ if "-" not in ts:
70
+ continue
71
+ try:
72
+ start, end = ts.split("-")
73
+ s_sec, e_sec = parse_t(start), parse_t(end)
74
+ seg_lengths.append(e_sec - s_sec)
75
+ max_end = max(max_end, e_sec)
76
+ for o in seg.get("objects", []) or []:
77
+ nm = o.get("name") if isinstance(o, dict) else o
78
+ if nm:
79
+ objects[OBJ_EN.get(nm, nm)] += 1
80
+ except Exception:
81
+ continue
82
+ if max_end > 0 and scene in durations:
83
+ durations[scene].append(max_end / 60.0)
84
+
85
+ print(f"Per-scene durations: { {s: len(v) for s, v in durations.items()} }")
86
+ print(f"Total segments: {len(seg_lengths)}")
87
+ print(f"Unique objects: {len(objects)}")
88
+ top_obj = objects.most_common(5)
89
+ print(f"Top objects: {top_obj}")
90
+
91
+ fig, axes = plt.subplots(1, 3, figsize=(12, 3.5))
92
+
93
+ # (a) Duration boxplot per scene
94
+ ax = axes[0]
95
+ scene_order = [f"S{i}" for i in range(1, 9)]
96
+ data = [durations[s] for s in scene_order]
97
+ ax.boxplot(data, tick_labels=scene_order, showfliers=False, patch_artist=True,
98
+ boxprops=dict(facecolor="#b3cde3"))
99
+ ax.set_ylabel("Recording duration (min)")
100
+ ax.set_title("(a) Recording duration per scene")
101
+ ax.grid(axis="y", alpha=0.3)
102
+
103
+ # (b) Segment length histogram
104
+ ax = axes[1]
105
+ seg_arr = np.array(seg_lengths)
106
+ seg_arr = seg_arr[seg_arr <= 10]
107
+ ax.hist(seg_arr, bins=np.arange(0, 11) - 0.5, color="#8c96c6", edgecolor="black")
108
+ ax.set_xlabel("Segment length (s)")
109
+ ax.set_ylabel("Segment count")
110
+ ax.set_title(f"(b) Segment length (n={len(seg_lengths)})")
111
+ ax.set_xticks(range(0, 11))
112
+ ax.grid(axis="y", alpha=0.3)
113
+
114
+ # (c) Top-20 objects
115
+ ax = axes[2]
116
+ objs, ocounts = zip(*objects.most_common(20))
117
+ ax.barh(objs[::-1], ocounts[::-1], color="#74c476")
118
+ ax.set_xlabel("Segment count")
119
+ ax.set_title("(c) Top-20 manipulated objects")
120
+ ax.tick_params(axis="y", labelsize=8)
121
+ ax.grid(axis="x", alpha=0.3)
122
+
123
+ fig.tight_layout()
124
+ fig.savefig(OUT, bbox_inches="tight")
125
+ fig.savefig(str(OUT).replace(".pdf", ".png"), dpi=140, bbox_inches="tight")
126
+ print(f"Saved: {OUT}")
experiments/analysis/exp_per_subject.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Experiment G: Per-subject diagnostic analysis.
4
+
5
+ Load the best scene-recognition checkpoint(s) from previous T1 runs and
6
+ produce a per-test-volunteer breakdown of F1 and Accuracy. Reveals whether
7
+ aggregate metrics are driven by one or two outlier subjects, as reviewers
8
+ often ask.
9
+
10
+ Runs CPU-side; no training.
11
+ """
12
+
13
+ import os
14
+ import sys
15
+ import json
16
+ import glob
17
+ import argparse
18
+ import numpy as np
19
+ import torch
20
+ from sklearn.metrics import accuracy_score, f1_score
21
+
22
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
23
+ from data.dataset import (
24
+ MultimodalSceneDataset, TEST_VOLS, SCENE_LABELS, NUM_CLASSES,
25
+ get_dataloaders,
26
+ )
27
+ from nets.models import build_model
28
+
29
+
30
+ def per_subject_eval(model, device, modalities, stats, downsample):
31
+ """Evaluate one model across each test volunteer separately."""
32
+ breakdown = {}
33
+ for vol in TEST_VOLS:
34
+ ds = MultimodalSceneDataset([vol], modalities, downsample=downsample,
35
+ stats=stats)
36
+ if len(ds) == 0:
37
+ breakdown[vol] = {'n': 0}
38
+ continue
39
+ preds, ys = [], []
40
+ model.eval()
41
+ with torch.no_grad():
42
+ for i in range(len(ds)):
43
+ x, y = ds[i]
44
+ x = x.to(device).unsqueeze(0)
45
+ mask = torch.ones(1, x.size(1), dtype=torch.bool).to(device)
46
+ logits = model(x, mask)
47
+ preds.append(logits.argmax(dim=1).cpu().item())
48
+ ys.append(y)
49
+ breakdown[vol] = {
50
+ 'n': len(ds),
51
+ 'acc': float(accuracy_score(ys, preds)),
52
+ 'f1': float(f1_score(ys, preds, average='macro', zero_division=0)),
53
+ 'preds': preds,
54
+ 'labels': ys,
55
+ 'samples': ds.sample_info,
56
+ }
57
+ return breakdown
58
+
59
+
60
+ def run_on_checkpoint(ckpt_path, args_json_path, output_dir):
61
+ ckpt_args = json.load(open(args_json_path))['args']
62
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
63
+ modalities = ckpt_args['modalities'] if isinstance(ckpt_args['modalities'], list) \
64
+ else ckpt_args['modalities'].split(',')
65
+ downsample = ckpt_args.get('downsample', 5)
66
+ # Get train stats
67
+ _, _, _, info = get_dataloaders(modalities,
68
+ batch_size=ckpt_args.get('batch_size', 16),
69
+ downsample=downsample)
70
+ # Need the actual stats object -- re-load train set to compute
71
+ tr_ds = MultimodalSceneDataset(
72
+ __import__('experiments.dataset', fromlist=['TRAIN_VOLS']).TRAIN_VOLS,
73
+ modalities, downsample=downsample)
74
+ stats = tr_ds.get_stats()
75
+
76
+ model = build_model(
77
+ ckpt_args.get('model', 'transformer'),
78
+ ckpt_args.get('fusion', 'late'),
79
+ info['feat_dim'], info['modality_dims'], NUM_CLASSES,
80
+ hidden_dim=ckpt_args.get('hidden_dim', 128),
81
+ proj_dim=ckpt_args.get('proj_dim', 0),
82
+ late_agg=ckpt_args.get('late_agg', 'mean'),
83
+ ).to(device)
84
+ try:
85
+ sd = torch.load(ckpt_path, weights_only=True, map_location=device)
86
+ except Exception:
87
+ sd = torch.load(ckpt_path, map_location=device)
88
+ model.load_state_dict(sd, strict=False)
89
+
90
+ breakdown = per_subject_eval(model, device, modalities, stats, downsample)
91
+
92
+ # Overall F1
93
+ all_preds, all_ys = [], []
94
+ for v, info_v in breakdown.items():
95
+ if info_v.get('n', 0) > 0:
96
+ all_preds.extend(info_v['preds'])
97
+ all_ys.extend(info_v['labels'])
98
+ overall_f1 = float(f1_score(all_ys, all_preds, average='macro', zero_division=0))
99
+ overall_acc = float(accuracy_score(all_ys, all_preds))
100
+
101
+ # Per-subject summary
102
+ summary = {
103
+ 'ckpt': ckpt_path,
104
+ 'modalities': modalities,
105
+ 'overall': {'acc': overall_acc, 'f1': overall_f1,
106
+ 'n': len(all_preds)},
107
+ 'per_subject': {
108
+ v: {'n': b.get('n'), 'acc': b.get('acc'), 'f1': b.get('f1')}
109
+ for v, b in breakdown.items()
110
+ },
111
+ 'detail': breakdown,
112
+ }
113
+ os.makedirs(output_dir, exist_ok=True)
114
+ out_path = os.path.join(output_dir, os.path.basename(
115
+ os.path.dirname(ckpt_path)) + '_per_subject.json')
116
+ with open(out_path, 'w') as f:
117
+ json.dump(summary, f, indent=2)
118
+ print(f"Per-subject breakdown saved: {out_path}")
119
+ print(f"Overall F1: {overall_f1:.4f} Acc: {overall_acc:.4f}")
120
+ for v, b in summary['per_subject'].items():
121
+ print(f" {v}: n={b['n']} acc={b.get('acc'):.3f} f1={b.get('f1'):.3f}"
122
+ if b.get('n') else f" {v}: (empty)")
123
+ return summary
124
+
125
+
126
+ def main():
127
+ p = argparse.ArgumentParser()
128
+ p.add_argument('--exp_root', type=str, required=True,
129
+ help='Directory containing run subdirs with model_best.pt and results.json')
130
+ p.add_argument('--output_dir', type=str, required=True)
131
+ args = p.parse_args()
132
+
133
+ runs = []
134
+ for sub in sorted(os.listdir(args.exp_root)):
135
+ if sub == 'slurm_logs':
136
+ continue
137
+ ckpt = os.path.join(args.exp_root, sub, 'model_best.pt')
138
+ res = os.path.join(args.exp_root, sub, 'results.json')
139
+ if os.path.exists(ckpt) and os.path.exists(res):
140
+ runs.append((ckpt, res))
141
+ print(f"Found {len(runs)} runs with checkpoints.")
142
+ for ckpt, res in runs:
143
+ try:
144
+ run_on_checkpoint(ckpt, res, args.output_dir)
145
+ except Exception as e:
146
+ print(f" FAIL {ckpt}: {e}")
147
+
148
+
149
+ if __name__ == '__main__':
150
+ main()
experiments/analysis/extract_video_features.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Extract video features from Scene Camera videos using a pretrained backbone.
4
+ Uses CLIP (ViT-B/16) which is lightweight and doesn't need video-specific pretraining.
5
+
6
+ Output: per-frame feature vectors saved as .npy files, aligned to 100Hz sensor data.
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import json
12
+ import glob
13
+ import argparse
14
+ import numpy as np
15
+ import cv2
16
+ import torch
17
+ import torch.nn as nn
18
+ from torchvision import transforms
19
+
20
+ DATASET_DIR = "${PULSE_ROOT}/dataset"
21
+
22
+
23
+ class CLIPFeatureExtractor:
24
+ """Extract features using CLIP ViT-B/16 (via torchvision)."""
25
+
26
+ def __init__(self, device='cpu'):
27
+ self.device = device
28
+ # Use torchvision's pretrained ViT
29
+ from torchvision.models import vit_b_16, ViT_B_16_Weights
30
+ weights = ViT_B_16_Weights.IMAGENET1K_V1
31
+ model = vit_b_16(weights=weights)
32
+ # Remove classification head, keep feature extractor
33
+ model.heads = nn.Identity()
34
+ model.eval()
35
+ self.model = model.to(device)
36
+ self.transform = weights.transforms()
37
+ self.feat_dim = 768 # ViT-B/16 feature dimension
38
+
39
+ @torch.no_grad()
40
+ def extract_batch(self, frames):
41
+ """Extract features from a batch of frames.
42
+
43
+ Args:
44
+ frames: list of numpy arrays (H, W, 3) in BGR format
45
+ Returns:
46
+ features: numpy array (N, feat_dim)
47
+ """
48
+ tensors = []
49
+ for frame in frames:
50
+ # BGR -> RGB -> PIL-like tensor
51
+ rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
52
+ tensor = torch.from_numpy(rgb).permute(2, 0, 1).float() / 255.0
53
+ tensor = self.transform(tensor)
54
+ tensors.append(tensor)
55
+
56
+ batch = torch.stack(tensors).to(self.device)
57
+ features = self.model(batch)
58
+ return features.cpu().numpy()
59
+
60
+
61
+ def find_scene_video(scenario_dir, vol, scenario):
62
+ """Find the Scene Camera video file."""
63
+ pattern = os.path.join(scenario_dir, f"trimmed_{vol}{scenario}*Scene Cam.mp4")
64
+ matches = glob.glob(pattern)
65
+ return matches[0] if matches else None
66
+
67
+
68
+ def extract_features_for_video(extractor, video_path, target_fps=100,
69
+ batch_size=32, sample_fps=2):
70
+ """Extract features from a video file.
71
+
72
+ Args:
73
+ extractor: feature extractor
74
+ video_path: path to video file
75
+ target_fps: target frame rate to align with sensor data (100Hz)
76
+ batch_size: batch size for feature extraction
77
+ sample_fps: extract features at this rate (e.g., 2 = every 0.5s)
78
+ Features are then interpolated to target_fps.
79
+ Returns:
80
+ features: numpy array (T_target, feat_dim) aligned to target_fps
81
+ """
82
+ cap = cv2.VideoCapture(video_path)
83
+ video_fps = cap.get(cv2.CAP_PROP_FPS)
84
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
85
+ duration = total_frames / video_fps
86
+
87
+ # Sample frames at sample_fps
88
+ sample_interval = int(video_fps / sample_fps)
89
+ sample_indices = list(range(0, total_frames, sample_interval))
90
+
91
+ print(f" Video: {total_frames} frames @ {video_fps:.1f}fps = {duration:.1f}s")
92
+ print(f" Sampling {len(sample_indices)} frames @ {sample_fps}fps")
93
+
94
+ # Extract features in batches
95
+ all_features = []
96
+ batch_frames = []
97
+ batch_indices = []
98
+
99
+ for idx in sample_indices:
100
+ cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
101
+ ret, frame = cap.read()
102
+ if not ret:
103
+ break
104
+ batch_frames.append(frame)
105
+ batch_indices.append(idx)
106
+
107
+ if len(batch_frames) >= batch_size:
108
+ feats = extractor.extract_batch(batch_frames)
109
+ all_features.append(feats)
110
+ batch_frames = []
111
+ if len(all_features) % 10 == 0:
112
+ print(f" Processed {len(all_features) * batch_size} frames...")
113
+
114
+ if batch_frames:
115
+ feats = extractor.extract_batch(batch_frames)
116
+ all_features.append(feats)
117
+
118
+ cap.release()
119
+
120
+ if not all_features:
121
+ return None
122
+
123
+ features = np.concatenate(all_features, axis=0) # (N_samples, feat_dim)
124
+ sample_times = np.array(batch_indices[:features.shape[0]]) / video_fps # seconds
125
+
126
+ # Interpolate to target_fps (100Hz)
127
+ target_times = np.arange(0, duration, 1.0 / target_fps)
128
+ n_target = len(target_times)
129
+
130
+ # Linear interpolation per feature dimension
131
+ from scipy.interpolate import interp1d
132
+ if len(sample_times) < 2:
133
+ # Not enough samples, repeat
134
+ interpolated = np.tile(features[0], (n_target, 1))
135
+ else:
136
+ interp_func = interp1d(
137
+ sample_times, features, axis=0,
138
+ kind='linear', fill_value='extrapolate'
139
+ )
140
+ interpolated = interp_func(target_times).astype(np.float32)
141
+
142
+ print(f" Output: {interpolated.shape} @ {target_fps}Hz")
143
+ return interpolated
144
+
145
+
146
+ def main():
147
+ parser = argparse.ArgumentParser(description='Extract video features')
148
+ parser.add_argument('--sample_fps', type=int, default=2,
149
+ help='Sample rate for feature extraction (default: 2fps)')
150
+ parser.add_argument('--batch_size', type=int, default=16,
151
+ help='Batch size for feature extraction')
152
+ parser.add_argument('--device', type=str, default='cuda',
153
+ help='Device (cuda or cpu)')
154
+ args = parser.parse_args()
155
+
156
+ device = args.device if torch.cuda.is_available() and args.device == 'cuda' else 'cpu'
157
+ print(f"Device: {device}")
158
+
159
+ print("Loading ViT-B/16 feature extractor...")
160
+ extractor = CLIPFeatureExtractor(device=device)
161
+ print(f"Feature dim: {extractor.feat_dim}")
162
+
163
+ # Process all volunteers and scenarios
164
+ processed = 0
165
+ skipped = 0
166
+
167
+ for vol_dir in sorted(glob.glob(f"{DATASET_DIR}/v*")):
168
+ vol = os.path.basename(vol_dir)
169
+ for scenario_dir in sorted(glob.glob(f"{vol_dir}/s*")):
170
+ scenario = os.path.basename(scenario_dir)
171
+ output_path = os.path.join(scenario_dir, "video_features_100hz.npy")
172
+
173
+ # Skip if already extracted
174
+ if os.path.exists(output_path):
175
+ print(f"[{vol}/{scenario}] Already exists, skipping")
176
+ skipped += 1
177
+ continue
178
+
179
+ # Find video
180
+ video_path = find_scene_video(scenario_dir, vol, scenario)
181
+ if video_path is None:
182
+ print(f"[{vol}/{scenario}] No Scene Camera video found, skipping")
183
+ skipped += 1
184
+ continue
185
+
186
+ print(f"\n[{vol}/{scenario}]")
187
+ print(f" Video: {os.path.basename(video_path)}")
188
+
189
+ features = extract_features_for_video(
190
+ extractor, video_path,
191
+ batch_size=args.batch_size,
192
+ sample_fps=args.sample_fps,
193
+ )
194
+
195
+ if features is not None:
196
+ np.save(output_path, features)
197
+ print(f" Saved: {output_path} ({features.shape})")
198
+ processed += 1
199
+ else:
200
+ print(f" FAILED: Could not extract features")
201
+
202
+ print(f"\n{'='*60}")
203
+ print(f"Done! Processed: {processed}, Skipped: {skipped}")
204
+ print(f"Feature files: {DATASET_DIR}/*/*/video_features_100hz.npy")
205
+
206
+
207
+ if __name__ == '__main__':
208
+ main()
experiments/analysis/extract_videomae_features.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Extract video features using VideoMAE (pretrained on Kinetics-400).
4
+ Process 16-frame video clips to capture temporal dynamics.
5
+
6
+ Output: per-frame feature vectors aligned to 100Hz sensor data.
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import json
12
+ import glob
13
+ import argparse
14
+ import numpy as np
15
+ import cv2
16
+ import torch
17
+
18
+ DATASET_DIR = "${PULSE_ROOT}/dataset"
19
+ MODEL_NAME = "${PULSE_ROOT}/models/videomae-base-kinetics"
20
+
21
+
22
+ class VideoMAEFeatureExtractor:
23
+ """Extract features using VideoMAE-Base (16-frame clips). Multi-GPU enabled."""
24
+
25
+ def __init__(self, device='cpu'):
26
+ from transformers import VideoMAEModel, VideoMAEImageProcessor
27
+ import torch.nn as nn
28
+ self.device = device
29
+ self.processor = VideoMAEImageProcessor.from_pretrained(MODEL_NAME)
30
+ model = VideoMAEModel.from_pretrained(MODEL_NAME).to(device)
31
+ model.eval()
32
+ # Wrap with DataParallel if multiple GPUs available
33
+ if torch.cuda.is_available() and torch.cuda.device_count() > 1:
34
+ self.n_gpus = torch.cuda.device_count()
35
+ print(f" Using DataParallel across {self.n_gpus} GPUs")
36
+ self.model = nn.DataParallel(model)
37
+ self.num_frames = model.config.num_frames
38
+ self.feat_dim = model.config.hidden_size
39
+ else:
40
+ self.n_gpus = 1
41
+ self.model = model
42
+ self.num_frames = model.config.num_frames
43
+ self.feat_dim = model.config.hidden_size
44
+
45
+ @torch.no_grad()
46
+ def extract_clip(self, frames):
47
+ """Extract feature from a single 16-frame clip.
48
+
49
+ Args:
50
+ frames: list of 16 RGB numpy arrays (H, W, 3)
51
+ Returns:
52
+ feature: numpy array (feat_dim,) - mean-pooled patch tokens
53
+ """
54
+ # Pad/truncate to exactly num_frames
55
+ if len(frames) < self.num_frames:
56
+ frames = frames + [frames[-1]] * (self.num_frames - len(frames))
57
+ elif len(frames) > self.num_frames:
58
+ # uniform sampling
59
+ indices = np.linspace(0, len(frames) - 1, self.num_frames, dtype=int)
60
+ frames = [frames[i] for i in indices]
61
+
62
+ inputs = self.processor(frames, return_tensors="pt")
63
+ pixel_values = inputs["pixel_values"].to(self.device)
64
+ outputs = self.model(pixel_values)
65
+ # Average pool over all patch tokens
66
+ feature = outputs.last_hidden_state.mean(dim=1).squeeze(0) # (768,)
67
+ return feature.cpu().numpy()
68
+
69
+ @torch.no_grad()
70
+ def extract_clip_batch(self, clips):
71
+ """Extract features from a batch of clips.
72
+
73
+ Args:
74
+ clips: list of clips, each is a list of 16 RGB frames
75
+ Returns:
76
+ features: numpy array (B, feat_dim)
77
+ """
78
+ # Process each clip
79
+ all_pixel_values = []
80
+ for frames in clips:
81
+ if len(frames) < self.num_frames:
82
+ frames = frames + [frames[-1]] * (self.num_frames - len(frames))
83
+ elif len(frames) > self.num_frames:
84
+ indices = np.linspace(0, len(frames) - 1, self.num_frames, dtype=int)
85
+ frames = [frames[i] for i in indices]
86
+ inputs = self.processor(frames, return_tensors="pt")
87
+ all_pixel_values.append(inputs["pixel_values"])
88
+
89
+ batch = torch.cat(all_pixel_values, dim=0).to(self.device)
90
+ outputs = self.model(batch)
91
+ features = outputs.last_hidden_state.mean(dim=1) # (B, 768)
92
+ return features.cpu().numpy()
93
+
94
+
95
+ def find_scene_video(scenario_dir, vol, scenario):
96
+ pattern = os.path.join(scenario_dir, f"trimmed_{vol}{scenario}*Scene Cam.mp4")
97
+ matches = glob.glob(pattern)
98
+ return matches[0] if matches else None
99
+
100
+
101
+ def extract_features_for_video(extractor, video_path, target_fps=100,
102
+ clip_stride_sec=0.5, batch_size=4):
103
+ """Extract VideoMAE features from a video.
104
+
105
+ Strategy (fast):
106
+ - Sequentially decode video ONCE, downsample to 8fps and store frames in RAM
107
+ - Build clips by indexing into the in-memory frame array (no random seeks)
108
+ """
109
+ import time
110
+ t0 = time.time()
111
+ cap = cv2.VideoCapture(video_path)
112
+ video_fps = cap.get(cv2.CAP_PROP_FPS)
113
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
114
+ duration = total_frames / video_fps
115
+
116
+ # Read all frames sequentially, downsample to ~16fps (every video_fps/16 frame)
117
+ decode_fps = 16 # we sample frames at this rate from the video
118
+ decode_stride = max(1, int(round(video_fps / decode_fps)))
119
+ print(f" Video: {total_frames} frames @ {video_fps:.1f}fps = {duration:.1f}s")
120
+ print(f" Decoding sequentially with stride {decode_stride} (~{video_fps/decode_stride:.1f}fps)...")
121
+
122
+ # Pre-resize to model input size during decoding to save memory
123
+ # VideoMAE expects 224x224
124
+ target_size = 224
125
+
126
+ decoded_frames = [] # list of (H, W, 3) uint8 RGB arrays
127
+ decoded_times = [] # corresponding timestamps in seconds
128
+ frame_idx = 0
129
+ while True:
130
+ ret, frame = cap.read()
131
+ if not ret:
132
+ break
133
+ if frame_idx % decode_stride == 0:
134
+ # Resize early to save memory
135
+ resized = cv2.resize(frame, (target_size, target_size), interpolation=cv2.INTER_AREA)
136
+ rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
137
+ decoded_frames.append(rgb)
138
+ decoded_times.append(frame_idx / video_fps)
139
+ frame_idx += 1
140
+ cap.release()
141
+
142
+ decoded_frames = np.array(decoded_frames) # (N, 224, 224, 3)
143
+ decoded_times = np.array(decoded_times)
144
+ decode_time = time.time() - t0
145
+ print(f" Decoded {len(decoded_frames)} frames in {decode_time:.1f}s")
146
+
147
+ # Build clips: each clip = 16 frames spanning ~1 second
148
+ # Sample 16 consecutive frames from in-memory array
149
+ frames_per_clip = 16
150
+ n_decoded = len(decoded_frames)
151
+ if n_decoded < 4:
152
+ return None
153
+
154
+ # Each clip occupies 16 frames at ~16fps = 1 second
155
+ clip_centers_sec = np.arange(0.5, duration - 0.5, clip_stride_sec)
156
+ n_clips = len(clip_centers_sec)
157
+ print(f" Building {n_clips} clips (stride={clip_stride_sec}s, {frames_per_clip} frames each)")
158
+
159
+ all_features = []
160
+ clip_times = []
161
+ batch_clips = []
162
+ batch_times = []
163
+
164
+ t1 = time.time()
165
+ for center_sec in clip_centers_sec:
166
+ # Find decoded frames within ±0.5s window
167
+ center_idx = np.searchsorted(decoded_times, center_sec)
168
+ half = frames_per_clip // 2
169
+ start = max(0, center_idx - half)
170
+ end = min(n_decoded, start + frames_per_clip)
171
+ start = max(0, end - frames_per_clip)
172
+
173
+ if end - start < 4:
174
+ continue
175
+
176
+ clip = list(decoded_frames[start:end])
177
+ # Pad if needed
178
+ if len(clip) < frames_per_clip:
179
+ clip = clip + [clip[-1]] * (frames_per_clip - len(clip))
180
+
181
+ batch_clips.append(clip)
182
+ batch_times.append(center_sec)
183
+
184
+ if len(batch_clips) >= batch_size:
185
+ feats = extractor.extract_clip_batch(batch_clips)
186
+ all_features.append(feats)
187
+ clip_times.extend(batch_times)
188
+ batch_clips = []
189
+ batch_times = []
190
+
191
+ if batch_clips:
192
+ feats = extractor.extract_clip_batch(batch_clips)
193
+ all_features.append(feats)
194
+ clip_times.extend(batch_times)
195
+ inference_time = time.time() - t1
196
+ print(f" Inference time: {inference_time:.1f}s ({len(clip_times)} clips)")
197
+
198
+ if not all_features:
199
+ return None
200
+
201
+ features = np.concatenate(all_features, axis=0) # (N_clips, 768)
202
+ clip_times = np.array(clip_times[:features.shape[0]])
203
+
204
+ # Interpolate to target_fps (100Hz)
205
+ target_times = np.arange(0, duration, 1.0 / target_fps)
206
+ n_target = len(target_times)
207
+
208
+ from scipy.interpolate import interp1d
209
+ if len(clip_times) < 2:
210
+ interpolated = np.tile(features[0], (n_target, 1))
211
+ else:
212
+ interp_func = interp1d(
213
+ clip_times, features, axis=0,
214
+ kind='linear', fill_value='extrapolate'
215
+ )
216
+ interpolated = interp_func(target_times).astype(np.float32)
217
+
218
+ print(f" Output: {interpolated.shape} @ {target_fps}Hz")
219
+ return interpolated
220
+
221
+
222
+ def main():
223
+ parser = argparse.ArgumentParser()
224
+ parser.add_argument('--clip_stride', type=float, default=0.5,
225
+ help='Clip extraction stride in seconds (default: 0.5)')
226
+ parser.add_argument('--batch_size', type=int, default=4)
227
+ parser.add_argument('--device', type=str, default='cuda')
228
+ parser.add_argument('--output_name', type=str, default='video_features_videomae_100hz.npy')
229
+ args = parser.parse_args()
230
+
231
+ device = args.device if torch.cuda.is_available() and args.device == 'cuda' else 'cpu'
232
+ print(f"Device: {device}")
233
+
234
+ print(f"Loading VideoMAE from {MODEL_NAME}...")
235
+ extractor = VideoMAEFeatureExtractor(device=device)
236
+ print(f"Feature dim: {extractor.feat_dim}, num frames per clip: {extractor.num_frames}")
237
+
238
+ processed = 0
239
+ skipped = 0
240
+
241
+ for vol_dir in sorted(glob.glob(f"{DATASET_DIR}/v*")):
242
+ vol = os.path.basename(vol_dir)
243
+ for scenario_dir in sorted(glob.glob(f"{vol_dir}/s*")):
244
+ scenario = os.path.basename(scenario_dir)
245
+ output_path = os.path.join(scenario_dir, args.output_name)
246
+
247
+ if os.path.exists(output_path):
248
+ print(f"[{vol}/{scenario}] exists, skip")
249
+ skipped += 1
250
+ continue
251
+
252
+ video_path = find_scene_video(scenario_dir, vol, scenario)
253
+ if video_path is None:
254
+ print(f"[{vol}/{scenario}] no video, skip")
255
+ skipped += 1
256
+ continue
257
+
258
+ print(f"\n[{vol}/{scenario}]")
259
+ features = extract_features_for_video(
260
+ extractor, video_path,
261
+ clip_stride_sec=args.clip_stride,
262
+ batch_size=args.batch_size,
263
+ )
264
+
265
+ if features is not None:
266
+ np.save(output_path, features)
267
+ print(f" Saved: {output_path} ({features.shape})")
268
+ processed += 1
269
+ else:
270
+ print(f" FAILED")
271
+
272
+ print(f"\nDone! Processed: {processed}, Skipped: {skipped}")
273
+
274
+
275
+ if __name__ == '__main__':
276
+ main()
experiments/analysis/gen_val_comparison.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, json, torch
2
+ sys.path.insert(0, '${PULSE_ROOT}')
3
+ os.environ['HF_HUB_OFFLINE'] = '1'
4
+ os.environ['TRANSFORMERS_OFFLINE'] = '1'
5
+
6
+ from tasks.train_pred import (
7
+ TextPredictionDataset, SensorToTextModel, apply_lora, set_seed
8
+ )
9
+ from data.dataset import TRAIN_VOLS, VAL_VOLS, TEST_VOLS
10
+
11
+ set_seed(42)
12
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+
14
+ # Load tokenizer & LLM
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM
16
+ llm_path = '${PULSE_ROOT}/models/qwen2.5-0.5b'
17
+ tokenizer = AutoTokenizer.from_pretrained(llm_path, trust_remote_code=True, local_files_only=True)
18
+ if tokenizer.pad_token is None:
19
+ tokenizer.pad_token = tokenizer.eos_token
20
+
21
+ llm = AutoModelForCausalLM.from_pretrained(
22
+ llm_path, trust_remote_code=True, torch_dtype=torch.float32, local_files_only=True
23
+ ).to(device)
24
+ llm.config.pad_token_id = tokenizer.pad_token_id
25
+ for p in llm.parameters():
26
+ p.requires_grad = False
27
+ lora_params = apply_lora(llm, r=8, alpha=16)
28
+
29
+ modalities = ['mocap', 'emg', 'imu']
30
+
31
+ # Build datasets
32
+ train_ds = TextPredictionDataset(TRAIN_VOLS, modalities, tokenizer, window_sec=15.0, downsample=5)
33
+ stats = train_ds.get_stats()
34
+ val_ds = TextPredictionDataset(VAL_VOLS, modalities, tokenizer, window_sec=15.0, downsample=5, stats=stats)
35
+ test_ds = TextPredictionDataset(TEST_VOLS, modalities, tokenizer, window_sec=15.0, downsample=5, stats=stats)
36
+
37
+ # Build model & load weights
38
+ model = SensorToTextModel(train_ds.feat_dim, llm, tokenizer, n_sensor_tokens=8, d_model=64)
39
+ model.to(device)
40
+
41
+ ckpt_path = '${PULSE_ROOT}/results/pred_llm2/pred_llm_mocap-emg-imu/model_best.pt'
42
+ sd = torch.load(ckpt_path, weights_only=True, map_location=device)
43
+ model.load_state_dict(sd, strict=False)
44
+ model.eval()
45
+
46
+ out_path = '${PULSE_ROOT}/docs/pred_llm2_val_comparison.txt'
47
+
48
+ from torch.utils.data import DataLoader
49
+
50
+ with open(out_path, 'w') as f:
51
+ for split_name, ds in [('Validation', val_ds), ('Test', test_ds)]:
52
+ loader = DataLoader(ds, batch_size=8, shuffle=False)
53
+ f.write(f"{'='*70}\n")
54
+ f.write(f"{split_name} Set — mocap,emg,imu (best charF1=0.0324)\n")
55
+ f.write(f"Samples: {len(ds)}\n")
56
+ f.write(f"{'='*70}\n\n")
57
+
58
+ idx = 0
59
+ for batch in loader:
60
+ sensor = batch['sensor'].to(device)
61
+ preds = model.generate_text(sensor, tokenizer, max_new_tokens=20)
62
+ refs = [ds.texts[idx + i] for i in range(len(preds))]
63
+ for p, r in zip(preds, refs):
64
+ match = "OK" if p.strip() == r.strip() else "XX"
65
+ f.write(f"[{match}] #{idx+1}\n")
66
+ f.write(f" Pred: {p.strip()}\n")
67
+ f.write(f" Ref: {r.strip()}\n\n")
68
+ idx += 1
69
+
70
+ # Stats
71
+ f.write(f"\n--- {split_name} Summary ---\n")
72
+ f.write(f"Total: {idx}\n\n")
73
+
74
+ print(f"Written to {out_path}")
experiments/analysis/generate_action_labels.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Generate action labels by clustering task descriptions using text embeddings.
4
+ No manual rules — uses sentence-transformers + K-Means clustering.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import glob
10
+ import argparse
11
+ import numpy as np
12
+ from collections import Counter
13
+ from sklearn.cluster import KMeans
14
+ from sklearn.metrics import silhouette_score
15
+
16
+ ANNOTATION_DIR = "${PULSE_ROOT}"
17
+
18
+
19
+ def collect_tasks():
20
+ """Collect all task descriptions from all annotation files."""
21
+ tasks = []
22
+ for path in sorted(glob.glob(os.path.join(ANNOTATION_DIR, 'v*/s*.json'))):
23
+ with open(path) as f:
24
+ data = json.load(f)
25
+ for seg in data.get('segments', []):
26
+ tasks.append(seg['task'])
27
+ return tasks
28
+
29
+
30
+ def embed_texts(texts):
31
+ """Encode texts using sentence-transformers (multilingual model)."""
32
+ try:
33
+ from sentence_transformers import SentenceTransformer
34
+ model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
35
+ embeddings = model.encode(texts, show_progress_bar=True, batch_size=128)
36
+ print(f"Encoded {len(texts)} texts with sentence-transformers, dim={embeddings.shape[1]}")
37
+ return embeddings
38
+ except Exception as e:
39
+ print(f"sentence-transformers failed ({e}), falling back to TF-IDF")
40
+ from sklearn.feature_extraction.text import TfidfVectorizer
41
+ vec = TfidfVectorizer(analyzer='char', ngram_range=(1, 3), max_features=3000)
42
+ X = vec.fit_transform(texts).toarray()
43
+ print(f"Encoded {len(texts)} texts with TF-IDF char n-grams, dim={X.shape[1]}")
44
+ return X
45
+
46
+
47
+ def cluster_tasks(tasks, k_range=(10, 30)):
48
+ unique_tasks = sorted(set(tasks))
49
+ print(f"Total segments: {len(tasks)}, Unique task texts: {len(unique_tasks)}")
50
+
51
+ X = embed_texts(unique_tasks)
52
+
53
+ # Find optimal K via silhouette score
54
+ best_k, best_score = k_range[0], -1
55
+ scores = {}
56
+ for k in range(k_range[0], k_range[1] + 1):
57
+ km = KMeans(n_clusters=k, random_state=42, n_init=10)
58
+ labels = km.fit_predict(X)
59
+ score = silhouette_score(X, labels, sample_size=min(2000, len(unique_tasks)))
60
+ scores[k] = score
61
+ if score > best_score:
62
+ best_score = score
63
+ best_k = k
64
+ print(f" K={k}: silhouette={score:.4f}" + (" *" if k == best_k else ""))
65
+
66
+ print(f"\nBest K={best_k} (silhouette={best_score:.4f})")
67
+
68
+ # Final clustering
69
+ km = KMeans(n_clusters=best_k, random_state=42, n_init=10)
70
+ labels = km.fit_predict(X)
71
+
72
+ task_to_cluster = {task: int(labels[i]) for i, task in enumerate(unique_tasks)}
73
+
74
+ # Representative task per cluster (closest to centroid)
75
+ cluster_representatives = {}
76
+ cluster_members = {}
77
+ for cid in range(best_k):
78
+ member_idx = [i for i, l in enumerate(labels) if l == cid]
79
+ members = [unique_tasks[i] for i in member_idx]
80
+ cluster_members[cid] = members
81
+ centroid = km.cluster_centers_[cid]
82
+ dists = np.linalg.norm(X[member_idx] - centroid, axis=1)
83
+ closest = member_idx[np.argmin(dists)]
84
+ cluster_representatives[cid] = unique_tasks[closest]
85
+
86
+ return task_to_cluster, cluster_representatives, cluster_members, best_k, scores
87
+
88
+
89
+ def main():
90
+ parser = argparse.ArgumentParser()
91
+ parser.add_argument('--output_dir', type=str,
92
+ default='${PULSE_ROOT}/results/pred')
93
+ parser.add_argument('--k_min', type=int, default=10)
94
+ parser.add_argument('--k_max', type=int, default=30)
95
+ args = parser.parse_args()
96
+
97
+ os.makedirs(args.output_dir, exist_ok=True)
98
+
99
+ tasks = collect_tasks()
100
+ task_to_cluster, representatives, members, K, scores = cluster_tasks(
101
+ tasks, k_range=(args.k_min, args.k_max)
102
+ )
103
+
104
+ # Print summary
105
+ segment_counts = Counter(task_to_cluster[t] for t in tasks)
106
+ print(f"\n{'='*60}")
107
+ print(f"Clusters (K={K}):")
108
+ for cid in range(K):
109
+ rep = representatives[cid]
110
+ n_unique = len(members[cid])
111
+ n_segs = segment_counts.get(cid, 0)
112
+ examples = [m for m in members[cid] if m != rep][:3]
113
+ print(f"\n [{cid:2d}] ({n_segs:4d} segs, {n_unique:3d} unique) \"{rep}\"")
114
+ for ex in examples:
115
+ print(f" - {ex}")
116
+
117
+ # Save
118
+ output = {
119
+ 'num_classes': K,
120
+ 'task_to_cluster': task_to_cluster,
121
+ 'cluster_representatives': {str(k): v for k, v in representatives.items()},
122
+ 'cluster_sizes_unique': {str(k): len(v) for k, v in members.items()},
123
+ 'cluster_sizes_segments': {str(k): v for k, v in segment_counts.items()},
124
+ 'silhouette_scores': {str(k): v for k, v in scores.items()},
125
+ }
126
+ out_path = os.path.join(args.output_dir, 'action_labels.json')
127
+ with open(out_path, 'w') as f:
128
+ json.dump(output, f, indent=2, ensure_ascii=False)
129
+ print(f"\nSaved to {out_path}")
130
+
131
+
132
+ if __name__ == '__main__':
133
+ main()
experiments/analysis/generate_coarse_annotations.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Generate coarse-grained annotations by merging consecutive fine-grained segments
4
+ into composite actions (8-15s duration) using LLM.
5
+
6
+ Input: annotations_v2/ (fine-grained, ~2-3s segments, 11 classes)
7
+ Output: annotations_coarse/ (coarse-grained, ~8-15s segments, ~6 classes)
8
+
9
+ Does NOT modify annotations_v2/.
10
+ """
11
+
12
+ import os
13
+ import json
14
+ import re
15
+ import time
16
+ import glob
17
+ import urllib.request
18
+ from collections import Counter
19
+
20
+ INPUT_DIR = "${PULSE_ROOT}/annotations_v2"
21
+ OUTPUT_DIR = "${PULSE_ROOT}/annotations_coarse"
22
+
23
+ API_URL = "https://api.chatanywhere.tech/v1/chat/completions"
24
+ API_KEYS = [
25
+ "sk-MN5n1uEETyaky96fLJdHqZobXF1f7KmOrZHzwD3lt585asFQ",
26
+ "sk-YnYrtPdAXwlE12hRpi6dYqlE1RRVR3LDVBka6wKaefU4iQRY",
27
+ "sk-jOZtodDv6OxUOMu3NuJ8lzffjwBlshn9OHY5KSmqmPTtc9qs",
28
+ "sk-qAaKTKYIRF24btu1oQWgubWG4UdA92bILNtzOkHNEPAcCxdB",
29
+ "sk-MgCBBonblMrCFnSXd6fJZaBLTCfCJ5FjYZfSe2e46bgmyktk",
30
+ "sk-79e30kYRgduuf2fSU0Lsc814YjNkClXXzQqIbx0iLS40IOEH",
31
+ "sk-h9Tej4tW6AQC6fT0njfzrPKXEk6fBwpiSvvQd0aJAhw4UwLz",
32
+ "sk-k2QNHt5wAH26Fw8hZuPWuVXw8Psd1jX09qusiA6PdBj5Vzuu",
33
+ "sk-w7EkTblciNI44cwosHXi0PGZNUf1hnJmpzOQ85va9VPdAKbz",
34
+ "sk-Dexs5ZF7OjFCq7CZW45wJ8EKoGtIswv6rsLUMzUXXkWBDBBJ",
35
+ ]
36
+
37
+ SCENE_DESCRIPTIONS = {
38
+ "s1": "办公桌面整理与工作准备",
39
+ "s2": "快递打包发送",
40
+ "s3": "厨房调料整理",
41
+ "s4": "清理餐后桌面",
42
+ "s5": "餐前桌面布置",
43
+ "s6": "商务旅行行李箱打包",
44
+ "s7": "冲泡咖啡/饮品",
45
+ "s8": "晾衣架整理与衣物收纳",
46
+ }
47
+
48
+ COARSE_CATEGORIES = """粗粒度动作类别(共6类):
49
+
50
+ 1. Manipulate - 操作物体(抓取、调整、放置某个物体的完整过程,包含拿起→操作→放下的组合)
51
+ 2. CleanOrganize - 清洁/整理(擦桌子、理线、整理桌面、叠衣服等持续性整理活动)
52
+ 3. Transfer - 搬运/传递(将物体从一个位置搬到另一个位置的过程)
53
+ 4. Assemble - 组装/连接/包装(封箱、贴胶带、盖盖子、插电源、拧瓶盖等需要精细对准的操作)
54
+ 5. FoodPrep - 食物/饮品准备(倒水、倒调料、搅拌、冲泡等与食物饮品相关的操作)
55
+ 6. Idle - 空闲/过渡(无明确操作的间隔)
56
+ """
57
+
58
+ current_key_idx = 0
59
+ call_count = 0
60
+
61
+
62
+ def call_llm(prompt, max_tokens=1500, retries=3):
63
+ global current_key_idx, call_count
64
+ for attempt in range(retries * len(API_KEYS)):
65
+ key = API_KEYS[current_key_idx]
66
+ try:
67
+ data = json.dumps({
68
+ "model": "gpt-4o-mini",
69
+ "messages": [{"role": "user", "content": prompt}],
70
+ "max_tokens": max_tokens,
71
+ "temperature": 0.1,
72
+ }).encode()
73
+ req = urllib.request.Request(
74
+ API_URL, data=data,
75
+ headers={"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
76
+ )
77
+ resp = urllib.request.urlopen(req, timeout=30)
78
+ result = json.loads(resp.read())
79
+ call_count += 1
80
+ return result["choices"][0]["message"]["content"]
81
+ except Exception as e:
82
+ err = str(e)
83
+ if any(k in err for k in ["429", "quota", "limit", "402", "403"]):
84
+ current_key_idx = (current_key_idx + 1) % len(API_KEYS)
85
+ else:
86
+ time.sleep(0.5)
87
+ current_key_idx = (current_key_idx + 1) % len(API_KEYS)
88
+ return None
89
+
90
+
91
+ def parse_ts(ts_str):
92
+ """Parse 'MM:SS' to seconds."""
93
+ m = re.match(r'(\d+):(\d+)', ts_str.strip())
94
+ if m:
95
+ return int(m.group(1)) * 60 + int(m.group(2))
96
+ return 0
97
+
98
+
99
+ def format_ts(sec):
100
+ """Format seconds to 'MM:SS'."""
101
+ return f"{sec//60:02d}:{sec%60:02d}"
102
+
103
+
104
+ def merge_segments_with_llm(segments, scene_id):
105
+ """Use LLM to merge fine-grained segments into coarse composite actions."""
106
+ scene_desc = SCENE_DESCRIPTIONS.get(scene_id, "日常活动")
107
+
108
+ # Build segment list
109
+ seg_lines = []
110
+ for i, seg in enumerate(segments):
111
+ label = seg.get("action_label", "Idle")
112
+ seg_lines.append(f"{i+1}. [{seg['timestamp']}] {label}: {seg['task']}")
113
+ seg_text = "\n".join(seg_lines)
114
+
115
+ prompt = f"""你是一个动作标注专家。以下是一段"{scene_desc}"录制中的细粒度动作序列(每个2-3秒)。
116
+ 请将相关的连续动作合并为粗粒度复合动作,每个复合动作持续5-15秒。
117
+
118
+ 合并规则:
119
+ - 围绕同一个物体的连续操作合并为一个(如"抓取杯子→调整→放下"合并为一个Manipulate)
120
+ - 连续的整理/清洁动作合并
121
+ - 合并后的时间范围 = 第一个子动作的开始时间 到 最后一个子动作的结束时间
122
+ - 如果中间有短暂Idle(≤3秒),可以包含进去
123
+ - 每个复合动作必须从6个类别中选一个
124
+
125
+ {COARSE_CATEGORIES}
126
+
127
+ 细粒度动作序列:
128
+ {seg_text}
129
+
130
+ 请严格按以下JSON格式返回,不要添加任何额外文字:
131
+ [{{"timestamp": "MM:SS-MM:SS", "coarse_action": "类别名", "description": "简��描述这段复合动作", "fine_segments": [子动作编号列表]}}]"""
132
+
133
+ response = call_llm(prompt, max_tokens=2000)
134
+ if response is None:
135
+ return None
136
+
137
+ try:
138
+ match = re.search(r'\[.*\]', response, re.DOTALL)
139
+ if match:
140
+ results = json.loads(match.group())
141
+ valid = []
142
+ for r in results:
143
+ if all(k in r for k in ["timestamp", "coarse_action", "description"]):
144
+ # Validate category
145
+ if r["coarse_action"] in {"Manipulate", "CleanOrganize", "Transfer",
146
+ "Assemble", "FoodPrep", "Idle"}:
147
+ valid.append(r)
148
+ return valid
149
+ except (json.JSONDecodeError, KeyError) as e:
150
+ print(f" Parse error: {e}")
151
+ return None
152
+
153
+
154
+ def process_file(input_path, vol, scenario):
155
+ """Process one annotation file."""
156
+ data = json.load(open(input_path))
157
+ segments = data["segments"]
158
+
159
+ if not segments:
160
+ return {"fine_segments": segments, "coarse_segments": []}, 0
161
+
162
+ print(f" Merging {len(segments)} fine segments...")
163
+ coarse = merge_segments_with_llm(segments, scenario)
164
+
165
+ if coarse is None:
166
+ # Fallback: simple time-based merging without LLM
167
+ print(f" LLM failed, using fallback merge")
168
+ coarse = fallback_merge(segments)
169
+
170
+ result = {
171
+ "fine_segments": segments,
172
+ "coarse_segments": coarse,
173
+ }
174
+ return result, len(coarse)
175
+
176
+
177
+ def fallback_merge(segments):
178
+ """Simple rule-based merging as fallback."""
179
+ if not segments:
180
+ return []
181
+
182
+ coarse = []
183
+ group = [segments[0]]
184
+
185
+ for seg in segments[1:]:
186
+ # Parse timestamps
187
+ prev_ts = group[-1]["timestamp"]
188
+ curr_ts = seg["timestamp"]
189
+ m1 = re.match(r'(\d+:\d+)\s*-\s*(\d+:\d+)', prev_ts)
190
+ m2 = re.match(r'(\d+:\d+)\s*-\s*(\d+:\d+)', curr_ts)
191
+ if not m1 or not m2:
192
+ group.append(seg)
193
+ continue
194
+
195
+ prev_end = parse_ts(m1.group(2))
196
+ curr_start = parse_ts(m2.group(1))
197
+ gap = curr_start - prev_end
198
+
199
+ # Merge if gap ≤ 3s and group duration < 15s
200
+ group_start = parse_ts(re.match(r'(\d+:\d+)', group[0]["timestamp"]).group(1))
201
+ curr_end = parse_ts(m2.group(2))
202
+ group_duration = curr_end - group_start
203
+
204
+ if gap <= 3 and group_duration <= 15:
205
+ group.append(seg)
206
+ else:
207
+ # Emit current group
208
+ coarse.append(_emit_group(group))
209
+ group = [seg]
210
+
211
+ if group:
212
+ coarse.append(_emit_group(group))
213
+
214
+ return coarse
215
+
216
+
217
+ def _emit_group(group):
218
+ """Create a coarse segment from a group of fine segments."""
219
+ m_start = re.match(r'(\d+:\d+)', group[0]["timestamp"])
220
+ m_end = re.match(r'\d+:\d+\s*-\s*(\d+:\d+)', group[-1]["timestamp"])
221
+ start = m_start.group(1) if m_start else "00:00"
222
+ end = m_end.group(1) if m_end else "00:00"
223
+
224
+ labels = [seg.get("action_label", "Idle") for seg in group]
225
+ label_counts = Counter(labels)
226
+ dominant = label_counts.most_common(1)[0][0]
227
+
228
+ # Map fine label to coarse
229
+ label_map = {
230
+ "Grasp": "Manipulate", "Place": "Manipulate", "Arrange": "CleanOrganize",
231
+ "Wipe": "CleanOrganize", "Fold": "CleanOrganize", "Transport": "Transfer",
232
+ "OpenClose": "Assemble", "TearCut": "Assemble",
233
+ "Pour": "FoodPrep", "Stir": "FoodPrep", "Idle": "Idle",
234
+ }
235
+ coarse_label = label_map.get(dominant, "Manipulate")
236
+
237
+ tasks = [seg["task"] for seg in group]
238
+ desc = tasks[0] if len(tasks) == 1 else f"{tasks[0]}...{tasks[-1]}"
239
+
240
+ return {
241
+ "timestamp": f"{start}-{end}",
242
+ "coarse_action": coarse_label,
243
+ "description": desc[:80],
244
+ "fine_segments": list(range(1, len(group) + 1)),
245
+ }
246
+
247
+
248
+ def main():
249
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
250
+
251
+ total_fine = 0
252
+ total_coarse = 0
253
+ total_files = 0
254
+ coarse_labels = Counter()
255
+
256
+ for vol_dir in sorted(glob.glob(f"{INPUT_DIR}/v*")):
257
+ vol = os.path.basename(vol_dir)
258
+ out_dir = os.path.join(OUTPUT_DIR, vol)
259
+ os.makedirs(out_dir, exist_ok=True)
260
+
261
+ for ann_file in sorted(glob.glob(f"{vol_dir}/s*.json")):
262
+ scenario = os.path.basename(ann_file).replace(".json", "")
263
+ print(f"[{vol}/{scenario}]", flush=True)
264
+
265
+ result, n_coarse = process_file(ann_file, vol, scenario)
266
+
267
+ out_path = os.path.join(out_dir, f"{scenario}.json")
268
+ with open(out_path, "w", encoding="utf-8") as f:
269
+ json.dump(result, f, ensure_ascii=False, indent=2)
270
+
271
+ n_fine = len(result["fine_segments"])
272
+ total_fine += n_fine
273
+ total_coarse += n_coarse
274
+ total_files += 1
275
+
276
+ for seg in result["coarse_segments"]:
277
+ coarse_labels[seg["coarse_action"]] += 1
278
+
279
+ print(f" {n_fine} fine → {n_coarse} coarse segments", flush=True)
280
+
281
+ print(f"\n{'='*60}")
282
+ print(f"Total: {total_files} files")
283
+ print(f" Fine segments: {total_fine}")
284
+ print(f" Coarse segments: {total_coarse}")
285
+ print(f" Compression: {total_fine/max(total_coarse,1):.1f}x")
286
+ print(f" API calls: {call_count}")
287
+
288
+ print(f"\n Coarse label distribution:")
289
+ for label, count in coarse_labels.most_common():
290
+ print(f" {label:<20} {count:>5} ({count/max(total_coarse,1)*100:.1f}%)")
291
+
292
+ print(f"\n Output: {OUTPUT_DIR}")
293
+
294
+
295
+ if __name__ == "__main__":
296
+ main()
experiments/analysis/grasp_phase_analysis.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Grasp Phase Timing Analysis — Flagship visualization for the paper.
4
+
5
+ Classic neuroscience finding:
6
+ Eye gaze → EMG activation → Hand motion → Pressure contact
7
+
8
+ This script:
9
+ 1. Detects grasp events (pressure onset: 0 → >5g)
10
+ 2. Looks back in time to find:
11
+ - EMG envelope activation onset
12
+ - Hand velocity peak (from MoCap)
13
+ - Eye gaze fixation (if available)
14
+ 3. Computes statistics over all grasp events
15
+ 4. Produces the canonical "grasp phase" timing figure
16
+ """
17
+
18
+ import os
19
+ import glob
20
+ import json
21
+ import numpy as np
22
+ import pandas as pd
23
+ import matplotlib
24
+ matplotlib.use('Agg')
25
+ import matplotlib.pyplot as plt
26
+ from scipy import signal as scisig
27
+ from collections import defaultdict
28
+
29
+ DATASET_DIR = "${PULSE_ROOT}/dataset"
30
+ OUTPUT_DIR = "${PULSE_ROOT}/results/grasp_phase"
31
+ SAMPLING_RATE = 100 # Hz
32
+ PRESSURE_THRESHOLD = 5.0 # grams
33
+ CONTEXT_WINDOW_SEC = 2.0 # look back 2s before contact
34
+ CONTEXT_FRAMES = int(CONTEXT_WINDOW_SEC * SAMPLING_RATE)
35
+
36
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
37
+
38
+
39
+ def load_pressure(scenario_dir):
40
+ """Load pressure data and return (T, 2) array: [right_total, left_total]."""
41
+ f = os.path.join(scenario_dir, "aligned_pressure_100hz.csv")
42
+ if not os.path.exists(f):
43
+ return None
44
+ df = pd.read_csv(f, low_memory=False)
45
+ r_cols = [c for c in df.columns if c.startswith('R') and c.endswith('(g)')]
46
+ l_cols = [c for c in df.columns if c.startswith('L') and c.endswith('(g)')]
47
+ if not r_cols or not l_cols:
48
+ return None
49
+ r = df[r_cols].apply(pd.to_numeric, errors='coerce').fillna(0).values.sum(axis=1)
50
+ l = df[l_cols].apply(pd.to_numeric, errors='coerce').fillna(0).values.sum(axis=1)
51
+ return np.stack([r, l], axis=1) # (T, 2)
52
+
53
+
54
+ def load_emg(scenario_dir):
55
+ """Load EMG data: (T, 8) array."""
56
+ f = os.path.join(scenario_dir, "aligned_emg_100hz.csv")
57
+ if not os.path.exists(f):
58
+ return None
59
+ df = pd.read_csv(f, low_memory=False)
60
+ # Find EMG channel columns (e.g., EMG1...EMG8 or channels)
61
+ numeric_cols = df.select_dtypes(include=[np.number]).columns
62
+ numeric_cols = [c for c in numeric_cols if c not in ('Frame', 'Time', 'time', 'UTC')]
63
+ if len(numeric_cols) < 4:
64
+ return None
65
+ arr = df[numeric_cols].values.astype(np.float32)
66
+ arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
67
+ return arr
68
+
69
+
70
+ def load_mocap(scenario_dir, vol, scenario):
71
+ """Load MoCap hand position, return (T, 3) right hand velocity magnitude, (T, 3) left hand."""
72
+ f = os.path.join(scenario_dir, f"aligned_{vol}{scenario}_s_Q.tsv")
73
+ if not os.path.exists(f):
74
+ return None, None
75
+ df = pd.read_csv(f, sep='\t', low_memory=False)
76
+ # Find right/left hand position columns
77
+ # Try common naming patterns
78
+ r_cols = [c for c in df.columns if 'RightHand' in c and (c.endswith('_X') or c.endswith('_Y') or c.endswith('_Z'))]
79
+ l_cols = [c for c in df.columns if 'LeftHand' in c and (c.endswith('_X') or c.endswith('_Y') or c.endswith('_Z'))]
80
+ if not r_cols or not l_cols:
81
+ # Try alternative naming
82
+ r_cols = [c for c in df.columns if 'R_Hand' in c or 'RHand' in c][:3]
83
+ l_cols = [c for c in df.columns if 'L_Hand' in c or 'LHand' in c][:3]
84
+ if not r_cols or not l_cols:
85
+ return None, None
86
+
87
+ r_pos = df[r_cols[:3]].apply(pd.to_numeric, errors='coerce').fillna(0).values
88
+ l_pos = df[l_cols[:3]].apply(pd.to_numeric, errors='coerce').fillna(0).values
89
+ return r_pos, l_pos
90
+
91
+
92
+ def compute_emg_envelope(emg, window_size=20):
93
+ """Rectify and low-pass filter EMG to get envelope."""
94
+ # Rectify
95
+ rectified = np.abs(emg - np.mean(emg, axis=0))
96
+ # Moving average
97
+ kernel = np.ones(window_size) / window_size
98
+ envelope = np.zeros_like(rectified)
99
+ for ch in range(rectified.shape[1]):
100
+ envelope[:, ch] = np.convolve(rectified[:, ch], kernel, mode='same')
101
+ # Sum across channels and normalize
102
+ total = envelope.sum(axis=1)
103
+ if total.max() > total.min():
104
+ total = (total - total.min()) / (total.max() - total.min() + 1e-8)
105
+ return total # (T,)
106
+
107
+
108
+ def compute_velocity(position, window=3):
109
+ """Compute velocity magnitude from 3D position."""
110
+ vel = np.zeros_like(position)
111
+ vel[1:] = position[1:] - position[:-1]
112
+ vel_mag = np.linalg.norm(vel, axis=1)
113
+ # Smooth
114
+ kernel = np.ones(window) / window
115
+ vel_mag = np.convolve(vel_mag, kernel, mode='same')
116
+ return vel_mag # (T,)
117
+
118
+
119
+ def detect_grasp_events(pressure_1d, threshold=5.0, min_duration=10, min_gap=50):
120
+ """Detect pressure onset events (0 → >threshold).
121
+
122
+ Returns list of onset frame indices.
123
+ """
124
+ above = pressure_1d > threshold
125
+ # Hysteresis smoothing: require persistence
126
+ onsets = []
127
+ last_state = False
128
+ stable_counter = 0
129
+ for i, a in enumerate(above):
130
+ if a and not last_state:
131
+ # Candidate onset, check persistence
132
+ if i + min_duration < len(above) and np.mean(above[i:i+min_duration]) > 0.7:
133
+ if not onsets or i - onsets[-1] > min_gap:
134
+ onsets.append(i)
135
+ last_state = True
136
+ elif not a and last_state:
137
+ # Check if really released
138
+ if i + 5 < len(above) and np.mean(above[i:i+5]) < 0.3:
139
+ last_state = False
140
+ return onsets
141
+
142
+
143
+ def find_signal_onset(signal, ref_idx, window_frames, threshold_ratio=0.3):
144
+ """Find the LATEST pre-contact onset of signal activation.
145
+
146
+ Strategy: walk backward from ref_idx. Look for the last sample that's
147
+ still 'active' (> baseline + threshold_ratio * (peak-baseline)).
148
+ The first 'inactive' sample going backward marks the onset.
149
+
150
+ Returns: frame index of onset relative to ref_idx (negative = before).
151
+ """
152
+ start = max(0, ref_idx - window_frames)
153
+ segment = signal[start:ref_idx + 1] # pre-contact window
154
+ if len(segment) < 10:
155
+ return None
156
+
157
+ # Baseline: lower quartile of the pre-contact window (robust to activation)
158
+ # Only use the earliest 30% as baseline estimate
159
+ early_part = segment[:max(10, int(len(segment) * 0.3))]
160
+ baseline = np.percentile(early_part, 25)
161
+
162
+ # Peak of the pre-contact activation
163
+ peak = np.max(segment)
164
+ if peak - baseline < 1e-4:
165
+ return None
166
+
167
+ threshold = baseline + (peak - baseline) * threshold_ratio
168
+
169
+ # Walk BACKWARD from ref_idx: find the last consecutive 'active' region
170
+ # ending at ref_idx, then the onset is where that region starts
171
+ above = segment > threshold
172
+ if not above[-1]:
173
+ # Not active at contact - use threshold crossing pattern
174
+ # Find the rising edge closest to ref_idx
175
+ rising = np.where(np.diff(above.astype(int)) == 1)[0]
176
+ if len(rising) == 0:
177
+ return None
178
+ onset_local = rising[-1] + 1 # first active frame
179
+ else:
180
+ # Active at contact - walk back to find onset
181
+ onset_local = len(segment) - 1
182
+ while onset_local > 0 and above[onset_local - 1]:
183
+ onset_local -= 1
184
+
185
+ onset_global = start + onset_local
186
+ return onset_global - ref_idx # negative = before contact
187
+
188
+
189
+ def is_clean_grasp(emg_env, velocity, pressure_trace, onset, look_back=150, rest_window=50):
190
+ """Check if this is a CLEAN grasp starting from rest.
191
+
192
+ Requires: EMG and velocity are both low in the REST window (onset-150 ~ onset-100).
193
+ """
194
+ rest_start = onset - look_back
195
+ rest_end = onset - (look_back - rest_window)
196
+ if rest_start < 0:
197
+ return False
198
+
199
+ # Quiescent rest period: EMG and velocity both low
200
+ emg_rest = emg_env[rest_start:rest_end].mean()
201
+ vel_rest = velocity[rest_start:rest_end].mean()
202
+
203
+ # Compare to the entire pre-contact activation
204
+ emg_pre = emg_env[rest_end:onset]
205
+ vel_pre = velocity[rest_end:onset]
206
+
207
+ if len(emg_pre) < 10:
208
+ return False
209
+
210
+ # The rest period should be significantly lower than the activation period
211
+ emg_active = np.percentile(emg_pre, 75)
212
+ vel_active = np.percentile(vel_pre, 75)
213
+
214
+ emg_increase = emg_active - emg_rest
215
+ vel_increase = vel_active - vel_rest
216
+
217
+ # Require meaningful increase from rest to activation
218
+ emg_dyn = emg_env.max() - emg_env.min()
219
+ vel_dyn = velocity.max() - velocity.min()
220
+
221
+ if emg_dyn < 1e-6 or vel_dyn < 1e-6:
222
+ return False
223
+
224
+ return (emg_increase / emg_dyn > 0.1) and (vel_increase / vel_dyn > 0.1)
225
+
226
+
227
+ def analyze_one_scenario(vol, scenario):
228
+ """Analyze clean grasp events starting from rest."""
229
+ scenario_dir = os.path.join(DATASET_DIR, vol, scenario)
230
+
231
+ pressure = load_pressure(scenario_dir)
232
+ emg = load_emg(scenario_dir)
233
+ mocap_r, mocap_l = load_mocap(scenario_dir, vol, scenario)
234
+
235
+ if pressure is None or emg is None or mocap_r is None:
236
+ return None
237
+
238
+ min_len = min(pressure.shape[0], emg.shape[0], mocap_r.shape[0])
239
+ pressure = pressure[:min_len]
240
+ emg = emg[:min_len]
241
+ mocap_r = mocap_r[:min_len]
242
+ mocap_l = mocap_l[:min_len]
243
+
244
+ emg_env = compute_emg_envelope(emg)
245
+ vel_r = compute_velocity(mocap_r)
246
+ vel_l = compute_velocity(mocap_l)
247
+
248
+ events = []
249
+
250
+ for hand_name, hand_pressure, hand_vel in [
251
+ ('right', pressure[:, 0], vel_r),
252
+ ('left', pressure[:, 1], vel_l),
253
+ ]:
254
+ onsets = detect_grasp_events(hand_pressure, threshold=PRESSURE_THRESHOLD)
255
+ for onset in onsets:
256
+ if onset < CONTEXT_FRAMES:
257
+ continue
258
+
259
+ # Filter: only clean grasps starting from rest
260
+ if not is_clean_grasp(emg_env, hand_vel, hand_pressure, onset):
261
+ continue
262
+
263
+ # Find EMG onset: look for sustained activation rising from rest
264
+ emg_delay = find_signal_onset(emg_env, onset, CONTEXT_FRAMES, threshold_ratio=0.3)
265
+ motion_delay = find_signal_onset(hand_vel, onset, CONTEXT_FRAMES, threshold_ratio=0.3)
266
+ if emg_delay is None or motion_delay is None:
267
+ continue
268
+
269
+ # Sanity check: delays should be within [-1500, 0] ms
270
+ if emg_delay * 10 < -1500 or emg_delay * 10 > 0:
271
+ continue
272
+ if motion_delay * 10 < -1500 or motion_delay * 10 > 0:
273
+ continue
274
+
275
+ start = onset - CONTEXT_FRAMES
276
+ end = onset + 50
277
+ events.append({
278
+ 'pressure': hand_pressure[start:end],
279
+ 'emg': emg_env[start:end],
280
+ 'velocity': hand_vel[start:end],
281
+ 'hand': hand_name,
282
+ 'onset_idx': onset,
283
+ 'emg_delay_ms': emg_delay * 10,
284
+ 'motion_delay_ms': motion_delay * 10,
285
+ })
286
+
287
+ return events
288
+
289
+
290
+ def main():
291
+ all_events = []
292
+ stats = defaultdict(int)
293
+
294
+ for vol_dir in sorted(glob.glob(f"{DATASET_DIR}/v*")):
295
+ vol = os.path.basename(vol_dir)
296
+ for scenario_dir in sorted(glob.glob(f"{vol_dir}/s*")):
297
+ scenario = os.path.basename(scenario_dir)
298
+ meta_path = os.path.join(scenario_dir, 'alignment_metadata.json')
299
+ if not os.path.exists(meta_path):
300
+ continue
301
+ meta = json.load(open(meta_path))
302
+ # Need all 3 modalities
303
+ if not {'pressure', 'emg', 'mocap'}.issubset(set(meta['modalities'])):
304
+ stats['no_modality'] += 1
305
+ continue
306
+
307
+ events = analyze_one_scenario(vol, scenario)
308
+ if events is None:
309
+ stats['load_error'] += 1
310
+ continue
311
+ all_events.extend(events)
312
+ stats['scenarios'] += 1
313
+ stats['events'] += len(events)
314
+ print(f"[{vol}/{scenario}] {len(events)} grasp events", flush=True)
315
+
316
+ print(f"\n=== Summary ===")
317
+ print(f"Scenarios processed: {stats['scenarios']}")
318
+ print(f"Total grasp events: {stats['events']}")
319
+ print(f"Loading errors: {stats['load_error']}")
320
+ print(f"Missing modality: {stats['no_modality']}")
321
+
322
+ if not all_events:
323
+ print("No events found!")
324
+ return
325
+
326
+ # Extract delays
327
+ emg_delays = np.array([e['emg_delay_ms'] for e in all_events])
328
+ motion_delays = np.array([e['motion_delay_ms'] for e in all_events])
329
+
330
+ print(f"\n=== Timing Statistics (ms, negative = before contact) ===")
331
+ print(f"EMG onset delay: mean={emg_delays.mean():.1f} median={np.median(emg_delays):.1f} std={emg_delays.std():.1f}")
332
+ print(f"Motion peak delay: mean={motion_delays.mean():.1f} median={np.median(motion_delays):.1f} std={motion_delays.std():.1f}")
333
+
334
+ # Save statistics
335
+ stats_dict = {
336
+ 'n_events': len(all_events),
337
+ 'emg_delay_ms': {'mean': float(emg_delays.mean()), 'median': float(np.median(emg_delays)),
338
+ 'std': float(emg_delays.std()), 'p25': float(np.percentile(emg_delays, 25)),
339
+ 'p75': float(np.percentile(emg_delays, 75))},
340
+ 'motion_delay_ms': {'mean': float(motion_delays.mean()), 'median': float(np.median(motion_delays)),
341
+ 'std': float(motion_delays.std()), 'p25': float(np.percentile(motion_delays, 25)),
342
+ 'p75': float(np.percentile(motion_delays, 75))},
343
+ }
344
+ with open(os.path.join(OUTPUT_DIR, 'timing_stats.json'), 'w') as f:
345
+ json.dump(stats_dict, f, indent=2)
346
+
347
+ # ============ Figure 1: Aligned signal traces (averaged) ============
348
+ # Filter to events that have sufficient context
349
+ valid = [e for e in all_events if len(e['pressure']) == CONTEXT_FRAMES + 50]
350
+ print(f"\nEvents with full context: {len(valid)} / {len(all_events)}")
351
+
352
+ if len(valid) < 10:
353
+ print("Not enough events for plotting")
354
+ return
355
+
356
+ # Normalize signals (per-event max)
357
+ def normalize(sigs):
358
+ sigs = np.stack(sigs)
359
+ # Normalize each to [0, 1]
360
+ sigs = sigs - sigs.min(axis=1, keepdims=True)
361
+ maxs = sigs.max(axis=1, keepdims=True)
362
+ sigs = sigs / (maxs + 1e-8)
363
+ return sigs
364
+
365
+ pressure_stack = normalize([e['pressure'] for e in valid])
366
+ emg_stack = normalize([e['emg'] for e in valid])
367
+ vel_stack = normalize([e['velocity'] for e in valid])
368
+
369
+ time_axis = np.arange(-CONTEXT_FRAMES, 50) * 10 # ms
370
+
371
+ fig, ax = plt.subplots(figsize=(9, 5))
372
+
373
+ # Plot mean ± std
374
+ for sigs, color, label in [
375
+ (emg_stack, '#E74C3C', 'EMG envelope'),
376
+ (vel_stack, '#3498DB', 'Hand velocity'),
377
+ (pressure_stack, '#27AE60', 'Pressure (contact)'),
378
+ ]:
379
+ mean = sigs.mean(axis=0)
380
+ std = sigs.std(axis=0)
381
+ ax.plot(time_axis, mean, color=color, linewidth=2.5, label=label)
382
+ ax.fill_between(time_axis, mean - std * 0.5, mean + std * 0.5, color=color, alpha=0.15)
383
+
384
+ ax.axvline(0, color='black', linestyle='--', linewidth=1.2, alpha=0.7, label='Contact onset')
385
+ ax.axvline(emg_delays.mean(), color='#E74C3C', linestyle=':', alpha=0.8)
386
+ ax.axvline(motion_delays.mean(), color='#3498DB', linestyle=':', alpha=0.8)
387
+
388
+ # Annotations
389
+ ax.annotate(f'EMG\n{emg_delays.mean():.0f}ms',
390
+ xy=(emg_delays.mean(), 0.85), ha='center', fontsize=10, color='#C0392B',
391
+ bbox=dict(boxstyle="round,pad=0.3", fc='#FADBD8', ec='#E74C3C', alpha=0.9))
392
+ ax.annotate(f'Motion\n{motion_delays.mean():.0f}ms',
393
+ xy=(motion_delays.mean(), 0.65), ha='center', fontsize=10, color='#1F618D',
394
+ bbox=dict(boxstyle="round,pad=0.3", fc='#D6EAF8', ec='#3498DB', alpha=0.9))
395
+
396
+ ax.set_xlabel('Time relative to contact onset (ms)', fontsize=12)
397
+ ax.set_ylabel('Normalized amplitude', fontsize=12)
398
+ ax.set_title(f'Grasp Phase Timing ({len(valid)} events, {stats["scenarios"]} recordings)',
399
+ fontsize=13, fontweight='bold')
400
+ ax.set_xlim(-CONTEXT_WINDOW_SEC * 1000, 500)
401
+ ax.legend(loc='upper left', frameon=True, fontsize=10)
402
+ ax.grid(True, alpha=0.3)
403
+ ax.set_ylim(-0.05, 1.1)
404
+
405
+ plt.tight_layout()
406
+ fig_path = os.path.join(OUTPUT_DIR, 'grasp_phase_timing.png')
407
+ plt.savefig(fig_path, dpi=150, bbox_inches='tight')
408
+ plt.savefig(fig_path.replace('.png', '.pdf'), bbox_inches='tight')
409
+ print(f"Saved figure: {fig_path}")
410
+
411
+ # ============ Figure 2: Delay distributions ============
412
+ fig, axes = plt.subplots(1, 2, figsize=(11, 4))
413
+
414
+ axes[0].hist(emg_delays, bins=30, color='#E74C3C', alpha=0.7, edgecolor='black')
415
+ axes[0].axvline(emg_delays.mean(), color='black', linestyle='--', linewidth=2, label=f'Mean: {emg_delays.mean():.0f}ms')
416
+ axes[0].axvline(np.median(emg_delays), color='grey', linestyle=':', linewidth=2, label=f'Median: {np.median(emg_delays):.0f}ms')
417
+ axes[0].set_xlabel('EMG onset - Contact onset (ms)', fontsize=11)
418
+ axes[0].set_ylabel('Count', fontsize=11)
419
+ axes[0].set_title('EMG → Contact Delay', fontsize=12, fontweight='bold')
420
+ axes[0].legend(fontsize=10)
421
+ axes[0].grid(True, alpha=0.3)
422
+
423
+ axes[1].hist(motion_delays, bins=30, color='#3498DB', alpha=0.7, edgecolor='black')
424
+ axes[1].axvline(motion_delays.mean(), color='black', linestyle='--', linewidth=2, label=f'Mean: {motion_delays.mean():.0f}ms')
425
+ axes[1].axvline(np.median(motion_delays), color='grey', linestyle=':', linewidth=2, label=f'Median: {np.median(motion_delays):.0f}ms')
426
+ axes[1].set_xlabel('Motion onset - Contact onset (ms)', fontsize=11)
427
+ axes[1].set_ylabel('Count', fontsize=11)
428
+ axes[1].set_title('Hand Motion → Contact Delay', fontsize=12, fontweight='bold')
429
+ axes[1].legend(fontsize=10)
430
+ axes[1].grid(True, alpha=0.3)
431
+
432
+ plt.tight_layout()
433
+ fig2_path = os.path.join(OUTPUT_DIR, 'delay_distributions.png')
434
+ plt.savefig(fig2_path, dpi=150, bbox_inches='tight')
435
+ plt.savefig(fig2_path.replace('.png', '.pdf'), bbox_inches='tight')
436
+ print(f"Saved figure: {fig2_path}")
437
+
438
+ print(f"\nAll outputs saved to: {OUTPUT_DIR}")
439
+
440
+
441
+ if __name__ == '__main__':
442
+ main()
experiments/analysis/modality_viz.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Visualize mocap skeleton frames, IMU waveforms, EMG waveforms."""
2
+ import os, numpy as np, pandas as pd, matplotlib.pyplot as plt
3
+ from mpl_toolkits.mplot3d import Axes3D # noqa
4
+
5
+ REC = "${PULSE_ROOT}/dataset/v1/s1"
6
+ OUT = "${PULSE_ROOT}/paper/figures"
7
+ os.makedirs(OUT, exist_ok=True)
8
+
9
+ # ---- Skeleton bone definition (marker pairs) ----
10
+ BONES = [
11
+ # torso
12
+ ("HeadTop","HeadFront"),("HeadL","HeadR"),("HeadFront","SpineTop"),
13
+ ("SpineTop","Chest"),("Chest","WaistLFront"),("Chest","WaistRFront"),
14
+ ("WaistLFront","WaistLBack"),("WaistRFront","WaistRBack"),
15
+ ("WaistLBack","BackL"),("WaistRBack","BackR"),("BackL","BackR"),
16
+ ("SpineTop","LShoulderTop"),("SpineTop","RShoulderTop"),
17
+ ("LShoulderTop","LShoulderBack"),("RShoulderTop","RShoulderBack"),
18
+ # left arm
19
+ ("LShoulderTop","LArm"),("LArm","LElbowOut"),("LElbowOut","LElbowBack"),
20
+ ("LElbowOut","LForearmRoll"),("LForearmRoll","LWristOut"),
21
+ ("LWristOut","LWristIn"),("LWristOut","LHandOut"),("LWristIn","LHandIn"),
22
+ ("LHandOut","LIndex2"),("LIndex2","LIndexTip"),
23
+ ("LHandOut","LMiddle2"),("LMiddle2","LMiddleTip"),
24
+ ("LHandIn","LRing2"),("LRing2","LRingTip"),
25
+ ("LHandIn","LPinky2"),("LPinky2","LPinkyTip"),
26
+ ("LWristIn","LThumb1"),("LThumb1","LThumbTip"),
27
+ # right arm
28
+ ("RShoulderTop","RArm"),("RArm","RElbowOut"),("RElbowOut","RElbowBack"),
29
+ ("RElbowOut","RForearmRoll"),("RForearmRoll","RWristOut"),
30
+ ("RWristOut","RWristIn"),("RWristOut","RHandOut"),("RWristIn","RHandIn"),
31
+ ("RHandOut","RIndex2"),("RIndex2","RIndexTip"),
32
+ ("RHandOut","RMiddle2"),("RMiddle2","RMiddleTip"),
33
+ ("RHandIn","RRing2"),("RRing2","RRingTip"),
34
+ ("RHandIn","RPinky2"),("RPinky2","RPinkyTip"),
35
+ ("RWristIn","RThumb1"),("RThumb1","RThumbTip"),
36
+ ]
37
+
38
+
39
+ def load_mocap(path):
40
+ df = pd.read_csv(path)
41
+ # Extract x,y,z for each marker ignoring Type cols
42
+ markers = {}
43
+ for col in df.columns:
44
+ if col.startswith("Q_") and col.endswith(" X"):
45
+ name = col[2:-2]
46
+ xs = df[f"Q_{name} X"].to_numpy()
47
+ ys = df[f"Q_{name} Y"].to_numpy()
48
+ zs = df[f"Q_{name} Z"].to_numpy()
49
+ markers[name] = np.stack([xs, ys, zs], axis=-1)
50
+ return df["Time"].to_numpy(), markers
51
+
52
+
53
+ def plot_skeletons():
54
+ t, mk = load_mocap(os.path.join(REC, "aligned_mocap_100hz.csv"))
55
+ N = len(t)
56
+ # pick 4 time frames well spread through the recording with valid data
57
+ candidate = np.linspace(int(0.1*N), int(0.9*N), 4).astype(int)
58
+
59
+ fig = plt.figure(figsize=(12, 3.2))
60
+ for i, fr in enumerate(candidate):
61
+ ax = fig.add_subplot(1, 4, i+1, projection='3d')
62
+ # gather all points at this frame
63
+ pts = np.array([mk[n][fr] for n in mk])
64
+ pts = pts[~np.isnan(pts).any(axis=1)]
65
+ if len(pts) == 0:
66
+ continue
67
+ # draw bones
68
+ for a, b in BONES:
69
+ if a in mk and b in mk:
70
+ pa, pb = mk[a][fr], mk[b][fr]
71
+ if np.isnan(pa).any() or np.isnan(pb).any():
72
+ continue
73
+ ax.plot([pa[0], pb[0]], [pa[1], pb[1]], [pa[2], pb[2]],
74
+ color='#2266aa', lw=1.2)
75
+ ax.scatter(pts[:, 0], pts[:, 1], pts[:, 2], s=4, c='#cc3333', alpha=0.8)
76
+ # equal aspect
77
+ c = pts.mean(0)
78
+ r = np.ptp(pts, axis=0).max() / 2
79
+ ax.set_xlim(c[0]-r, c[0]+r); ax.set_ylim(c[1]-r, c[1]+r); ax.set_zlim(c[2]-r, c[2]+r)
80
+ ax.set_xticks([]); ax.set_yticks([]); ax.set_zticks([])
81
+ ax.set_title(f"t={t[fr]:.1f}s", fontsize=9)
82
+ ax.view_init(elev=12, azim=-75)
83
+ fig.suptitle("MoCap skeleton frames (56-marker Qualisys, v1/s1)", fontsize=11)
84
+ fig.tight_layout()
85
+ out = os.path.join(OUT, "mocap_skeleton.pdf")
86
+ fig.savefig(out, bbox_inches='tight'); fig.savefig(out.replace('.pdf', '.png'), dpi=150, bbox_inches='tight')
87
+ plt.close(fig)
88
+ print("Saved", out)
89
+
90
+
91
+ def plot_imu():
92
+ df = pd.read_csv(os.path.join(REC, "aligned_imu_100hz.csv"))
93
+ t = df["time"].to_numpy(); t = t - t[0]
94
+ # pick 5 body locations (WT0..WT9 order roughly: wrists, forearms, upper arms, shins, thighs, torso)
95
+ sites = [("WT0", "Wrist R"), ("WT2", "Forearm R"),
96
+ ("WT4", "Upper arm R"), ("WT6", "Shin R"), ("WT9", "Torso")]
97
+ fig, axes = plt.subplots(len(sites), 1, figsize=(9, 6), sharex=True)
98
+ # crop to 20s window mid-recording
99
+ mid = len(t)//2
100
+ sl = slice(max(0, mid-1000), min(len(t), mid+1000))
101
+ for ax, (sid, lbl) in zip(axes, sites):
102
+ for comp, col in zip(["x", "y", "z"], ["#d62728", "#2ca02c", "#1f77b4"]):
103
+ ax.plot(t[sl], df[f"{sid}_acc_{comp}"].to_numpy()[sl], color=col, lw=0.8, label=f"acc_{comp}")
104
+ ax.set_ylabel(lbl, fontsize=9)
105
+ ax.grid(alpha=0.3)
106
+ axes[0].legend(loc="upper right", ncol=3, fontsize=8)
107
+ axes[-1].set_xlabel("Time (s)")
108
+ fig.suptitle("IMU 3-axis acceleration across 5 body sites (v1/s1, 20s window)", fontsize=11)
109
+ fig.tight_layout()
110
+ out = os.path.join(OUT, "imu_waveforms.pdf")
111
+ fig.savefig(out, bbox_inches='tight'); fig.savefig(out.replace('.pdf', '.png'), dpi=150, bbox_inches='tight')
112
+ plt.close(fig)
113
+ print("Saved", out)
114
+
115
+
116
+ def plot_emg():
117
+ df = pd.read_csv(os.path.join(REC, "aligned_emg_100hz.csv"))
118
+ t = df["time"].to_numpy(); t = t - t[0]
119
+ ch = [f"emg_{i}" for i in range(1, 9)]
120
+ # 20s window mid-recording
121
+ mid = len(t)//2
122
+ sl = slice(max(0, mid-1000), min(len(t), mid+1000))
123
+ fig, axes = plt.subplots(8, 1, figsize=(9, 7), sharex=True)
124
+ for ax, c in zip(axes, ch):
125
+ sig = df[c].to_numpy()[sl]
126
+ ax.plot(t[sl], sig, color="#555", lw=0.5)
127
+ # envelope overlay
128
+ env = pd.Series(np.abs(sig)).rolling(20, min_periods=1).mean().to_numpy()
129
+ ax.plot(t[sl], env, color="#d62728", lw=0.9)
130
+ ax.set_ylabel(c, fontsize=8)
131
+ ax.grid(alpha=0.3)
132
+ axes[-1].set_xlabel("Time (s)")
133
+ fig.suptitle("Surface EMG 8-channel raw (grey) with rectified envelope (red), v1/s1, 20s window",
134
+ fontsize=11)
135
+ fig.tight_layout()
136
+ out = os.path.join(OUT, "emg_waveforms.pdf")
137
+ fig.savefig(out, bbox_inches='tight'); fig.savefig(out.replace('.pdf', '.png'), dpi=150, bbox_inches='tight')
138
+ plt.close(fig)
139
+ print("Saved", out)
140
+
141
+
142
+ if __name__ == "__main__":
143
+ plot_skeletons()
144
+ plot_imu()
145
+ plot_emg()
experiments/analysis/reannotate_actions.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Re-annotate action segments using LLM (GPT-4o-mini).
4
+ 1. Re-classify existing segments with better accuracy
5
+ 2. Infer actions in unlabeled gaps based on context (scene, surrounding actions)
6
+ 3. Output improved annotations with higher coverage
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import json
12
+ import re
13
+ import time
14
+ import copy
15
+ import glob
16
+ import urllib.request
17
+ from collections import Counter
18
+
19
+ ANN_DIR = "${PULSE_ROOT}/annotations_by_scene"
20
+ OUTPUT_DIR = "${PULSE_ROOT}/annotations_v2"
21
+ DATASET_DIR = "${PULSE_ROOT}/dataset"
22
+
23
+ API_URL = "https://api.chatanywhere.tech/v1/chat/completions"
24
+ API_KEYS = [
25
+ "sk-MN5n1uEETyaky96fLJdHqZobXF1f7KmOrZHzwD3lt585asFQ",
26
+ "sk-YnYrtPdAXwlE12hRpi6dYqlE1RRVR3LDVBka6wKaefU4iQRY",
27
+ "sk-jOZtodDv6OxUOMu3NuJ8lzffjwBlshn9OHY5KSmqmPTtc9qs",
28
+ "sk-qAaKTKYIRF24btu1oQWgubWG4UdA92bILNtzOkHNEPAcCxdB",
29
+ "sk-MgCBBonblMrCFnSXd6fJZaBLTCfCJ5FjYZfSe2e46bgmyktk",
30
+ "sk-79e30kYRgduuf2fSU0Lsc814YjNkClXXzQqIbx0iLS40IOEH",
31
+ "sk-h9Tej4tW6AQC6fT0njfzrPKXEk6fBwpiSvvQd0aJAhw4UwLz",
32
+ "sk-k2QNHt5wAH26Fw8hZuPWuVXw8Psd1jX09qusiA6PdBj5Vzuu",
33
+ "sk-w7EkTblciNI44cwosHXi0PGZNUf1hnJmpzOQ85va9VPdAKbz",
34
+ "sk-Dexs5ZF7OjFCq7CZW45wJ8EKoGtIswv6rsLUMzUXXkWBDBBJ",
35
+ ]
36
+
37
+ SCENE_DESCRIPTIONS = {
38
+ "s1": "办公桌面整理与工作准备(整理文件、电源线、鼠标、笔记本电脑等)",
39
+ "s2": "快递打包发送(折叠纸箱、放入物品、封箱、贴标签等)",
40
+ "s3": "厨房调料整理(拿取调料瓶、倒调料、拧瓶盖、擦拭等)",
41
+ "s4": "清理餐后桌面(收碗碟、擦桌子、整理餐具、倒残渣等)",
42
+ "s5": "餐前桌面布置(铺桌布、摆放餐具碗碟、放杯子等)",
43
+ "s6": "商务旅行行李箱打包(折叠衣物、放入行李箱、整理物品等)",
44
+ "s7": "冲泡咖啡/饮品(取杯子、放咖啡粉/茶包、倒热水、搅拌等)",
45
+ "s8": "晾衣架整理与衣物收纳(取衣架、挂衣服、折叠衣物等)",
46
+ }
47
+
48
+ ACTION_CATEGORIES = """动作类别定义(共11类):
49
+
50
+ 1. Grasp - 抓取/拿起物体(手从无接触到接触并握住物体)
51
+ 2. Place - 放置/放下物体(将物体放到某个位置并释放)
52
+ 3. Pour - 倾倒/注入液体或颗粒(倒水、倒调料、倒咖啡粉等)
53
+ 4. Wipe - 擦拭/清洁表面(用抹布或手擦桌面、瓶身等)
54
+ 5. Fold - 折叠/卷起(折衣服、折桌布、折纸箱等)
55
+ 6. OpenClose - 打开/关闭/旋开/旋紧(开盒子、拧瓶盖、拉拉链、合箱盖等)
56
+ 7. Stir - 搅拌(搅拌咖啡、搅拌饮品等)
57
+ 8. TearCut - 撕/剪/粘贴(撕胶带、剪快递单、贴标签等)
58
+ 9. Arrange - 整理/摆放/调整位置(摆餐具、整理文件、调整物品位置、理线等)
59
+ 10. Transport - 搬运/移动物体到较远位置(把包裹搬到架子、把碗端到水槽等)
60
+ 11. Idle - 空闲/过渡/无明确操作(双手无目的性动作、等待、观察等)
61
+
62
+ 注意:
63
+ - 只有真正没有任何手部操作时才标Idle
64
+ - "调整姿态"、"检查物体"等属于Arrange
65
+ - "插入"、"装入"等属于Place
66
+ - "提起并移动"如果距离短属于Grasp,距离远属于Transport
67
+ """
68
+
69
+ current_key_idx = 0
70
+ call_count = 0
71
+
72
+
73
+ def call_llm(prompt, max_tokens=1000, retries=3):
74
+ """Call LLM API with automatic key rotation."""
75
+ global current_key_idx, call_count
76
+
77
+ for attempt in range(retries * len(API_KEYS)):
78
+ key = API_KEYS[current_key_idx]
79
+ try:
80
+ data = json.dumps({
81
+ "model": "gpt-4o-mini",
82
+ "messages": [{"role": "user", "content": prompt}],
83
+ "max_tokens": max_tokens,
84
+ "temperature": 0.1,
85
+ }).encode()
86
+ req = urllib.request.Request(
87
+ API_URL, data=data,
88
+ headers={
89
+ "Content-Type": "application/json",
90
+ "Authorization": f"Bearer {key}",
91
+ }
92
+ )
93
+ resp = urllib.request.urlopen(req, timeout=30)
94
+ result = json.loads(resp.read())
95
+ call_count += 1
96
+ return result["choices"][0]["message"]["content"]
97
+ except Exception as e:
98
+ err = str(e)
99
+ if "429" in err or "quota" in err or "limit" in err or "402" in err:
100
+ # Key exhausted, rotate
101
+ print(f" Key {current_key_idx+1} exhausted, rotating...")
102
+ current_key_idx = (current_key_idx + 1) % len(API_KEYS)
103
+ elif "timeout" in err.lower():
104
+ time.sleep(1)
105
+ else:
106
+ print(f" API error: {err[:100]}")
107
+ current_key_idx = (current_key_idx + 1) % len(API_KEYS)
108
+ time.sleep(0.5)
109
+
110
+ print(" WARNING: All API keys failed!")
111
+ return None
112
+
113
+
114
+ def reclassify_segments(segments, scene_id):
115
+ """Use LLM to reclassify all segments in a recording."""
116
+ scene_desc = SCENE_DESCRIPTIONS.get(scene_id, "日常活动")
117
+
118
+ # Build segment list for prompt
119
+ seg_list = []
120
+ for i, seg in enumerate(segments):
121
+ seg_list.append(f"{i+1}. [{seg['timestamp']}] {seg['task']}")
122
+ seg_text = "\n".join(seg_list)
123
+
124
+ prompt = f"""你是一个人体动作标注专家。请为以下每个动作片段分配一个动作类别。
125
+
126
+ 场景:{scene_desc}
127
+
128
+ {ACTION_CATEGORIES}
129
+
130
+ 动作片段列表:
131
+ {seg_text}
132
+
133
+ 请严格按以下JSON格式返回,不要添加任何额外文字:
134
+ [{{"id": 1, "action": "类别名"}}, {{"id": 2, "action": "类别名"}}, ...]
135
+
136
+ 每个action必须是以下之一:Grasp, Place, Pour, Wipe, Fold, OpenClose, Stir, TearCut, Arrange, Transport, Idle"""
137
+
138
+ response = call_llm(prompt, max_tokens=len(segments) * 40)
139
+ if response is None:
140
+ return None
141
+
142
+ # Parse response
143
+ try:
144
+ # Extract JSON from response
145
+ match = re.search(r'\[.*\]', response, re.DOTALL)
146
+ if match:
147
+ results = json.loads(match.group())
148
+ return {r["id"]: r["action"] for r in results}
149
+ except (json.JSONDecodeError, KeyError) as e:
150
+ print(f" Parse error: {e}, response: {response[:200]}")
151
+ return None
152
+
153
+
154
+ def infer_gap_actions(scene_id, before_seg, after_seg, gap_start, gap_end):
155
+ """Use LLM to infer what actions likely happened in an unlabeled gap."""
156
+ scene_desc = SCENE_DESCRIPTIONS.get(scene_id, "日常活动")
157
+ gap_duration = gap_end - gap_start
158
+
159
+ before_text = f"[{before_seg['timestamp']}] {before_seg['task']}" if before_seg else "(录制开始)"
160
+ after_text = f"[{after_seg['timestamp']}] {after_seg['task']}" if after_seg else "(录制结束)"
161
+
162
+ prompt = f"""你是一个人体动作标注专家。在一段日常活动录制中,有一段时间没有被标注。请根据场景和前后动作推断这段时间内最可能发生的动作。
163
+
164
+ 场景:{scene_desc}
165
+ 未标注时间段:{gap_start//60:02d}:{gap_start%60:02d} - {gap_end//60:02d}:{gap_end%60:02d}(共{gap_duration}秒)
166
+ 前一个标注动作:{before_text}
167
+ 后一个标注动作:{after_text}
168
+
169
+ {ACTION_CATEGORIES}
170
+
171
+ 请推断这段时间内可能发生的动作序列。每个动作段落2-4秒,时间用MM:SS格式。
172
+ 如果确实是空闲等待,标注为Idle。
173
+
174
+ 严格按以下JSON格式返回,不要添加任何额外文字:
175
+ [{{"timestamp": "MM:SS-MM:SS", "task": "动作描述", "action": "类别名"}}]
176
+
177
+ 每个action必须是以下之一:Grasp, Place, Pour, Wipe, Fold, OpenClose, Stir, TearCut, Arrange, Transport, Idle"""
178
+
179
+ response = call_llm(prompt, max_tokens=500)
180
+ if response is None:
181
+ return []
182
+
183
+ try:
184
+ match = re.search(r'\[.*\]', response, re.DOTALL)
185
+ if match:
186
+ results = json.loads(match.group())
187
+ # Validate timestamps
188
+ valid = []
189
+ for r in results:
190
+ if "timestamp" in r and "action" in r and "task" in r:
191
+ ts_match = re.match(r'(\d+):(\d+)\s*-\s*(\d+):(\d+)', r["timestamp"])
192
+ if ts_match:
193
+ s = int(ts_match.group(1))*60 + int(ts_match.group(2))
194
+ e = int(ts_match.group(3))*60 + int(ts_match.group(4))
195
+ if gap_start <= s < e <= gap_end:
196
+ valid.append(r)
197
+ return valid
198
+ except (json.JSONDecodeError, KeyError) as e:
199
+ print(f" Parse error: {e}")
200
+ return []
201
+
202
+
203
+ def get_recording_duration(vol, scenario):
204
+ """Get total recording duration in seconds."""
205
+ meta_path = os.path.join(DATASET_DIR, vol, scenario, "alignment_metadata.json")
206
+ if os.path.exists(meta_path):
207
+ meta = json.load(open(meta_path))
208
+ if "aligned_length_sec" in meta:
209
+ return meta["aligned_length_sec"]
210
+ if "aligned_length_frames" in meta:
211
+ return meta["aligned_length_frames"] / 100.0
212
+ return None
213
+
214
+
215
+ def process_one_file(ann_path, vol, scenario):
216
+ """Process one annotation file: reclassify + fill gaps."""
217
+ data = json.load(open(ann_path))
218
+ segments = data["segments"]
219
+
220
+ if not segments:
221
+ return data, {"reclassified": 0, "gaps_filled": 0}
222
+
223
+ # Step 1: Reclassify existing segments
224
+ print(f" Reclassifying {len(segments)} segments...")
225
+ classifications = reclassify_segments(segments, scenario)
226
+
227
+ if classifications:
228
+ for i, seg in enumerate(segments):
229
+ action = classifications.get(i + 1)
230
+ if action and action in {"Grasp", "Place", "Pour", "Wipe", "Fold",
231
+ "OpenClose", "Stir", "TearCut", "Arrange",
232
+ "Transport", "Idle"}:
233
+ seg["action_label"] = action
234
+ else:
235
+ seg["action_label"] = "Idle"
236
+ else:
237
+ # Fallback: keep without label
238
+ for seg in segments:
239
+ seg["action_label"] = "Idle"
240
+
241
+ reclassified = sum(1 for s in segments if "action_label" in s)
242
+
243
+ # Step 2: Find and fill gaps ≥ 3 seconds
244
+ # Parse all timestamps
245
+ parsed = []
246
+ for seg in segments:
247
+ m = re.match(r'(\d+):(\d+)\s*-\s*(\d+):(\d+)', seg["timestamp"])
248
+ if m:
249
+ s = int(m.group(1))*60 + int(m.group(2))
250
+ e = int(m.group(3))*60 + int(m.group(4))
251
+ parsed.append((s, e, seg))
252
+ parsed.sort()
253
+
254
+ total_dur = get_recording_duration(vol, scenario)
255
+
256
+ new_segments = []
257
+ gaps_filled = 0
258
+
259
+ for i in range(len(parsed)):
260
+ new_segments.append(parsed[i][2])
261
+
262
+ # Check gap after this segment
263
+ if i < len(parsed) - 1:
264
+ gap_start = parsed[i][1]
265
+ gap_end = parsed[i + 1][0]
266
+ elif total_dur:
267
+ gap_start = parsed[i][1]
268
+ gap_end = int(total_dur)
269
+ else:
270
+ continue
271
+
272
+ gap_duration = gap_end - gap_start
273
+ if gap_duration >= 3:
274
+ before_seg = parsed[i][2]
275
+ after_seg = parsed[i + 1][2] if i < len(parsed) - 1 else None
276
+
277
+ print(f" Filling gap {gap_start}s-{gap_end}s ({gap_duration}s)...")
278
+ inferred = infer_gap_actions(scenario, before_seg, after_seg, gap_start, gap_end)
279
+
280
+ for inf in inferred:
281
+ new_seg = {
282
+ "timestamp": inf["timestamp"],
283
+ "task": inf["task"],
284
+ "action_label": inf["action"],
285
+ "source": "llm_inferred",
286
+ "left_hand": "",
287
+ "right_hand": "",
288
+ "bimanual_interaction": "",
289
+ "objects": [],
290
+ }
291
+ new_segments.append(new_seg)
292
+ gaps_filled += 1
293
+
294
+ # Also check gap at the beginning
295
+ if parsed and parsed[0][0] >= 3:
296
+ print(f" Filling start gap 0s-{parsed[0][0]}s...")
297
+ inferred = infer_gap_actions(scenario, None, parsed[0][2], 0, parsed[0][0])
298
+ for inf in inferred:
299
+ new_seg = {
300
+ "timestamp": inf["timestamp"],
301
+ "task": inf["task"],
302
+ "action_label": inf["action"],
303
+ "source": "llm_inferred",
304
+ "left_hand": "",
305
+ "right_hand": "",
306
+ "bimanual_interaction": "",
307
+ "objects": [],
308
+ }
309
+ new_segments.insert(0, new_seg)
310
+ gaps_filled += 1
311
+
312
+ # Sort by timestamp
313
+ def sort_key(seg):
314
+ m = re.match(r'(\d+):(\d+)', seg["timestamp"])
315
+ return int(m.group(1))*60 + int(m.group(2)) if m else 0
316
+ new_segments.sort(key=sort_key)
317
+
318
+ result = copy.deepcopy(data)
319
+ result["segments"] = new_segments
320
+
321
+ return result, {"reclassified": reclassified, "gaps_filled": gaps_filled}
322
+
323
+
324
+ def main():
325
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
326
+
327
+ total_reclassified = 0
328
+ total_gaps_filled = 0
329
+ total_files = 0
330
+
331
+ for vol_dir in sorted(glob.glob(f"{ANN_DIR}/v*")):
332
+ vol = os.path.basename(vol_dir)
333
+ out_vol_dir = os.path.join(OUTPUT_DIR, vol)
334
+ os.makedirs(out_vol_dir, exist_ok=True)
335
+
336
+ for ann_file in sorted(glob.glob(f"{vol_dir}/s*.json")):
337
+ scenario = os.path.basename(ann_file).replace(".json", "")
338
+ print(f"\n[{vol}/{scenario}]", flush=True)
339
+
340
+ result, stats = process_one_file(ann_file, vol, scenario)
341
+
342
+ # Save
343
+ out_path = os.path.join(out_vol_dir, f"{scenario}.json")
344
+ with open(out_path, "w", encoding="utf-8") as f:
345
+ json.dump(result, f, ensure_ascii=False, indent=2)
346
+
347
+ total_reclassified += stats["reclassified"]
348
+ total_gaps_filled += stats["gaps_filled"]
349
+ total_files += 1
350
+
351
+ print(f" Done: {stats['reclassified']} reclassified, {stats['gaps_filled']} gaps filled",
352
+ flush=True)
353
+
354
+ print(f"\n{'='*60}")
355
+ print(f"Total: {total_files} files processed")
356
+ print(f" Reclassified: {total_reclassified} segments")
357
+ print(f" Gap-filled: {total_gaps_filled} new segments")
358
+ print(f" API calls: {call_count}")
359
+ print(f" Output: {OUTPUT_DIR}")
360
+
361
+
362
+ if __name__ == "__main__":
363
+ main()
experiments/data/__init__.py ADDED
File without changes
experiments/data/__pycache__/dataset.cpython-312.pyc ADDED
Binary file (18.8 kB). View file
 
experiments/data/dataset.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multimodal scene dataset for Experiment 1: Activity Recognition.
3
+ Loads aligned 100Hz multi-modal data, supports modality selection,
4
+ subject-independent splits, and variable-length sequence handling.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import numpy as np
10
+ import pandas as pd
11
+ import torch
12
+ from torch.utils.data import Dataset, DataLoader
13
+ from torch.nn.utils.rnn import pad_sequence
14
+
15
+ DATASET_DIR = "${PULSE_ROOT}/dataset"
16
+
17
+ MODALITY_FILES = {
18
+ 'mocap': None, # Special: uses aligned_{vol}{scene}_s_Q.tsv (skeleton data)
19
+ 'emg': 'aligned_emg_100hz.csv',
20
+ 'eyetrack': 'aligned_eyetrack_100hz.csv',
21
+ 'imu': 'aligned_imu_100hz.csv',
22
+ 'pressure': 'aligned_pressure_100hz.csv',
23
+ 'video': 'video_features_100hz.npy', # ViT-B/16 (ImageNet)
24
+ 'videomae': 'video_features_videomae_100hz.npy', # VideoMAE (Kinetics-400)
25
+ }
26
+
27
+
28
+ def get_modality_filepath(scenario_dir, modality, vol=None, scenario=None):
29
+ """Return the file path for a given modality.
30
+
31
+ Mocap uses a special naming pattern: aligned_{vol}{scene}_s_Q.tsv
32
+ All other modalities use MODALITY_FILES directly.
33
+ """
34
+ if modality == 'mocap':
35
+ if vol is None or scenario is None:
36
+ raise ValueError("vol and scenario required for mocap modality")
37
+ return os.path.join(scenario_dir, f"aligned_{vol}{scenario}_s_Q.tsv")
38
+ return os.path.join(scenario_dir, MODALITY_FILES[modality])
39
+
40
+ SKIP_COLS = {'Frame', 'Time', 'time', 'UTC'}
41
+ SKIP_COL_SUFFIXES = (' Type',)
42
+
43
+ # Eyetrack exports sometimes include volunteer-specific marker/ICA columns.
44
+ # Benchmark inputs use the fixed 24 core gaze columns below; recordings missing
45
+ # any core column are skipped instead of truncating the full dataset.
46
+ EYETRACK_SKIP_PATTERNS = ('Index Of Cognitive Activity', 'Marker Coordinates', 'Markers_')
47
+ EYETRACK_CORE_COLS = [
48
+ 'Dikablis Glasses 3_Eye Data_Original_Pupil X',
49
+ 'Dikablis Glasses 3_Eye Data_Original_Pupil Y',
50
+ 'Dikablis Glasses 3_Eye Data_Original_Left Eye_Pupil X',
51
+ 'Dikablis Glasses 3_Eye Data_Original_Left Eye_Pupil Y',
52
+ 'Dikablis Glasses 3_Eye Data_Original_Left Eye_Pupil Area',
53
+ 'Dikablis Glasses 3_Eye Data_Original_Left Eye_Pupil Height',
54
+ 'Dikablis Glasses 3_Eye Data_Original_Left Eye_Pupil Width',
55
+ 'Dikablis Glasses 3_Eye Data_Original_Left Eye_Fixations_Fixations',
56
+ 'Dikablis Glasses 3_Eye Data_Original_Left Eye_Fixations_Fixations Duration',
57
+ 'Dikablis Glasses 3_Eye Data_Original_Left Eye_Saccades_Saccades',
58
+ 'Dikablis Glasses 3_Eye Data_Original_Left Eye_Saccades_Saccades Duration',
59
+ 'Dikablis Glasses 3_Eye Data_Original_Left Eye_Saccades_Saccades Angle',
60
+ 'Dikablis Glasses 3_Eye Data_Original_Right Eye_Pupil X',
61
+ 'Dikablis Glasses 3_Eye Data_Original_Right Eye_Pupil Y',
62
+ 'Dikablis Glasses 3_Eye Data_Original_Right Eye_Pupil Area',
63
+ 'Dikablis Glasses 3_Eye Data_Original_Right Eye_Pupil Height',
64
+ 'Dikablis Glasses 3_Eye Data_Original_Right Eye_Pupil Width',
65
+ 'Dikablis Glasses 3_Eye Data_Original_Right Eye_Fixations_Fixations',
66
+ 'Dikablis Glasses 3_Eye Data_Original_Right Eye_Fixations_Fixations Duration',
67
+ 'Dikablis Glasses 3_Eye Data_Original_Right Eye_Saccades_Saccades',
68
+ 'Dikablis Glasses 3_Eye Data_Original_Right Eye_Saccades_Saccades Duration',
69
+ 'Dikablis Glasses 3_Eye Data_Original_Right Eye_Saccades_Saccades Angle',
70
+ 'Dikablis Glasses 3_Field Data_Scene Cam_Original_Gaze_Gaze X',
71
+ 'Dikablis Glasses 3_Field Data_Scene Cam_Original_Gaze_Gaze Y',
72
+ ]
73
+ EYETRACK_EXCLUDED_RECORDINGS = {('v1', 's1'), ('v14', 's8')}
74
+
75
+ SCENE_LABELS = {f's{i}': i - 1 for i in range(1, 9)}
76
+ NUM_CLASSES = 8
77
+
78
+ TRAIN_VOLS = ['v1', 'v2', 'v11', 'v12', 'v13', 'v15', 'v16', 'v17', 'v19', 'v20', 'v21', 'v22', 'v23', 'v24']
79
+ VAL_VOLS = [] # No separate val set; use train for early stopping or cross-val
80
+ TEST_VOLS = ['v25', 'v26', 'v27', 'v3']
81
+
82
+
83
+ def _preprocess_mocap_skeleton(arr, feat_cols):
84
+ """Convert absolute skeleton coords to hip-relative positions + velocity.
85
+
86
+ Input: (T, F) with absolute XYZ + quaternions
87
+ Output: (T, F + N_pos) where N_pos = number of XYZ position features
88
+ [hip-relative features, XYZ velocity]
89
+ """
90
+ col_to_idx = {c: i for i, c in enumerate(feat_cols)}
91
+
92
+ # Find hip position for subtraction
93
+ hip_x_idx = col_to_idx.get('Hips_X')
94
+ hip_y_idx = col_to_idx.get('Hips_Y')
95
+ hip_z_idx = col_to_idx.get('Hips_Z')
96
+ if hip_x_idx is None:
97
+ return arr # No hip joint found, skip preprocessing
98
+
99
+ # Identify all position columns (_X, _Y, _Z)
100
+ x_indices = [i for i, c in enumerate(feat_cols) if c.endswith('_X')]
101
+ y_indices = [i for i, c in enumerate(feat_cols) if c.endswith('_Y')]
102
+ z_indices = [i for i, c in enumerate(feat_cols) if c.endswith('_Z')]
103
+ all_pos_indices = sorted(x_indices + y_indices + z_indices)
104
+
105
+ # 1. Make XYZ positions hip-relative
106
+ arr_rel = arr.copy()
107
+ hip_xyz = arr[:, [hip_x_idx, hip_y_idx, hip_z_idx]] # (T, 3)
108
+ for idx in x_indices:
109
+ arr_rel[:, idx] -= hip_xyz[:, 0]
110
+ for idx in y_indices:
111
+ arr_rel[:, idx] -= hip_xyz[:, 1]
112
+ for idx in z_indices:
113
+ arr_rel[:, idx] -= hip_xyz[:, 2]
114
+
115
+ # 2. Compute velocity of position features only
116
+ pos_data = arr_rel[:, all_pos_indices] # (T, N_pos)
117
+ velocity = np.zeros_like(pos_data)
118
+ velocity[1:] = pos_data[1:] - pos_data[:-1]
119
+
120
+ # 3. Concatenate: [hip-relative features (pos+quat), position velocity]
121
+ return np.concatenate([arr_rel, velocity], axis=1)
122
+
123
+
124
+ def load_modality_array(filepath, modality):
125
+ """Load a modality CSV/TSV/NPY and return numpy_array.
126
+ Returns None if data is corrupted (extreme values or mostly zeros)."""
127
+ # Video features stored as .npy
128
+ if filepath.endswith('.npy'):
129
+ if not os.path.exists(filepath):
130
+ return None
131
+ arr = np.load(filepath).astype(np.float32)
132
+ arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
133
+ return arr
134
+ # Mocap uses TSV with tab separator
135
+ sep = '\t' if filepath.endswith('.tsv') else ','
136
+ df = pd.read_csv(filepath, sep=sep, low_memory=False)
137
+ df.columns = [str(c).strip() for c in df.columns]
138
+ if modality == 'eyetrack':
139
+ parts = os.path.normpath(filepath).split(os.sep)
140
+ if len(parts) >= 3 and (parts[-3], parts[-2]) in EYETRACK_EXCLUDED_RECORDINGS:
141
+ return None
142
+ feat_cols = [c for c in df.columns
143
+ if c not in SKIP_COLS
144
+ and not any(c.endswith(s) for s in SKIP_COL_SUFFIXES)]
145
+ if modality == 'eyetrack':
146
+ feat_cols = [c for c in EYETRACK_CORE_COLS if c in feat_cols]
147
+ if len(feat_cols) != len(EYETRACK_CORE_COLS):
148
+ return None
149
+ sub = df[feat_cols]
150
+ # Coerce non-numeric columns
151
+ obj_cols = sub.select_dtypes(include=['object']).columns
152
+ if len(obj_cols) > 0:
153
+ sub = sub.copy()
154
+ sub[obj_cols] = sub[obj_cols].apply(pd.to_numeric, errors='coerce')
155
+ arr = sub.values.astype(np.float64)
156
+ arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
157
+ # Quality check: reject samples with extreme values (corrupted data)
158
+ max_abs = np.max(np.abs(arr))
159
+ if max_abs > 1e6:
160
+ return None # Corrupted
161
+ # Quality check: reject samples that are mostly zeros (sensor dropout).
162
+ # Pressure and EMG are legitimately zero for long periods (rest, no grip)
163
+ # so we only apply the strict near-total-loss check to the modalities
164
+ # where a flat-zero stream is a clear dropout signal.
165
+ if modality not in ("pressure", "emg"):
166
+ zero_ratio = np.mean(arr == 0.0)
167
+ if zero_ratio > 0.9:
168
+ return None # Near-total data loss
169
+ # Mocap skeleton: convert to hip-relative + velocity
170
+ if modality == 'mocap' and filepath.endswith('.tsv'):
171
+ arr = _preprocess_mocap_skeleton(arr, feat_cols)
172
+ arr = arr.astype(np.float32)
173
+ return arr
174
+
175
+
176
+ class MultimodalSceneDataset(Dataset):
177
+ """Dataset for scene-level classification from multimodal time series."""
178
+
179
+ def __init__(self, volunteers, modalities, downsample=5, stats=None):
180
+ self.modalities = modalities
181
+ self.downsample = downsample
182
+ self.data = []
183
+ self.labels = []
184
+ self.sample_info = []
185
+ self._modality_dims = {}
186
+
187
+ for vol in volunteers:
188
+ vol_dir = os.path.join(DATASET_DIR, vol)
189
+ if not os.path.isdir(vol_dir):
190
+ continue
191
+ for scenario in sorted(os.listdir(vol_dir)):
192
+ scenario_dir = os.path.join(vol_dir, scenario)
193
+ if not os.path.isdir(scenario_dir) or scenario not in SCENE_LABELS:
194
+ continue
195
+ meta_path = os.path.join(scenario_dir, 'alignment_metadata.json')
196
+ if not os.path.exists(meta_path):
197
+ continue
198
+ with open(meta_path) as f:
199
+ meta = json.load(f)
200
+ available = set(meta['modalities'])
201
+ if not set(modalities).issubset(available):
202
+ continue
203
+
204
+ parts = []
205
+ skip = False
206
+ for mod in modalities:
207
+ if mod == 'mocap':
208
+ # Skeleton data: aligned_{vol}{scene}_s_Q.tsv
209
+ tsv_name = f"aligned_{vol}{scenario}_s_Q.tsv"
210
+ filepath = os.path.join(scenario_dir, tsv_name)
211
+ else:
212
+ filepath = os.path.join(scenario_dir, MODALITY_FILES[mod])
213
+ if not os.path.exists(filepath):
214
+ skip = True
215
+ break
216
+ arr = load_modality_array(filepath, mod)
217
+ if arr is None:
218
+ print(f" SKIP {vol}/{scenario} {mod}: corrupted data", flush=True)
219
+ skip = True
220
+ break
221
+ # Validate dimension consistency
222
+ if mod in self._modality_dims and arr.shape[1] != self._modality_dims[mod]:
223
+ print(f" WARNING: {vol}/{scenario} {mod} dim {arr.shape[1]} "
224
+ f"!= expected {self._modality_dims[mod]}, padding/truncating",
225
+ flush=True)
226
+ expected = self._modality_dims[mod]
227
+ if arr.shape[1] < expected:
228
+ pad = np.zeros((arr.shape[0], expected - arr.shape[1]), dtype=np.float32)
229
+ arr = np.concatenate([arr, pad], axis=1)
230
+ else:
231
+ arr = arr[:, :expected]
232
+ if mod not in self._modality_dims:
233
+ self._modality_dims[mod] = arr.shape[1]
234
+ parts.append(arr)
235
+
236
+ if skip:
237
+ continue
238
+
239
+ min_len = min(p.shape[0] for p in parts)
240
+ parts = [p[:min_len] for p in parts]
241
+ combined = np.concatenate(parts, axis=1)
242
+ combined = combined[::downsample]
243
+
244
+ self.data.append(combined)
245
+ self.labels.append(SCENE_LABELS[scenario])
246
+ self.sample_info.append(f"{vol}/{scenario}")
247
+
248
+ print(f" Loaded {len(self.data)} samples, modality dims: {self._modality_dims}, "
249
+ f"total feat dim: {sum(self._modality_dims.values())}", flush=True)
250
+
251
+ # Normalization (compute in float64 to avoid overflow)
252
+ if stats is not None:
253
+ self.mean, self.std = stats
254
+ else:
255
+ self._compute_stats()
256
+ for i in range(len(self.data)):
257
+ self.data[i] = ((self.data[i].astype(np.float64) - self.mean) / self.std).astype(np.float32)
258
+ self.data[i] = np.nan_to_num(self.data[i], nan=0.0, posinf=0.0, neginf=0.0)
259
+
260
+ def _compute_stats(self):
261
+ # Use float64 for accumulation to prevent overflow
262
+ all_frames = np.concatenate(self.data, axis=0).astype(np.float64)
263
+ self.mean = np.mean(all_frames, axis=0, keepdims=True)
264
+ self.std = np.std(all_frames, axis=0, keepdims=True)
265
+ self.std[self.std < 1e-8] = 1.0
266
+
267
+ def get_stats(self):
268
+ return (self.mean, self.std)
269
+
270
+ @property
271
+ def feat_dim(self):
272
+ return sum(self._modality_dims.values())
273
+
274
+ @property
275
+ def modality_dims(self):
276
+ return dict(self._modality_dims)
277
+
278
+ def get_class_weights(self):
279
+ counts = np.bincount(self.labels, minlength=NUM_CLASSES).astype(np.float32)
280
+ counts[counts == 0] = 1.0
281
+ weights = 1.0 / counts
282
+ weights = weights / weights.sum() * NUM_CLASSES
283
+ return torch.FloatTensor(weights)
284
+
285
+ def __len__(self):
286
+ return len(self.data)
287
+
288
+ def __getitem__(self, idx):
289
+ return torch.from_numpy(self.data[idx]), self.labels[idx]
290
+
291
+
292
+ def collate_fn(batch):
293
+ """Pad variable-length sequences and create masks."""
294
+ sequences, labels = zip(*batch)
295
+ lengths = torch.LongTensor([s.shape[0] for s in sequences])
296
+ padded = pad_sequence(sequences, batch_first=True, padding_value=0.0)
297
+ max_len = padded.shape[1]
298
+ mask = torch.arange(max_len).unsqueeze(0) < lengths.unsqueeze(1)
299
+ labels = torch.LongTensor(labels)
300
+ return padded, labels, mask, lengths
301
+
302
+
303
+ def get_dataloaders(modalities, batch_size=16, downsample=5, num_workers=0):
304
+ """Create train/val/test DataLoaders with proper normalization."""
305
+ print("Loading training data...", flush=True)
306
+ train_ds = MultimodalSceneDataset(TRAIN_VOLS, modalities, downsample)
307
+ stats = train_ds.get_stats()
308
+
309
+ print("Loading validation data...", flush=True)
310
+ val_ds = MultimodalSceneDataset(VAL_VOLS, modalities, downsample, stats=stats)
311
+
312
+ print("Loading test data...", flush=True)
313
+ test_ds = MultimodalSceneDataset(TEST_VOLS, modalities, downsample, stats=stats)
314
+
315
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
316
+ collate_fn=collate_fn, num_workers=num_workers,
317
+ drop_last=False)
318
+ val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
319
+ collate_fn=collate_fn, num_workers=num_workers)
320
+ test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False,
321
+ collate_fn=collate_fn, num_workers=num_workers)
322
+
323
+ info = {
324
+ 'feat_dim': train_ds.feat_dim,
325
+ 'modality_dims': train_ds.modality_dims,
326
+ 'num_classes': NUM_CLASSES,
327
+ 'train_size': len(train_ds),
328
+ 'val_size': len(val_ds),
329
+ 'test_size': len(test_ds),
330
+ 'class_weights': train_ds.get_class_weights(),
331
+ }
332
+ return train_loader, val_loader, test_loader, info
experiments/data/dataset_forecast.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Frame-level future motor-primitive forecasting dataset.
2
+
3
+ Task definition
4
+ ---------------
5
+ At a sampled anchor time t in a recording:
6
+ past = sensor frames over [t - T_obs, t] ← input
7
+ future = per-frame verb_fine labels over (t, t + T_fut] ← target
8
+
9
+ We use NUM_VERB_FINE (= 17) as a sentinel "idle / no segment" class for
10
+ frames not covered by any annotated segment, so every future frame has a
11
+ valid label (output cardinality = NUM_VERB_FINE + 1 = 18).
12
+
13
+ Anchors are sampled at fixed stride within each recording so the model
14
+ sees both intra-segment future (mostly stationary) and across-boundary
15
+ future (where the next-action label changes — the interesting cases).
16
+ """
17
+ from __future__ import annotations
18
+
19
+ import os
20
+ import sys
21
+ from pathlib import Path
22
+ from typing import Dict, List, Optional, Sequence, Tuple
23
+
24
+ import numpy as np
25
+ import torch
26
+ from torch.utils.data import Dataset
27
+
28
+ THIS = Path(__file__).resolve()
29
+ sys.path.insert(0, str(THIS.parent))
30
+ sys.path.insert(0, str(THIS.parents[1]))
31
+
32
+ try:
33
+ from experiments.dataset_seqpred import (
34
+ SAMPLING_RATE_HZ, _load_recording_sensors, _load_annotations,
35
+ parse_ts_range, TRAIN_VOLS_V3, TEST_VOLS_V3,
36
+ DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR,
37
+ )
38
+ from experiments.taxonomy import (
39
+ classify_segment, NUM_VERB_FINE,
40
+ )
41
+ except ModuleNotFoundError:
42
+ from dataset_seqpred import (
43
+ SAMPLING_RATE_HZ, _load_recording_sensors, _load_annotations,
44
+ parse_ts_range, TRAIN_VOLS_V3, TEST_VOLS_V3,
45
+ DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR,
46
+ )
47
+ from taxonomy import classify_segment, NUM_VERB_FINE
48
+
49
+
50
+ IDLE_LABEL = NUM_VERB_FINE # = 17, sentinel for "no segment covers this frame"
51
+ NUM_FORECAST_CLASSES = NUM_VERB_FINE + 1 # = 18
52
+
53
+
54
+ class ForecastDataset(Dataset):
55
+ """Forecast next T_fut seconds of per-frame verb_fine given past T_obs."""
56
+
57
+ def __init__(
58
+ self,
59
+ volunteers: Sequence[str],
60
+ modalities: Sequence[str],
61
+ t_obs_sec: float = 1.5,
62
+ t_fut_sec: float = 0.5,
63
+ anchor_stride_sec: float = 0.25,
64
+ downsample: int = 5,
65
+ dataset_dir: Path = DEFAULT_DATASET_DIR,
66
+ annot_dir: Path = DEFAULT_ANNOT_DIR,
67
+ stats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]] = None,
68
+ expected_dims: Optional[Dict[str, int]] = None,
69
+ contact_only: bool = False,
70
+ contact_threshold_g: float = 5.0,
71
+ log: bool = True,
72
+ ):
73
+ super().__init__()
74
+ self.modalities = list(modalities)
75
+ self.t_obs_sec = float(t_obs_sec)
76
+ self.t_fut_sec = float(t_fut_sec)
77
+ self.anchor_stride_sec = float(anchor_stride_sec)
78
+ self.downsample = int(downsample)
79
+ self.sr = SAMPLING_RATE_HZ // self.downsample
80
+ self.dataset_dir = Path(dataset_dir)
81
+ self.annot_dir = Path(annot_dir)
82
+ self.contact_only = bool(contact_only)
83
+ self.contact_threshold_g = float(contact_threshold_g)
84
+
85
+ # Output time-step counts (after downsample)
86
+ self.T_obs = int(round(self.t_obs_sec * self.sr))
87
+ self.T_fut = int(round(self.t_fut_sec * self.sr))
88
+
89
+ self._items: List[dict] = []
90
+ # Pre-seed modality dims if caller (e.g. test set) provides them
91
+ self._modality_dims: Dict[str, int] = dict(expected_dims) if expected_dims else {}
92
+
93
+ for vol in volunteers:
94
+ vol_dir = self.dataset_dir / vol
95
+ if not vol_dir.is_dir():
96
+ continue
97
+ for scenario_dir in sorted(vol_dir.glob("s*")):
98
+ if not scenario_dir.is_dir():
99
+ continue
100
+ scene = scenario_dir.name
101
+ annot_path = self.annot_dir / vol / f"{scene}.json"
102
+ if not annot_path.exists():
103
+ continue
104
+
105
+ # Always include pressure for the filter, even if model
106
+ # doesn't see it as input. We separate "filter sensors"
107
+ # (load_mods) from "model input sensors" (self.modalities).
108
+ load_mods = list(dict.fromkeys(list(self.modalities) + ["pressure"]))
109
+ try:
110
+ sensors_all = _load_recording_sensors(
111
+ scenario_dir, vol, scene, load_mods
112
+ )
113
+ except Exception:
114
+ continue
115
+ if sensors_all is None or any(a is None for a in sensors_all.values()):
116
+ continue
117
+ pressure_full = sensors_all.get("pressure") # (T, 50)
118
+ # Subset to model-input modalities for everything downstream
119
+ sensors = {m: sensors_all[m] for m in self.modalities}
120
+
121
+ # Track modality dim consistency
122
+ for m, arr in sensors.items():
123
+ if m in self._modality_dims:
124
+ target = self._modality_dims[m]
125
+ if arr.shape[1] != target:
126
+ if arr.shape[1] < target:
127
+ pad = np.zeros((arr.shape[0], target - arr.shape[1]),
128
+ dtype=np.float32)
129
+ sensors[m] = np.concatenate([arr, pad], axis=1)
130
+ else:
131
+ sensors[m] = arr[:, :target]
132
+ else:
133
+ self._modality_dims[m] = arr.shape[1]
134
+
135
+ T_avail = min(a.shape[0] for a in sensors.values())
136
+ if T_avail < (self.T_obs + self.T_fut) * self.downsample:
137
+ continue
138
+
139
+ # Build per-frame verb_fine timeline at full 100 Hz
140
+ timeline = np.full(T_avail, IDLE_LABEL, dtype=np.int64)
141
+ segs = _load_annotations(annot_path)
142
+ for seg in segs:
143
+ a = seg.get("action_annotation", {})
144
+ labels = classify_segment(a)
145
+ if labels is None:
146
+ continue
147
+ start_sec, end_sec = parse_ts_range(seg.get("timestamp", ""))
148
+ s = int(round(start_sec * SAMPLING_RATE_HZ))
149
+ e = int(round(end_sec * SAMPLING_RATE_HZ))
150
+ s = max(0, s); e = min(T_avail, e)
151
+ if e > s:
152
+ timeline[s:e] = labels["verb_fine"]
153
+
154
+ # Downsample timeline to 20 Hz
155
+ timeline_ds = timeline[::self.downsample]
156
+ T_ds = len(timeline_ds)
157
+
158
+ # Downsample sensors to 20 Hz (kept as full record;
159
+ # we'll slice windows below)
160
+ sensors_ds = {m: arr[::self.downsample] for m, arr in sensors.items()}
161
+
162
+ # Build contact mask at 20 Hz (per-frame): is pressure-sum > thr?
163
+ # Pressure is 50 channels; we follow the T2 contact convention
164
+ # (sum across all fingertips and threshold at 5 g).
165
+ if pressure_full is not None:
166
+ pressure_ds = pressure_full[::self.downsample]
167
+ contact_ds = pressure_ds.sum(axis=1) > self.contact_threshold_g
168
+ else:
169
+ contact_ds = np.zeros(T_ds, dtype=bool)
170
+
171
+ # Sample anchors at fixed stride (in 20 Hz frames)
172
+ stride = max(1, int(round(self.anchor_stride_sec * self.sr)))
173
+ first_anchor = self.T_obs
174
+ last_anchor = T_ds - self.T_fut
175
+ if last_anchor <= first_anchor:
176
+ continue
177
+
178
+ for anchor in range(first_anchor, last_anchor + 1, stride):
179
+ # contact-rich filter: any contact frame in past or future window?
180
+ if self.contact_only:
181
+ win = contact_ds[max(0, anchor - self.T_obs):
182
+ min(T_ds, anchor + self.T_fut)]
183
+ if not win.any():
184
+ continue
185
+ past_slice = {m: arr[anchor - self.T_obs:anchor]
186
+ for m, arr in sensors_ds.items()}
187
+ fut_labels = timeline_ds[anchor:anchor + self.T_fut].copy()
188
+ # length sanity
189
+ if any(w.shape[0] != self.T_obs for w in past_slice.values()):
190
+ continue
191
+ if fut_labels.shape[0] != self.T_fut:
192
+ continue
193
+ self._items.append({
194
+ "x": past_slice, # dict[mod] -> (T_obs, F_mod)
195
+ "y_seq": fut_labels, # (T_fut,) int in [0..17]
196
+ "meta": {"vol": vol, "scene": scene, "anchor_idx": int(anchor)},
197
+ })
198
+
199
+ if not self._items:
200
+ raise RuntimeError("ForecastDataset: collected 0 anchors. Check annot_dir / modalities.")
201
+
202
+ # Per-modality z-score using training stats
203
+ if stats is None:
204
+ stats = self._compute_stats()
205
+ self._stats = stats
206
+ self._apply_stats(stats)
207
+
208
+ if log:
209
+ print(f"[ForecastDataset] vols={len(volunteers)} "
210
+ f"anchors={len(self._items)} "
211
+ f"T_obs={self.T_obs} T_fut={self.T_fut} "
212
+ f"contact_only={self.contact_only} "
213
+ f"modality_dims={self._modality_dims} "
214
+ f"sr={self.sr}Hz", flush=True)
215
+
216
+ # ----- Stats / normalization -----
217
+ def _compute_stats(self) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
218
+ accs = {m: [] for m in self._modality_dims}
219
+ for it in self._items:
220
+ for m, w in it["x"].items():
221
+ accs[m].append(w)
222
+ out = {}
223
+ for m, ws in accs.items():
224
+ cat = np.concatenate(ws, axis=0)
225
+ mu = cat.mean(axis=0)
226
+ sd = cat.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd)
227
+ out[m] = (mu.astype(np.float32), sd.astype(np.float32))
228
+ return out
229
+
230
+ def _apply_stats(self, stats):
231
+ for it in self._items:
232
+ for m, w in it["x"].items():
233
+ if m in stats:
234
+ mu, sd = stats[m]
235
+ it["x"][m] = ((w - mu) / sd).astype(np.float32)
236
+
237
+ # ----- Dataset protocol -----
238
+ def __len__(self):
239
+ return len(self._items)
240
+
241
+ def __getitem__(self, idx):
242
+ it = self._items[idx]
243
+ x = {m: torch.from_numpy(np.ascontiguousarray(w)) for m, w in it["x"].items()}
244
+ y_seq = torch.from_numpy(np.ascontiguousarray(it["y_seq"])) # (T_fut,)
245
+ return x, y_seq, it["meta"]
246
+
247
+ @property
248
+ def modality_dims(self):
249
+ return dict(self._modality_dims)
250
+
251
+ def class_freq(self) -> np.ndarray:
252
+ c = np.zeros(NUM_FORECAST_CLASSES, dtype=np.int64)
253
+ for it in self._items:
254
+ for v in it["y_seq"]:
255
+ c[int(v)] += 1
256
+ return c
257
+
258
+
259
+ def collate_forecast(batch):
260
+ """Stack (x_dict, y_seq, meta) -> batched tensors. All samples share T_obs/T_fut."""
261
+ xs, ys, metas = zip(*batch)
262
+ B = len(batch)
263
+ mods = list(xs[0].keys())
264
+ x_out: Dict[str, torch.Tensor] = {}
265
+ for m in mods:
266
+ x_out[m] = torch.stack([x[m] for x in xs], dim=0) # (B, T_obs, F_mod)
267
+ y_out = torch.stack(ys, dim=0) # (B, T_fut)
268
+ return x_out, y_out, list(metas)
269
+
270
+
271
+ def build_train_test(
272
+ modalities: Sequence[str],
273
+ t_obs_sec: float = 1.5,
274
+ t_fut_sec: float = 0.5,
275
+ anchor_stride_sec: float = 0.25,
276
+ downsample: int = 5,
277
+ dataset_dir: Path = DEFAULT_DATASET_DIR,
278
+ annot_dir: Path = DEFAULT_ANNOT_DIR,
279
+ contact_only: bool = False,
280
+ contact_threshold_g: float = 5.0,
281
+ ):
282
+ train = ForecastDataset(
283
+ TRAIN_VOLS_V3, modalities=modalities,
284
+ t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec,
285
+ anchor_stride_sec=anchor_stride_sec, downsample=downsample,
286
+ dataset_dir=dataset_dir, annot_dir=annot_dir,
287
+ contact_only=contact_only, contact_threshold_g=contact_threshold_g,
288
+ stats=None, log=True,
289
+ )
290
+ test = ForecastDataset(
291
+ TEST_VOLS_V3, modalities=modalities,
292
+ t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec,
293
+ anchor_stride_sec=anchor_stride_sec, downsample=downsample,
294
+ dataset_dir=dataset_dir, annot_dir=annot_dir,
295
+ contact_only=contact_only, contact_threshold_g=contact_threshold_g,
296
+ stats=train._stats, expected_dims=train._modality_dims, log=True,
297
+ )
298
+ return train, test
299
+
300
+
301
+ if __name__ == "__main__":
302
+ import argparse
303
+ ap = argparse.ArgumentParser()
304
+ ap.add_argument("--modalities", type=str, default="imu,emg,eyetrack,mocap,pressure")
305
+ ap.add_argument("--t_obs", type=float, default=1.5)
306
+ ap.add_argument("--t_fut", type=float, default=0.5)
307
+ ap.add_argument("--stride", type=float, default=0.25)
308
+ args = ap.parse_args()
309
+ mods = args.modalities.split(",")
310
+ tr, te = build_train_test(
311
+ modalities=mods,
312
+ t_obs_sec=args.t_obs, t_fut_sec=args.t_fut,
313
+ anchor_stride_sec=args.stride,
314
+ )
315
+ print(f"\nTrain={len(tr)} Test={len(te)} T_obs={tr.T_obs} T_fut={tr.T_fut}")
316
+ print(f"Train class freq:\n{tr.class_freq()}")
317
+ print(f"Test class freq:\n{te.class_freq()}")
318
+ x, y, meta = tr[0]
319
+ print(f"Sample: x={ {m: tuple(v.shape) for m,v in x.items()} } y_seq={tuple(y.shape)}")
experiments/data/dataset_grasp_state.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Anchor-based binary "is_grasping" classification dataset (T5 v3 / TGSR).
2
+
3
+ At each sampled anchor t in a recording:
4
+ past = sensor frames over [t - T_obs, t] ← input
5
+ label = majority vote of grasp-annotation mask over (t, t+T_fut] ← binary class
6
+
7
+ Ground-truth source: annotations_v3 verb segments. A frame is marked
8
+ "is_grasp" if it falls inside a segment whose action_name belongs to
9
+ GRASP_VERBS (set below). The label is annotation-derived, completely
10
+ independent of pressure — so adding/removing pressure as input does
11
+ NOT leak the label.
12
+
13
+ This is the cleanest test of "does pressure improve recognition of
14
+ object-interaction state when human-annotated grasp segments are GT?"
15
+ """
16
+ from __future__ import annotations
17
+
18
+ import json
19
+ import sys
20
+ from pathlib import Path
21
+ from typing import Dict, List, Optional, Sequence, Tuple
22
+
23
+ import numpy as np
24
+ import torch
25
+ from torch.utils.data import Dataset
26
+
27
+ THIS = Path(__file__).resolve()
28
+ sys.path.insert(0, str(THIS.parent))
29
+ sys.path.insert(0, str(THIS.parents[1]))
30
+
31
+ try:
32
+ from experiments.dataset_seqpred import (
33
+ SAMPLING_RATE_HZ, _load_recording_sensors,
34
+ TRAIN_VOLS_V3, TEST_VOLS_V3,
35
+ DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR,
36
+ )
37
+ except ModuleNotFoundError:
38
+ from dataset_seqpred import (
39
+ SAMPLING_RATE_HZ, _load_recording_sensors,
40
+ TRAIN_VOLS_V3, TEST_VOLS_V3,
41
+ DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR,
42
+ )
43
+
44
+
45
+ GRASP_VERBS = {
46
+ "grasp", "hold", "pick_up", "move", "place", "put_down",
47
+ "pull", "rotate", "insert", "remove",
48
+ }
49
+ # User-specified subset of action verbs that mean "the object has been lifted
50
+ # off its resting surface and held in hand" (used as Class 2 stricter definition).
51
+ LIFT_VERBS = {"grasp", "open", "move", "pick_up", "hold"}
52
+
53
+ # Multi-class verb taxonomy (annotations_v3 verb_fine universe).
54
+ # Verb 0 = background (anchor outside any segment).
55
+ VERB_LIST = [
56
+ "background",
57
+ "grasp", "move", "place", "adjust", "pick_up",
58
+ "close", "put_down", "pull", "hold", "open",
59
+ "rotate", "release", "push", "insert", "remove",
60
+ "align", "stabilize",
61
+ ]
62
+ VERB_TO_IDX = {v: i for i, v in enumerate(VERB_LIST)}
63
+
64
+ # Top-15 most common object categories with non-zero coverage in the
65
+ # pressure-bearing test set (annotations_v3 survey of TRAIN+TEST_VOLS_V3).
66
+ # Index 0 = "_other": anchor outside any segment OR object not in top-15.
67
+ # Note: "coat" excluded because it appears only in v14, which has no
68
+ # pressure-aligned sessions and is silently dropped by the loader.
69
+ OBJECT_TOP_LIST = [
70
+ "_other",
71
+ "sealed jar", "towel", "tablecloth", "box", "pot",
72
+ "rice bowl", "tape", "pants", "spoon", "plate",
73
+ "marker", "cloth", "laptop", "toothbrush case", "tea canister",
74
+ ]
75
+ OBJECT_TO_IDX = {o: i for i, o in enumerate(OBJECT_TOP_LIST)}
76
+ EVENT_NAMES = {0: "non-contact", 1: "pre-contact", 2: "steady-grip", 3: "release"}
77
+ CLASS_NAMES_BINARY = {0: "non-grasp", 1: "grasp"}
78
+ CLASS_NAMES_THREE = {0: "no-grasp", 1: "attempted", 2: "sustained"}
79
+ # Back-compat default (used by binary code paths)
80
+ CLASS_NAMES = CLASS_NAMES_BINARY
81
+
82
+
83
+ def _parse_one(x: str, fmt_mode: str) -> float:
84
+ p = x.split(":")
85
+ if len(p) == 2:
86
+ return int(p[0]) * 60 + int(p[1])
87
+ if fmt_mode == "hhmmss":
88
+ return int(p[0]) * 3600 + int(p[1]) * 60 + int(p[2])
89
+ return int(p[0]) * 60 + int(p[1]) + int(p[2]) / 30.0 # mmssff @ 30fps
90
+
91
+
92
+ def _detect_fmt(segments, rec_sec: float) -> str:
93
+ for s in segments:
94
+ b = s["timestamp"].split("-")[1]
95
+ p = b.split(":")
96
+ if len(p) == 3:
97
+ hh = int(p[0]) * 3600 + int(p[1]) * 60 + int(p[2])
98
+ if hh > rec_sec * 1.05:
99
+ return "mmssff"
100
+ return "hhmmss"
101
+
102
+
103
+ def build_object_label(annot_path: Path, n_frames: int,
104
+ sr: int = SAMPLING_RATE_HZ) -> np.ndarray:
105
+ """Per-frame object index (top-15 + '_other' fallback as class 0)."""
106
+ label = np.zeros(n_frames, dtype=np.int8)
107
+ if not annot_path.exists():
108
+ return label
109
+ try:
110
+ ann = json.load(open(annot_path))
111
+ except Exception:
112
+ return label
113
+ segments = ann.get("segments", [])
114
+ if not segments:
115
+ return label
116
+ rec_sec = n_frames / sr
117
+ fmt = _detect_fmt(segments, rec_sec)
118
+ for s in segments:
119
+ obj = s.get("action_annotation", {}).get("object_name")
120
+ idx = OBJECT_TO_IDX.get(obj, 0)
121
+ if idx == 0:
122
+ continue # leave as 0 ("_other"/background)
123
+ try:
124
+ a, b = s["timestamp"].split("-")
125
+ t0 = _parse_one(a, fmt); t1 = _parse_one(b, fmt)
126
+ except Exception:
127
+ continue
128
+ if t1 <= t0 or t1 > rec_sec * 1.10:
129
+ continue
130
+ i0 = max(0, int(round(t0 * sr)))
131
+ i1 = min(n_frames, int(round(t1 * sr)))
132
+ label[i0:i1] = idx
133
+ return label
134
+
135
+
136
+ def build_lift_eligible_mask(annot_path: Path, n_frames: int,
137
+ sr: int = SAMPLING_RATE_HZ) -> np.ndarray:
138
+ """Per-frame bool: True if frame is inside a segment that meets the
139
+ lifted-grasp criterion: verb ∈ LIFT_VERBS OR hand_type == 'both'.
140
+ Used by 3-class label_mode when require_lift_for_sustained=True."""
141
+ mask = np.zeros(n_frames, dtype=bool)
142
+ if not annot_path.exists():
143
+ return mask
144
+ try:
145
+ ann = json.load(open(annot_path))
146
+ except Exception:
147
+ return mask
148
+ segments = ann.get("segments", [])
149
+ if not segments:
150
+ return mask
151
+ rec_sec = n_frames / sr
152
+ fmt = _detect_fmt(segments, rec_sec)
153
+ for s in segments:
154
+ a = s.get("action_annotation", {})
155
+ verb = a.get("action_name")
156
+ hand = a.get("hand_type", "")
157
+ is_lift = (verb in LIFT_VERBS) or (hand == "both")
158
+ if not is_lift:
159
+ continue
160
+ try:
161
+ ts0, ts1 = s["timestamp"].split("-")
162
+ t0 = _parse_one(ts0, fmt); t1 = _parse_one(ts1, fmt)
163
+ except Exception:
164
+ continue
165
+ if t1 <= t0 or t1 > rec_sec * 1.10:
166
+ continue
167
+ i0 = max(0, int(round(t0 * sr)))
168
+ i1 = min(n_frames, int(round(t1 * sr)))
169
+ mask[i0:i1] = True
170
+ return mask
171
+
172
+
173
+ def build_verb_label(annot_path: Path, n_frames: int,
174
+ sr: int = SAMPLING_RATE_HZ) -> np.ndarray:
175
+ """Per-frame verb index (int8). Default (no segment) = 0 (background)."""
176
+ label = np.zeros(n_frames, dtype=np.int8)
177
+ if not annot_path.exists():
178
+ return label
179
+ try:
180
+ ann = json.load(open(annot_path))
181
+ except Exception:
182
+ return label
183
+ segments = ann.get("segments", [])
184
+ if not segments:
185
+ return label
186
+ rec_sec = n_frames / sr
187
+ fmt = _detect_fmt(segments, rec_sec)
188
+ for s in segments:
189
+ verb = s.get("action_annotation", {}).get("action_name")
190
+ v_idx = VERB_TO_IDX.get(verb, 0) # unknown verb → background
191
+ if v_idx == 0:
192
+ continue
193
+ try:
194
+ a, b = s["timestamp"].split("-")
195
+ t0 = _parse_one(a, fmt); t1 = _parse_one(b, fmt)
196
+ except Exception:
197
+ continue
198
+ if t1 <= t0 or t1 > rec_sec * 1.10:
199
+ continue
200
+ i0 = max(0, int(round(t0 * sr)))
201
+ i1 = min(n_frames, int(round(t1 * sr)))
202
+ label[i0:i1] = v_idx
203
+ return label
204
+
205
+
206
+ def build_grasp_mask(annot_path: Path, n_frames: int,
207
+ sr: int = SAMPLING_RATE_HZ) -> np.ndarray:
208
+ """Return bool array of shape (n_frames,)."""
209
+ mask = np.zeros(n_frames, dtype=bool)
210
+ if not annot_path.exists():
211
+ return mask
212
+ try:
213
+ ann = json.load(open(annot_path))
214
+ except Exception:
215
+ return mask
216
+ segments = ann.get("segments", [])
217
+ if not segments:
218
+ return mask
219
+ rec_sec = n_frames / sr
220
+ fmt = _detect_fmt(segments, rec_sec)
221
+ for s in segments:
222
+ verb = s.get("action_annotation", {}).get("action_name")
223
+ if verb not in GRASP_VERBS:
224
+ continue
225
+ try:
226
+ a, b = s["timestamp"].split("-")
227
+ t0 = _parse_one(a, fmt); t1 = _parse_one(b, fmt)
228
+ except Exception:
229
+ continue
230
+ if t1 <= t0 or t1 > rec_sec * 1.10:
231
+ continue
232
+ i0 = max(0, int(round(t0 * sr)))
233
+ i1 = min(n_frames, int(round(t1 * sr)))
234
+ mask[i0:i1] = True
235
+ return mask
236
+
237
+
238
+ class GraspStateDataset(Dataset):
239
+ """Predict binary 'is_grasping' label over future window from past sensor signals."""
240
+
241
+ def __init__(
242
+ self,
243
+ volunteers: Sequence[str],
244
+ input_modalities: Sequence[str],
245
+ t_obs_sec: float = 1.0,
246
+ t_fut_sec: float = 0.5,
247
+ anchor_stride_sec: float = 0.25,
248
+ downsample: int = 5,
249
+ dataset_dir: Path = DEFAULT_DATASET_DIR,
250
+ annot_dir: Path = DEFAULT_ANNOT_DIR,
251
+ contact_threshold_g: float = 5.0, # legacy sum-threshold (kept for back-compat, unused if use_per_cell_contact=True)
252
+ per_cell_threshold_g: float = 10.0, # per-cell threshold to declare a sensor cell "active"
253
+ min_active_cells: int = 3, # need ≥ this many active cells to declare contact
254
+ use_per_cell_contact: bool = True, # NEW default: use per-cell active-count for event_type
255
+ label_mode: str = "binary", # "binary", "three_class", or "verb"
256
+ sustained_threshold_sec: float = 0.3, # (3-class only) min contiguous contact for "Sustained"
257
+ require_lift_for_sustained: bool = False, # (3-class only) Class 2 also requires verb ∈ LIFT_VERBS
258
+ per_class_max: Optional[int] = None,
259
+ input_stats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]] = None,
260
+ expected_input_dims: Optional[Dict[str, int]] = None,
261
+ majority_threshold: float = 0.5,
262
+ rng_seed: int = 0,
263
+ log: bool = True,
264
+ ):
265
+ super().__init__()
266
+ self.input_modalities = list(input_modalities)
267
+ self.t_obs_sec = float(t_obs_sec)
268
+ self.t_fut_sec = float(t_fut_sec)
269
+ self.anchor_stride_sec = float(anchor_stride_sec)
270
+ self.downsample = int(downsample)
271
+ self.sr = SAMPLING_RATE_HZ // self.downsample
272
+ self.dataset_dir = Path(dataset_dir)
273
+ self.annot_dir = Path(annot_dir)
274
+ self.contact_threshold_g = float(contact_threshold_g)
275
+ self.per_cell_threshold_g = float(per_cell_threshold_g)
276
+ self.min_active_cells = int(min_active_cells)
277
+ self.use_per_cell_contact = bool(use_per_cell_contact)
278
+ self.label_mode = str(label_mode)
279
+ if self.label_mode not in ("binary", "three_class", "verb", "object"):
280
+ raise ValueError(f"label_mode must be binary|three_class|verb|object, got {label_mode}")
281
+ if self.label_mode == "binary":
282
+ self.num_classes = 2
283
+ elif self.label_mode == "three_class":
284
+ self.num_classes = 3
285
+ elif self.label_mode == "verb":
286
+ self.num_classes = len(VERB_LIST)
287
+ else: # object
288
+ self.num_classes = len(OBJECT_TOP_LIST)
289
+ self.sustained_threshold_sec = float(sustained_threshold_sec)
290
+ self.require_lift_for_sustained = bool(require_lift_for_sustained)
291
+ self.per_class_max = per_class_max
292
+ self.majority_threshold = float(majority_threshold)
293
+ self.T_obs = int(round(self.t_obs_sec * self.sr))
294
+ self.T_fut = int(round(self.t_fut_sec * self.sr))
295
+
296
+ self._items: List[dict] = []
297
+ self._modality_dims: Dict[str, int] = dict(expected_input_dims) if expected_input_dims else {}
298
+ rng = np.random.default_rng(rng_seed)
299
+
300
+ # Load pressure even if not in inputs, for event_type stratification.
301
+ load_mods = list(dict.fromkeys(list(self.input_modalities) + ["pressure"]))
302
+
303
+ # Per-class anchor pool
304
+ pools: Dict[int, List[dict]] = {c: [] for c in range(self.num_classes)}
305
+ sustained_thresh_frames = int(round(self.sustained_threshold_sec * self.sr))
306
+
307
+ for vol in volunteers:
308
+ vol_dir = self.dataset_dir / vol
309
+ if not vol_dir.is_dir():
310
+ continue
311
+ for scenario_dir in sorted(vol_dir.glob("s*")):
312
+ if not scenario_dir.is_dir():
313
+ continue
314
+ scene = scenario_dir.name
315
+ annot_path = self.annot_dir / vol / f"{scene}.json"
316
+ if not annot_path.exists():
317
+ continue
318
+ try:
319
+ sensors_all = _load_recording_sensors(
320
+ scenario_dir, vol, scene, load_mods
321
+ )
322
+ except Exception:
323
+ continue
324
+ if sensors_all is None or any(a is None for a in sensors_all.values()):
325
+ continue
326
+
327
+ pressure_full = sensors_all["pressure"] # (T, 50)
328
+ input_arrs = {m: sensors_all[m] for m in self.input_modalities}
329
+ for m, arr in input_arrs.items():
330
+ self._enforce_dim(input_arrs, m, arr, self._modality_dims)
331
+
332
+ T_avail = min(a.shape[0] for a in input_arrs.values())
333
+ T_avail = min(T_avail, pressure_full.shape[0])
334
+ if T_avail < (self.T_obs + self.T_fut) * self.downsample:
335
+ continue
336
+
337
+ # Build grasp mask at 100 Hz, then downsample.
338
+ mask_full = build_grasp_mask(annot_path, T_avail,
339
+ sr=SAMPLING_RATE_HZ)
340
+ if self.label_mode == "verb":
341
+ verb_full = build_verb_label(annot_path, T_avail, sr=SAMPLING_RATE_HZ)
342
+ verb_ds = verb_full[:T_avail:self.downsample]
343
+ else:
344
+ verb_ds = None
345
+ if self.label_mode == "object":
346
+ obj_full = build_object_label(annot_path, T_avail, sr=SAMPLING_RATE_HZ)
347
+ obj_ds = obj_full[:T_avail:self.downsample]
348
+ else:
349
+ obj_ds = None
350
+ if self.label_mode == "three_class" and self.require_lift_for_sustained:
351
+ lift_full = build_lift_eligible_mask(annot_path, T_avail, sr=SAMPLING_RATE_HZ)
352
+ lift_eligible_ds = lift_full[:T_avail:self.downsample]
353
+ else:
354
+ lift_eligible_ds = None
355
+ input_ds = {m: arr[:T_avail:self.downsample] for m, arr in input_arrs.items()}
356
+ pressure_ds = pressure_full[:T_avail:self.downsample]
357
+ mask_ds = mask_full[:T_avail:self.downsample]
358
+ T_ds = mask_ds.shape[0]
359
+ if self.use_per_cell_contact:
360
+ # n_active per frame: count cells with value > per_cell_threshold_g
361
+ n_active = (pressure_ds > self.per_cell_threshold_g).sum(axis=1)
362
+ contact_frame = n_active >= self.min_active_cells
363
+ else:
364
+ pressure_sum = pressure_ds.sum(axis=1)
365
+ contact_frame = pressure_sum > self.contact_threshold_g
366
+
367
+ stride = max(1, int(round(self.anchor_stride_sec * self.sr)))
368
+ first_anchor = self.T_obs
369
+ last_anchor = T_ds - self.T_fut
370
+ if last_anchor <= first_anchor:
371
+ continue
372
+
373
+ for anchor in range(first_anchor, last_anchor + 1, stride):
374
+ fut_mask = mask_ds[anchor:anchor + self.T_fut]
375
+ if fut_mask.shape[0] != self.T_fut:
376
+ continue
377
+ annotation_is_grasp = fut_mask.mean() >= self.majority_threshold
378
+
379
+ if self.label_mode == "binary":
380
+ label = int(annotation_is_grasp)
381
+ elif self.label_mode == "three_class":
382
+ if not annotation_is_grasp:
383
+ label = 0 # NoGrasp
384
+ else:
385
+ # longest contiguous run of contact in future window
386
+ fut_contact = contact_frame[anchor:anchor + self.T_fut]
387
+ longest = 0; cur = 0
388
+ for v in fut_contact:
389
+ if v: cur += 1; longest = max(longest, cur)
390
+ else: cur = 0
391
+ is_sustained = longest >= sustained_thresh_frames
392
+ if is_sustained and self.require_lift_for_sustained:
393
+ # Demote to Class 1 unless majority of future window is in
394
+ # a "lift-eligible" segment (verb ∈ LIFT_VERBS or hand=both).
395
+ fut_lift = lift_eligible_ds[anchor:anchor + self.T_fut]
396
+ if fut_lift.mean() < 0.5:
397
+ is_sustained = False
398
+ label = 2 if is_sustained else 1
399
+ elif self.label_mode == "verb":
400
+ fut_v = verb_ds[anchor:anchor + self.T_fut]
401
+ counts = np.bincount(fut_v, minlength=self.num_classes)
402
+ label = int(np.argmax(counts))
403
+ else: # object — majority object in future window
404
+ fut_o = obj_ds[anchor:anchor + self.T_fut]
405
+ counts = np.bincount(fut_o, minlength=self.num_classes)
406
+ label = int(np.argmax(counts))
407
+
408
+ # event_type for stratification (4-class transition taxonomy)
409
+ past_high = contact_frame[anchor - self.T_obs:anchor].mean() > 0.5
410
+ fut_high = contact_frame[anchor:anchor + self.T_fut].mean() > 0.5
411
+ if not past_high and not fut_high: et = 0
412
+ elif not past_high and fut_high: et = 1
413
+ elif past_high and fut_high: et = 2
414
+ else: et = 3
415
+
416
+ past_slice = {m: arr[anchor - self.T_obs:anchor]
417
+ for m, arr in input_ds.items()}
418
+ if any(w.shape[0] != self.T_obs for w in past_slice.values()):
419
+ continue
420
+
421
+ item = {
422
+ "x": past_slice,
423
+ "label": label,
424
+ "event_type": et,
425
+ "meta": {"vol": vol, "scene": scene, "anchor_idx": int(anchor)},
426
+ }
427
+ pools[label].append(item)
428
+
429
+ # Balance classes if requested (cap larger pool to per_class_max)
430
+ if self.per_class_max is not None:
431
+ for c, pool in pools.items():
432
+ if len(pool) > self.per_class_max:
433
+ idx = rng.choice(len(pool), size=self.per_class_max, replace=False)
434
+ pools[c] = [pool[i] for i in sorted(idx)]
435
+ self._items = [it for c in range(self.num_classes) for it in pools[c]]
436
+
437
+ if not self._items:
438
+ raise RuntimeError("GraspStateDataset: collected 0 anchors.")
439
+
440
+ # Z-score inputs
441
+ if input_stats is None:
442
+ input_stats = self._compute_input_stats()
443
+ self._input_stats = input_stats
444
+ self._apply_input_stats(input_stats)
445
+
446
+ if log:
447
+ if self.label_mode == "binary":
448
+ class_names = CLASS_NAMES_BINARY
449
+ elif self.label_mode == "three_class":
450
+ class_names = CLASS_NAMES_THREE
451
+ elif self.label_mode == "verb":
452
+ class_names = {i: v for i, v in enumerate(VERB_LIST)}
453
+ else: # object
454
+ class_names = {i: v for i, v in enumerate(OBJECT_TOP_LIST)}
455
+ counts_class = {class_names[c]: sum(1 for it in self._items if it["label"] == c)
456
+ for c in range(self.num_classes)}
457
+ counts_event = {EVENT_NAMES[k]: sum(1 for it in self._items if it["event_type"] == k)
458
+ for k in (0, 1, 2, 3)}
459
+ print(f"[GraspStateDataset] vols={len(volunteers)} "
460
+ f"inputs={self.input_modalities} "
461
+ f"anchors={len(self._items)} class={counts_class} "
462
+ f"event={counts_event} "
463
+ f"T_obs={self.T_obs} T_fut={self.T_fut} sr={self.sr}Hz "
464
+ f"input_dims={self._modality_dims}", flush=True)
465
+
466
+ @staticmethod
467
+ def _enforce_dim(arrs, m, arr, dim_dict):
468
+ if m in dim_dict:
469
+ tgt = dim_dict[m]
470
+ if arr.shape[1] != tgt:
471
+ if arr.shape[1] < tgt:
472
+ pad = np.zeros((arr.shape[0], tgt - arr.shape[1]), dtype=np.float32)
473
+ arrs[m] = np.concatenate([arr, pad], axis=1)
474
+ else:
475
+ arrs[m] = arr[:, :tgt]
476
+ else:
477
+ dim_dict[m] = arr.shape[1]
478
+
479
+ def _compute_input_stats(self):
480
+ accs = {m: [] for m in self._modality_dims}
481
+ for it in self._items:
482
+ for m, w in it["x"].items():
483
+ accs[m].append(w)
484
+ out = {}
485
+ for m, ws in accs.items():
486
+ cat = np.concatenate(ws, axis=0)
487
+ mu = cat.mean(axis=0).astype(np.float32)
488
+ sd = cat.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd)
489
+ out[m] = (mu, sd.astype(np.float32))
490
+ return out
491
+
492
+ def _apply_input_stats(self, stats):
493
+ for it in self._items:
494
+ for m, w in it["x"].items():
495
+ if m in stats:
496
+ mu, sd = stats[m]
497
+ it["x"][m] = ((w - mu) / sd).astype(np.float32)
498
+
499
+ def __len__(self): return len(self._items)
500
+
501
+ def __getitem__(self, idx):
502
+ it = self._items[idx]
503
+ x = {m: torch.from_numpy(np.ascontiguousarray(w)) for m, w in it["x"].items()}
504
+ label = int(it["label"])
505
+ et = int(it["event_type"])
506
+ return x, label, et, it["meta"]
507
+
508
+ @property
509
+ def modality_dims(self): return dict(self._modality_dims)
510
+
511
+
512
+ def collate_grasp_state(batch):
513
+ xs, labels, ets, metas = zip(*batch)
514
+ mods = list(xs[0].keys())
515
+ x_out = {m: torch.stack([x[m] for x in xs], dim=0) for m in mods}
516
+ y_out = torch.tensor(labels, dtype=torch.long)
517
+ et_out = torch.tensor(ets, dtype=torch.long)
518
+ return x_out, y_out, et_out, list(metas)
519
+
520
+
521
+ def build_grasp_train_test(
522
+ input_modalities,
523
+ t_obs_sec=1.0, t_fut_sec=0.5, anchor_stride_sec=0.25,
524
+ downsample=5,
525
+ dataset_dir=DEFAULT_DATASET_DIR, annot_dir=DEFAULT_ANNOT_DIR,
526
+ contact_threshold_g=5.0, per_class_max=None,
527
+ label_mode="binary", sustained_threshold_sec=0.3,
528
+ require_lift_for_sustained=False,
529
+ rng_seed=0,
530
+ train_vols=None, test_vols=None,
531
+ ):
532
+ if train_vols is None: train_vols = TRAIN_VOLS_V3
533
+ if test_vols is None: test_vols = TEST_VOLS_V3
534
+ train = GraspStateDataset(
535
+ train_vols, input_modalities=input_modalities,
536
+ t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec,
537
+ anchor_stride_sec=anchor_stride_sec, downsample=downsample,
538
+ dataset_dir=dataset_dir, annot_dir=annot_dir,
539
+ contact_threshold_g=contact_threshold_g, per_class_max=per_class_max,
540
+ label_mode=label_mode, sustained_threshold_sec=sustained_threshold_sec,
541
+ require_lift_for_sustained=require_lift_for_sustained,
542
+ rng_seed=rng_seed, log=True,
543
+ )
544
+ test = GraspStateDataset(
545
+ test_vols, input_modalities=input_modalities,
546
+ t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec,
547
+ anchor_stride_sec=anchor_stride_sec, downsample=downsample,
548
+ dataset_dir=dataset_dir, annot_dir=annot_dir,
549
+ contact_threshold_g=contact_threshold_g, per_class_max=None, # don't cap test
550
+ label_mode=label_mode, sustained_threshold_sec=sustained_threshold_sec,
551
+ require_lift_for_sustained=require_lift_for_sustained,
552
+ input_stats=train._input_stats,
553
+ expected_input_dims=train._modality_dims,
554
+ rng_seed=rng_seed + 1, log=True,
555
+ )
556
+ return train, test
557
+
558
+
559
+ if __name__ == "__main__":
560
+ import argparse
561
+ ap = argparse.ArgumentParser()
562
+ ap.add_argument("--input_modalities", default="emg,imu,mocap")
563
+ ap.add_argument("--t_obs", type=float, default=1.0)
564
+ ap.add_argument("--t_fut", type=float, default=0.5)
565
+ args = ap.parse_args()
566
+ tr, te = build_grasp_train_test(
567
+ input_modalities=args.input_modalities.split(","),
568
+ t_obs_sec=args.t_obs, t_fut_sec=args.t_fut,
569
+ )
570
+ x, y, et, meta = tr[0]
571
+ print(f"sample: x={ {m: tuple(v.shape) for m,v in x.items()} } y={y} et={et}")
experiments/data/dataset_seqpred.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Segment-to-Next-Segment Triplet Prediction dataset (T10).
3
+
4
+ For every annotated action segment k in every recording:
5
+ anchor_t = start_time(segment_k) - T_fut (seconds)
6
+ observation = sensor frames in [anchor_t - T_obs, anchor_t]
7
+ target = triplet labels of segment_k: (verb_fine, verb_composite,
8
+ noun, hand)
9
+
10
+ Segments whose observation window would spill before t=0 of the recording
11
+ are skipped (no left-padding), so we never mix noise with real sensor data.
12
+
13
+ Strategy A is enforced in taxonomy.classify_segment(): segments whose noun is
14
+ not in the kept set (<50 occurrences) are dropped entirely.
15
+
16
+ Per-modality tensors are returned as a dict so downstream models can either
17
+ concat them (single-flow baselines) or keep them separate (our cross-modal
18
+ fusion model). A float mask is returned alongside the sensor tensor so
19
+ variable-length obs windows can be padded within a batch.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ # pandas must be imported BEFORE torch/numpy to avoid a GLIBCXX load-order bug
25
+ # on this cluster.
26
+ import pandas as pd
27
+
28
+ import json
29
+ import os
30
+ import sys
31
+ from pathlib import Path
32
+ from typing import Dict, List, Optional, Sequence, Tuple
33
+
34
+ import numpy as np
35
+ import torch
36
+ from torch.utils.data import Dataset
37
+
38
+ # Make sibling modules importable from either (a) the neurips26 root, or
39
+ # (b) the frozen row/code/ folder (populated by setup_row.sh).
40
+ _THIS = Path(__file__).resolve()
41
+ sys.path.insert(0, str(_THIS.parent)) # code/ itself
42
+ sys.path.insert(0, str(_THIS.parent.parent)) # neurips26/
43
+
44
+ try:
45
+ from data.dataset import ( # noqa: E402
46
+ MODALITY_FILES, load_modality_array,
47
+ )
48
+ from experiments.taxonomy import ( # noqa: E402
49
+ classify_segment, NOUN, NUM_VERB_FINE, NUM_VERB_COMPOSITE, NUM_NOUN,
50
+ NUM_HAND,
51
+ )
52
+ except ModuleNotFoundError:
53
+ from dataset import ( # noqa: E402
54
+ MODALITY_FILES, load_modality_array,
55
+ )
56
+ from taxonomy import ( # noqa: E402
57
+ classify_segment, NOUN, NUM_VERB_FINE, NUM_VERB_COMPOSITE, NUM_NOUN,
58
+ NUM_HAND,
59
+ )
60
+
61
+ # ---------------------------------------------------------------------------
62
+ # Constants
63
+ # ---------------------------------------------------------------------------
64
+
65
+ # Hard-code the dataset and annotation paths. The frozen row/code/ folders sit
66
+ # at arbitrary depths under the repo, so relative-to-__file__ discovery is
67
+ # unreliable. An env override is available for e.g. running on a mirror.
68
+ REPO = Path(os.environ.get(
69
+ "DAILYACT_REPO", "${PULSE_ROOT}"
70
+ ))
71
+ DEFAULT_DATASET_DIR = REPO / "aligned_gy"
72
+ DEFAULT_ANNOT_DIR = REPO / "annotations_v3"
73
+
74
+ SAMPLING_RATE_HZ = 100
75
+ # 5x downsample -> 20 Hz. Matches the existing pipeline in dataset.py.
76
+ DEFAULT_DOWNSAMPLE = 5
77
+
78
+ VALID_MODALITIES = ("mocap", "emg", "eyetrack", "imu", "pressure")
79
+
80
+ # Fixed subject-independent split. Hand-picked 5 test volunteers with full
81
+ # 8-scene coverage, spread across the ID range. Any volunteer not listed
82
+ # below but annotated in v3 is assumed to be train data (so the lists stay
83
+ # stable as more volunteers get annotated).
84
+ TEST_VOLS_V3 = ["v14", "v30", "v34", "v38", "v41"]
85
+ TRAIN_VOLS_V3 = [
86
+ "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
87
+ "v11", "v12", "v13", "v15", "v16", "v17", "v18", "v19", "v20",
88
+ "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
89
+ "v31", "v32", "v33", "v35", "v36", "v37", "v39", "v40",
90
+ ]
91
+ assert set(TRAIN_VOLS_V3).isdisjoint(TEST_VOLS_V3), "Split must be disjoint"
92
+
93
+
94
+ # ---------------------------------------------------------------------------
95
+ # Helpers
96
+ # ---------------------------------------------------------------------------
97
+
98
+ def _parse_ts(ts: str) -> float:
99
+ """Parse 'HH:MM:SS' or 'MM:SS' (or 'M:S') into seconds."""
100
+ parts = ts.strip().split(":")
101
+ try:
102
+ if len(parts) == 2:
103
+ return float(parts[0]) * 60 + float(parts[1])
104
+ if len(parts) == 3:
105
+ return float(parts[0]) * 3600 + float(parts[1]) * 60 + float(parts[2])
106
+ except ValueError:
107
+ return 0.0
108
+ return 0.0
109
+
110
+
111
+ def parse_ts_range(ts_range: str) -> Tuple[float, float]:
112
+ """Parse 'MM:SS-MM:SS' or 'HH:MM:SS-HH:MM:SS' into (start_sec, end_sec)."""
113
+ if "-" not in ts_range:
114
+ return 0.0, 0.0
115
+ a, b = ts_range.split("-", 1)
116
+ return _parse_ts(a), _parse_ts(b)
117
+
118
+
119
+ def _load_recording_sensors(
120
+ scenario_dir: Path, vol: str, scenario: str,
121
+ modalities: Sequence[str],
122
+ ) -> Optional[Dict[str, np.ndarray]]:
123
+ """Load each requested modality as a (T, F_mod) float32 array at 100 Hz.
124
+
125
+ Returns None if any requested modality is missing or corrupted."""
126
+ out: Dict[str, np.ndarray] = {}
127
+ for mod in modalities:
128
+ if mod == "mocap":
129
+ fp = scenario_dir / f"aligned_{vol}{scenario}_s_Q.tsv"
130
+ else:
131
+ fp = scenario_dir / MODALITY_FILES[mod]
132
+ if not fp.exists():
133
+ return None
134
+ arr = load_modality_array(str(fp), mod)
135
+ if arr is None:
136
+ return None
137
+ out[mod] = arr.astype(np.float32)
138
+ # Align lengths across modalities (take min); all start at sensor t=0.
139
+ T = min(a.shape[0] for a in out.values())
140
+ for m in out:
141
+ out[m] = out[m][:T]
142
+ return out
143
+
144
+
145
+ def _load_annotations(annot_path: Path) -> List[dict]:
146
+ with open(annot_path) as f:
147
+ d = json.load(f)
148
+ return d.get("segments", [])
149
+
150
+
151
+ # ---------------------------------------------------------------------------
152
+ # Dataset
153
+ # ---------------------------------------------------------------------------
154
+
155
+ class TripletSeqPredDataset(Dataset):
156
+ """One sample per (annotated segment, recording) pair.
157
+
158
+ Sample schema returned by __getitem__:
159
+ x: dict {mod_name: FloatTensor(T_frames, F_mod)}
160
+ y: dict {'verb_fine': int, 'verb_composite': int,
161
+ 'noun': int, 'hand': int}
162
+ meta: dict {'vol', 'scene', 'seg_idx', 'anchor_sec'}
163
+ """
164
+
165
+ def __init__(
166
+ self,
167
+ volunteers: Sequence[str],
168
+ modalities: Sequence[str] = ("imu", "mocap", "emg", "eyetrack", "pressure"),
169
+ t_obs_sec: float = 8.0,
170
+ t_fut_sec: float = 2.0,
171
+ downsample: int = DEFAULT_DOWNSAMPLE,
172
+ dataset_dir: Path = DEFAULT_DATASET_DIR,
173
+ annot_dir: Path = DEFAULT_ANNOT_DIR,
174
+ stats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]] = None,
175
+ min_seg_duration_sec: float = 0.4,
176
+ log: bool = True,
177
+ mode: str = "recognition",
178
+ ):
179
+ for m in modalities:
180
+ if m not in VALID_MODALITIES:
181
+ raise ValueError(f"Unknown modality: {m}")
182
+ if mode not in ("recognition", "anticipation"):
183
+ raise ValueError(f"mode must be 'recognition' or 'anticipation', got {mode!r}")
184
+
185
+ self.modalities = tuple(modalities)
186
+ self.t_obs_sec = float(t_obs_sec)
187
+ self.t_fut_sec = float(t_fut_sec)
188
+ self.downsample = int(downsample)
189
+ self.dataset_dir = Path(dataset_dir)
190
+ self.annot_dir = Path(annot_dir)
191
+ self.mode = mode
192
+
193
+ # Effective obs-window length in frames at the post-downsample rate.
194
+ sr = SAMPLING_RATE_HZ // self.downsample # 20 Hz
195
+ self.T_frames = int(round(self.t_obs_sec * sr)) # used only for anticipation
196
+ self._sr_down = sr
197
+
198
+ self._items: List[dict] = []
199
+ self._modality_dims: Dict[str, int] = {}
200
+
201
+ # If re-using training-set stats, force each modality's feature
202
+ # layout to match so we never apply a (14,)-mean to (24,)-data.
203
+ if stats is not None:
204
+ for m, (mu, _) in stats.items():
205
+ self._modality_dims[m] = mu.shape[1]
206
+
207
+ stats_counts = {
208
+ "recordings_scanned": 0,
209
+ "recordings_used": 0,
210
+ "segments_seen": 0,
211
+ "seg_dropped_label": 0, # Strategy A + invalid verb/hand
212
+ "seg_dropped_too_early": 0, # obs window before t=0
213
+ "seg_dropped_short": 0,
214
+ "seg_kept": 0,
215
+ }
216
+
217
+ for vol in volunteers:
218
+ vol_dir = self.dataset_dir / vol
219
+ if not vol_dir.is_dir():
220
+ continue
221
+ for scenario_dir in sorted(vol_dir.glob("s*")):
222
+ if not scenario_dir.is_dir():
223
+ continue
224
+ scene = scenario_dir.name
225
+ if scene not in {f"s{i}" for i in range(1, 9)}:
226
+ continue
227
+
228
+ annot_path = self.annot_dir / vol / f"{scene}.json"
229
+ if not annot_path.exists():
230
+ continue
231
+
232
+ stats_counts["recordings_scanned"] += 1
233
+
234
+ sensors = _load_recording_sensors(scenario_dir, vol, scene,
235
+ self.modalities)
236
+ if sensors is None:
237
+ continue
238
+
239
+ # Store / validate per-modality dim
240
+ for m, arr in sensors.items():
241
+ if m in self._modality_dims:
242
+ if arr.shape[1] != self._modality_dims[m]:
243
+ # Pad or truncate to match the first seen dim.
244
+ target = self._modality_dims[m]
245
+ if arr.shape[1] < target:
246
+ pad = np.zeros((arr.shape[0], target - arr.shape[1]),
247
+ dtype=np.float32)
248
+ sensors[m] = np.concatenate([arr, pad], axis=1)
249
+ else:
250
+ sensors[m] = arr[:, :target]
251
+ else:
252
+ self._modality_dims[m] = arr.shape[1]
253
+
254
+ segs = _load_annotations(annot_path)
255
+ rec_used = False
256
+ # BOS index for first segment in a recording (or after dropped segs).
257
+ BOS_VC = NUM_VERB_COMPOSITE # = 6
258
+ BOS_N = NUM_NOUN # = 34
259
+ prev_vc, prev_n = BOS_VC, BOS_N
260
+ for seg_idx, seg in enumerate(segs):
261
+ stats_counts["segments_seen"] += 1
262
+ a = seg.get("action_annotation", {})
263
+ labels = classify_segment(a)
264
+ if labels is None:
265
+ stats_counts["seg_dropped_label"] += 1
266
+ # do not advance prev (skipped segment doesn't update context)
267
+ continue
268
+
269
+ start_sec, end_sec = parse_ts_range(seg.get("timestamp", ""))
270
+ if end_sec - start_sec < min_seg_duration_sec:
271
+ stats_counts["seg_dropped_short"] += 1
272
+ continue
273
+
274
+ if self.mode == "anticipation":
275
+ anchor_sec = start_sec - self.t_fut_sec
276
+ obs_start_sec = anchor_sec - self.t_obs_sec
277
+ if obs_start_sec < 0:
278
+ stats_counts["seg_dropped_too_early"] += 1
279
+ continue
280
+ i0 = int(round(obs_start_sec * SAMPLING_RATE_HZ))
281
+ i1 = int(round(anchor_sec * SAMPLING_RATE_HZ))
282
+ meta_extra = {"anchor_sec": anchor_sec}
283
+ else: # recognition
284
+ # Use the segment's own [start, end] as the input window.
285
+ i0 = int(round(start_sec * SAMPLING_RATE_HZ))
286
+ i1 = int(round(end_sec * SAMPLING_RATE_HZ))
287
+ meta_extra = {"start_sec": start_sec, "end_sec": end_sec}
288
+
289
+ T_avail = min(a.shape[0] for a in sensors.values())
290
+ if i1 > T_avail:
291
+ stats_counts["seg_dropped_too_early"] += 1
292
+ continue
293
+ if i0 < 0:
294
+ i0 = 0 # safety; recognition mode shouldn't hit this
295
+
296
+ window: Dict[str, np.ndarray] = {}
297
+ for m, arr in sensors.items():
298
+ w = arr[i0:i1]
299
+ # Downsample: decimate every `downsample`-th frame.
300
+ w = w[::self.downsample]
301
+ window[m] = w
302
+
303
+ # Must have at least 4 post-downsample frames to be useful.
304
+ min_T = min(w.shape[0] for w in window.values())
305
+ if min_T < 4:
306
+ stats_counts["seg_dropped_short"] += 1
307
+ continue
308
+
309
+ self._items.append({
310
+ "x": window,
311
+ "y": labels,
312
+ "prev": {"verb_composite": prev_vc, "noun": prev_n},
313
+ "meta": {
314
+ "vol": vol, "scene": scene,
315
+ "seg_idx": seg_idx, **meta_extra,
316
+ },
317
+ })
318
+ stats_counts["seg_kept"] += 1
319
+ # Update context for next kept segment in this recording.
320
+ prev_vc = labels["verb_composite"]
321
+ prev_n = labels["noun"]
322
+ rec_used = True
323
+
324
+ if rec_used:
325
+ stats_counts["recordings_used"] += 1
326
+
327
+ if len(self._items) == 0:
328
+ raise RuntimeError(
329
+ "No samples collected. Check annot_dir, modalities, t_obs, t_fut."
330
+ )
331
+
332
+ # Per-modality z-score normalization using training-set stats.
333
+ if stats is None:
334
+ stats = self._compute_stats()
335
+ self._stats = stats
336
+ self._apply_stats(stats)
337
+
338
+ if log:
339
+ print(f"[TripletSeqPredDataset:{self.mode}] "
340
+ f"vols={len(volunteers)} "
341
+ f"recs_scan={stats_counts['recordings_scanned']} "
342
+ f"recs_used={stats_counts['recordings_used']} "
343
+ f"segs_seen={stats_counts['segments_seen']} "
344
+ f"kept={stats_counts['seg_kept']} "
345
+ f"drop_label={stats_counts['seg_dropped_label']} "
346
+ f"drop_early={stats_counts['seg_dropped_too_early']} "
347
+ f"drop_short={stats_counts['seg_dropped_short']}",
348
+ flush=True)
349
+ print(f" modality_dims={self._modality_dims} "
350
+ f"T_frames={self.T_frames} sr_down={sr}Hz",
351
+ flush=True)
352
+ self.stats_counts = stats_counts
353
+
354
+ # ----- stats (per-modality mean/std on training split) -----
355
+ def _compute_stats(self) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
356
+ acc: Dict[str, List[np.ndarray]] = {m: [] for m in self.modalities}
357
+ for it in self._items:
358
+ for m, w in it["x"].items():
359
+ acc[m].append(w.astype(np.float64))
360
+ out: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
361
+ for m, arrs in acc.items():
362
+ cat = np.concatenate(arrs, axis=0)
363
+ mu = cat.mean(axis=0, keepdims=True)
364
+ sd = cat.std(axis=0, keepdims=True)
365
+ sd[sd < 1e-8] = 1.0
366
+ out[m] = (mu.astype(np.float32), sd.astype(np.float32))
367
+ return out
368
+
369
+ def _apply_stats(self, stats: Dict[str, Tuple[np.ndarray, np.ndarray]]) -> None:
370
+ for it in self._items:
371
+ for m, w in it["x"].items():
372
+ mu, sd = stats[m]
373
+ z = (w.astype(np.float32) - mu) / sd
374
+ z = np.nan_to_num(z, nan=0.0, posinf=0.0, neginf=0.0)
375
+ it["x"][m] = z.astype(np.float32)
376
+
377
+ def get_stats(self) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
378
+ return self._stats
379
+
380
+ # ----- Dataset protocol -----
381
+ def __len__(self) -> int:
382
+ return len(self._items)
383
+
384
+ def __getitem__(self, idx: int):
385
+ it = self._items[idx]
386
+ x = {m: torch.from_numpy(w) for m, w in it["x"].items()}
387
+ y = it["y"]
388
+ meta = it["meta"]
389
+ prev = it.get("prev", {"verb_composite": NUM_VERB_COMPOSITE, "noun": NUM_NOUN})
390
+ return x, y, meta, prev
391
+
392
+ # ----- convenience -----
393
+ @property
394
+ def modality_dims(self) -> Dict[str, int]:
395
+ return dict(self._modality_dims)
396
+
397
+ @property
398
+ def total_feat_dim(self) -> int:
399
+ return sum(self._modality_dims.values())
400
+
401
+ def class_counts(self) -> Dict[str, np.ndarray]:
402
+ vf = np.zeros(NUM_VERB_FINE, dtype=np.int64)
403
+ vc = np.zeros(NUM_VERB_COMPOSITE, dtype=np.int64)
404
+ n = np.zeros(NUM_NOUN, dtype=np.int64)
405
+ h = np.zeros(NUM_HAND, dtype=np.int64)
406
+ for it in self._items:
407
+ y = it["y"]
408
+ vf[y["verb_fine"]] += 1
409
+ vc[y["verb_composite"]] += 1
410
+ n[y["noun"]] += 1
411
+ h[y["hand"]] += 1
412
+ return {"verb_fine": vf, "verb_composite": vc, "noun": n, "hand": h}
413
+
414
+
415
+ # ---------------------------------------------------------------------------
416
+ # Collate: pad each modality to the max T_frames in the batch
417
+ # ---------------------------------------------------------------------------
418
+
419
+ def collate_triplet(batch):
420
+ """Stack samples into batched tensors. Backward-compatible: accepts
421
+ samples of either (x, y, meta) or (x, y, meta, prev) form.
422
+
423
+ Returned:
424
+ x: dict[mod] -> FloatTensor (B, T_max, F_mod)
425
+ mask: BoolTensor (B, T_max)
426
+ lens: LongTensor (B,)
427
+ y: dict (each -> LongTensor (B,))
428
+ meta: list of dicts
429
+ prev: dict {'verb_composite': LongTensor (B,), 'noun': LongTensor (B,)}
430
+ values are class indices, with NUM_VERB_COMPOSITE / NUM_NOUN
431
+ used as a BOS sentinel for the first segment in a recording.
432
+ """
433
+ has_prev = len(batch[0]) >= 4
434
+ if has_prev:
435
+ xs, ys, metas, prevs = zip(*batch)
436
+ else:
437
+ xs, ys, metas = zip(*batch)
438
+ prevs = [{"verb_composite": NUM_VERB_COMPOSITE, "noun": NUM_NOUN} for _ in batch]
439
+ B = len(batch)
440
+ mods = list(xs[0].keys())
441
+ lens = torch.tensor([x[mods[0]].shape[0] for x in xs], dtype=torch.long)
442
+ T_max = int(lens.max().item())
443
+
444
+ x_out: Dict[str, torch.Tensor] = {}
445
+ for m in mods:
446
+ F = xs[0][m].shape[1]
447
+ padded = torch.zeros(B, T_max, F, dtype=torch.float32)
448
+ for i, x in enumerate(xs):
449
+ w = x[m]
450
+ padded[i, :w.shape[0]] = w
451
+ x_out[m] = padded
452
+
453
+ ar = torch.arange(T_max).unsqueeze(0)
454
+ mask = ar < lens.unsqueeze(1)
455
+
456
+ y_out = {
457
+ k: torch.tensor([y[k] for y in ys], dtype=torch.long)
458
+ for k in ("verb_fine", "verb_composite", "noun", "hand")
459
+ }
460
+ prev_out = {
461
+ "verb_composite": torch.tensor([p["verb_composite"] for p in prevs], dtype=torch.long),
462
+ "noun": torch.tensor([p["noun"] for p in prevs], dtype=torch.long),
463
+ }
464
+ return x_out, mask, lens, y_out, list(metas), prev_out
465
+
466
+
467
+ # ---------------------------------------------------------------------------
468
+ # Convenience: build paired train/test datasets with shared normalization
469
+ # ---------------------------------------------------------------------------
470
+
471
+ def build_train_test(
472
+ modalities: Sequence[str] = ("imu", "mocap", "emg", "eyetrack", "pressure"),
473
+ t_obs_sec: float = 8.0,
474
+ t_fut_sec: float = 2.0,
475
+ downsample: int = DEFAULT_DOWNSAMPLE,
476
+ dataset_dir: Path = DEFAULT_DATASET_DIR,
477
+ annot_dir: Path = DEFAULT_ANNOT_DIR,
478
+ mode: str = "recognition",
479
+ ) -> Tuple["TripletSeqPredDataset", "TripletSeqPredDataset"]:
480
+ train = TripletSeqPredDataset(
481
+ TRAIN_VOLS_V3, modalities=modalities,
482
+ t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec, downsample=downsample,
483
+ dataset_dir=dataset_dir, annot_dir=annot_dir, mode=mode,
484
+ )
485
+ test = TripletSeqPredDataset(
486
+ TEST_VOLS_V3, modalities=modalities,
487
+ t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec, downsample=downsample,
488
+ dataset_dir=dataset_dir, annot_dir=annot_dir,
489
+ stats=train.get_stats(), mode=mode,
490
+ )
491
+ return train, test
492
+
493
+
494
+ # ---------------------------------------------------------------------------
495
+ # CLI: quick sanity check
496
+ # ---------------------------------------------------------------------------
497
+
498
+ if __name__ == "__main__":
499
+ import argparse
500
+
501
+ ap = argparse.ArgumentParser()
502
+ ap.add_argument("--modalities", type=str, default="imu,emg,eyetrack")
503
+ ap.add_argument("--t_obs", type=float, default=8.0)
504
+ ap.add_argument("--t_fut", type=float, default=2.0)
505
+ ap.add_argument("--smoke_n", type=int, default=3,
506
+ help="Inspect first N samples per split")
507
+ args = ap.parse_args()
508
+
509
+ mods = args.modalities.split(",")
510
+ print(f"Building train/test with modalities={mods} "
511
+ f"t_obs={args.t_obs}s t_fut={args.t_fut}s ...")
512
+ train, test = build_train_test(
513
+ modalities=mods,
514
+ t_obs_sec=args.t_obs,
515
+ t_fut_sec=args.t_fut,
516
+ )
517
+ print(f"train: {len(train)} samples | test: {len(test)} samples")
518
+
519
+ for name, ds in [("train", train), ("test", test)]:
520
+ counts = ds.class_counts()
521
+ print(f"\n[{name}] class counts:")
522
+ print(" verb_fine:", counts["verb_fine"].tolist())
523
+ print(" verb_composite:", counts["verb_composite"].tolist())
524
+ print(" noun (sum):", int(counts["noun"].sum()),
525
+ "nonzero:", int((counts["noun"] > 0).sum()))
526
+ print(" hand:", counts["hand"].tolist())
527
+
528
+ print(f"\n[{name}] first {args.smoke_n} samples:")
529
+ for i in range(min(args.smoke_n, len(ds))):
530
+ x, y, meta = ds[i]
531
+ shape_str = " ".join(f"{m}:{tuple(x[m].shape)}" for m in x)
532
+ print(f" {i:3d} {meta['vol']}/{meta['scene']}#{meta['seg_idx']:3d} "
533
+ f"anchor={meta['anchor_sec']:.2f}s y={y} {shape_str}")
experiments/data/dataset_signal_forecast.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Frame-level future *signal* forecasting dataset (T8 v2).
2
+
3
+ Task definition
4
+ ---------------
5
+ At a sampled anchor t in a recording:
6
+ past = sensor frames over [t - T_obs, t] ← input
7
+ future = target-modality frames over (t, t + T_fut] ← regression target
8
+
9
+ Unlike the v1 ForecastDataset (which targets per-frame verb-fine class), this
10
+ predicts the raw *signal* values of one chosen target modality. This directly
11
+ tests the Johansson 1984 / monzee 2003 hypothesis that cutaneous force
12
+ feedback drives sub-second motor planning at the *signal* level (motor
13
+ commands / kinematics), not at the level of slow-changing semantic verbs.
14
+
15
+ Anchor stratification (4 event types based on contact transitions)
16
+ ------------------------------------------------------------------
17
+ For each candidate anchor, we compute pressure_sum on past and future windows
18
+ and label it by the (past_majority_contact, future_majority_contact) pair:
19
+
20
+ type 0 = non-contact (past low, future low) — control: pressure ~ 0
21
+ type 1 = pre-contact (past low, future high) — pressure foretells onset
22
+ type 2 = steady-grip (past high, future high) — sustained contact dynamics
23
+ type 3 = release (past high, future low) — letting-go dynamics
24
+
25
+ Per-event-type counts are reported and (optionally) capped to balance.
26
+ Evaluation is broken down per event type so we can see WHERE pressure helps.
27
+ """
28
+ from __future__ import annotations
29
+
30
+ import sys
31
+ from pathlib import Path
32
+ from typing import Dict, List, Optional, Sequence, Tuple
33
+
34
+ import numpy as np
35
+ import torch
36
+ from torch.utils.data import Dataset
37
+
38
+ THIS = Path(__file__).resolve()
39
+ sys.path.insert(0, str(THIS.parent))
40
+ sys.path.insert(0, str(THIS.parents[1]))
41
+
42
+ try:
43
+ from experiments.dataset_seqpred import (
44
+ SAMPLING_RATE_HZ, _load_recording_sensors,
45
+ TRAIN_VOLS_V3, TEST_VOLS_V3,
46
+ DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR,
47
+ )
48
+ except ModuleNotFoundError:
49
+ from dataset_seqpred import (
50
+ SAMPLING_RATE_HZ, _load_recording_sensors,
51
+ TRAIN_VOLS_V3, TEST_VOLS_V3,
52
+ DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR,
53
+ )
54
+
55
+
56
+ EVENT_NAMES = {0: "non-contact", 1: "pre-contact", 2: "steady-grip", 3: "release"}
57
+
58
+
59
+ class SignalForecastDataset(Dataset):
60
+ """Predict future T_fut frames of `target_modality` from past T_obs of `input_modalities`."""
61
+
62
+ def __init__(
63
+ self,
64
+ volunteers: Sequence[str],
65
+ input_modalities: Sequence[str],
66
+ target_modality: str,
67
+ t_obs_sec: float = 1.5,
68
+ t_fut_sec: float = 0.5,
69
+ anchor_stride_sec: float = 0.25,
70
+ downsample: int = 5,
71
+ dataset_dir: Path = DEFAULT_DATASET_DIR,
72
+ annot_dir: Path = DEFAULT_ANNOT_DIR,
73
+ contact_threshold_g: float = 5.0,
74
+ per_event_max: Optional[int] = None,
75
+ input_stats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]] = None,
76
+ target_stats: Optional[Tuple[np.ndarray, np.ndarray]] = None,
77
+ future_pressure_stats: Optional[Tuple[np.ndarray, np.ndarray]] = None,
78
+ expected_input_dims: Optional[Dict[str, int]] = None,
79
+ expected_target_dim: Optional[int] = None,
80
+ include_future_pressure: bool = False,
81
+ rng_seed: int = 0,
82
+ log: bool = True,
83
+ ):
84
+ super().__init__()
85
+ self.input_modalities = list(input_modalities)
86
+ self.target_modality = str(target_modality)
87
+ self.t_obs_sec = float(t_obs_sec)
88
+ self.t_fut_sec = float(t_fut_sec)
89
+ self.anchor_stride_sec = float(anchor_stride_sec)
90
+ self.downsample = int(downsample)
91
+ self.sr = SAMPLING_RATE_HZ // self.downsample
92
+ self.dataset_dir = Path(dataset_dir)
93
+ self.annot_dir = Path(annot_dir)
94
+ self.contact_threshold_g = float(contact_threshold_g)
95
+ self.per_event_max = per_event_max
96
+ self.include_future_pressure = bool(include_future_pressure)
97
+ self.T_obs = int(round(self.t_obs_sec * self.sr))
98
+ self.T_fut = int(round(self.t_fut_sec * self.sr))
99
+
100
+ self._items: List[dict] = []
101
+ self._modality_dims: Dict[str, int] = dict(expected_input_dims) if expected_input_dims else {}
102
+ self._target_dim: int = int(expected_target_dim) if expected_target_dim else -1
103
+ rng = np.random.default_rng(rng_seed)
104
+
105
+ # Modalities to load: union of inputs + target + pressure (for filter)
106
+ load_mods = list(dict.fromkeys(
107
+ list(self.input_modalities) + [self.target_modality, "pressure"]
108
+ ))
109
+
110
+ # Per-event-type pool of candidate anchor records
111
+ pools: Dict[int, List[dict]] = {0: [], 1: [], 2: [], 3: []}
112
+
113
+ for vol in volunteers:
114
+ vol_dir = self.dataset_dir / vol
115
+ if not vol_dir.is_dir():
116
+ continue
117
+ for scenario_dir in sorted(vol_dir.glob("s*")):
118
+ if not scenario_dir.is_dir():
119
+ continue
120
+ scene = scenario_dir.name
121
+ annot_path = self.annot_dir / vol / f"{scene}.json"
122
+ if not annot_path.exists():
123
+ continue
124
+ try:
125
+ sensors_all = _load_recording_sensors(
126
+ scenario_dir, vol, scene, load_mods
127
+ )
128
+ except Exception:
129
+ continue
130
+ if sensors_all is None or any(a is None for a in sensors_all.values()):
131
+ continue
132
+
133
+ pressure_full = sensors_all["pressure"] # (T, 50)
134
+ target_full = sensors_all[self.target_modality]
135
+ input_arrs = {m: sensors_all[m] for m in self.input_modalities}
136
+
137
+ # Track input modality dims
138
+ for m, arr in input_arrs.items():
139
+ self._enforce_dim(input_arrs, m, arr, self._modality_dims)
140
+ # Track target dim
141
+ if self._target_dim < 0:
142
+ self._target_dim = target_full.shape[1]
143
+ elif target_full.shape[1] != self._target_dim:
144
+ if target_full.shape[1] < self._target_dim:
145
+ pad = np.zeros((target_full.shape[0], self._target_dim - target_full.shape[1]),
146
+ dtype=np.float32)
147
+ target_full = np.concatenate([target_full, pad], axis=1)
148
+ else:
149
+ target_full = target_full[:, :self._target_dim]
150
+
151
+ T_avail = min(a.shape[0] for a in input_arrs.values())
152
+ T_avail = min(T_avail, target_full.shape[0], pressure_full.shape[0])
153
+ if T_avail < (self.T_obs + self.T_fut) * self.downsample:
154
+ continue
155
+
156
+ # Downsample to 20 Hz
157
+ input_ds = {m: arr[:T_avail:self.downsample] for m, arr in input_arrs.items()}
158
+ target_ds = target_full[:T_avail:self.downsample]
159
+ pressure_ds = pressure_full[:T_avail:self.downsample]
160
+ T_ds = target_ds.shape[0]
161
+ pressure_sum = pressure_ds.sum(axis=1) # (T_ds,)
162
+
163
+ stride = max(1, int(round(self.anchor_stride_sec * self.sr)))
164
+ first_anchor = self.T_obs
165
+ last_anchor = T_ds - self.T_fut
166
+ if last_anchor <= first_anchor:
167
+ continue
168
+
169
+ for anchor in range(first_anchor, last_anchor + 1, stride):
170
+ past_p = pressure_sum[anchor - self.T_obs:anchor]
171
+ fut_p = pressure_sum[anchor:anchor + self.T_fut]
172
+ past_high = (past_p > self.contact_threshold_g).mean() > 0.5
173
+ fut_high = (fut_p > self.contact_threshold_g).mean() > 0.5
174
+ if not past_high and not fut_high:
175
+ et = 0
176
+ elif not past_high and fut_high:
177
+ et = 1
178
+ elif past_high and fut_high:
179
+ et = 2
180
+ else:
181
+ et = 3
182
+
183
+ past_slice = {m: arr[anchor - self.T_obs:anchor]
184
+ for m, arr in input_ds.items()}
185
+ past_target_last = target_ds[anchor - 1].copy() # (target_dim,)
186
+ fut_target = target_ds[anchor:anchor + self.T_fut].copy()
187
+ if any(w.shape[0] != self.T_obs for w in past_slice.values()):
188
+ continue
189
+ if fut_target.shape[0] != self.T_fut:
190
+ continue
191
+
192
+ item = {
193
+ "x": past_slice,
194
+ "y": fut_target,
195
+ "y_last": past_target_last, # for persistence
196
+ "event_type": int(et),
197
+ "meta": {"vol": vol, "scene": scene, "anchor_idx": int(anchor)},
198
+ }
199
+ if self.include_future_pressure:
200
+ fut_press = pressure_ds[anchor:anchor + self.T_fut].copy()
201
+ if fut_press.shape[0] != self.T_fut:
202
+ continue
203
+ item["fp"] = fut_press # (T_fut, 50)
204
+ pools[et].append(item)
205
+
206
+ # Cap per-event count if requested (uniform downsample for balance)
207
+ for et, pool in pools.items():
208
+ if self.per_event_max is not None and len(pool) > self.per_event_max:
209
+ idx = rng.choice(len(pool), size=self.per_event_max, replace=False)
210
+ pools[et] = [pool[i] for i in sorted(idx)]
211
+ self._items = [it for et in (0, 1, 2, 3) for it in pools[et]]
212
+
213
+ if not self._items:
214
+ raise RuntimeError("SignalForecastDataset: collected 0 anchors.")
215
+
216
+ # Z-score inputs and target separately
217
+ if input_stats is None:
218
+ input_stats = self._compute_input_stats()
219
+ self._input_stats = input_stats
220
+ self._apply_input_stats(input_stats)
221
+ if target_stats is None:
222
+ target_stats = self._compute_target_stats()
223
+ self._target_stats = target_stats
224
+ self._apply_target_stats(target_stats)
225
+ if self.include_future_pressure:
226
+ if future_pressure_stats is None:
227
+ future_pressure_stats = self._compute_fp_stats()
228
+ self._fp_stats = future_pressure_stats
229
+ self._apply_fp_stats(future_pressure_stats)
230
+ else:
231
+ self._fp_stats = None
232
+
233
+ if log:
234
+ counts = {EVENT_NAMES[k]: sum(1 for it in self._items if it["event_type"] == k)
235
+ for k in (0, 1, 2, 3)}
236
+ print(f"[SignalForecastDataset] vols={len(volunteers)} "
237
+ f"target={self.target_modality} inputs={self.input_modalities} "
238
+ f"anchors={len(self._items)} {counts} "
239
+ f"T_obs={self.T_obs} T_fut={self.T_fut} sr={self.sr}Hz "
240
+ f"input_dims={self._modality_dims} target_dim={self._target_dim}",
241
+ flush=True)
242
+
243
+ @staticmethod
244
+ def _enforce_dim(arrs, m, arr, dim_dict):
245
+ if m in dim_dict:
246
+ target = dim_dict[m]
247
+ if arr.shape[1] != target:
248
+ if arr.shape[1] < target:
249
+ pad = np.zeros((arr.shape[0], target - arr.shape[1]), dtype=np.float32)
250
+ arrs[m] = np.concatenate([arr, pad], axis=1)
251
+ else:
252
+ arrs[m] = arr[:, :target]
253
+ else:
254
+ dim_dict[m] = arr.shape[1]
255
+
256
+ def _compute_input_stats(self):
257
+ accs = {m: [] for m in self._modality_dims}
258
+ for it in self._items:
259
+ for m, w in it["x"].items():
260
+ accs[m].append(w)
261
+ out = {}
262
+ for m, ws in accs.items():
263
+ cat = np.concatenate(ws, axis=0)
264
+ mu = cat.mean(axis=0).astype(np.float32)
265
+ sd = cat.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd)
266
+ out[m] = (mu, sd.astype(np.float32))
267
+ return out
268
+
269
+ def _apply_input_stats(self, stats):
270
+ for it in self._items:
271
+ for m, w in it["x"].items():
272
+ if m in stats:
273
+ mu, sd = stats[m]
274
+ it["x"][m] = ((w - mu) / sd).astype(np.float32)
275
+
276
+ def _compute_target_stats(self):
277
+ ys = np.concatenate([it["y"] for it in self._items], axis=0)
278
+ mu = ys.mean(axis=0).astype(np.float32)
279
+ sd = ys.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd)
280
+ return (mu, sd.astype(np.float32))
281
+
282
+ def _apply_target_stats(self, stats):
283
+ mu, sd = stats
284
+ for it in self._items:
285
+ it["y"] = ((it["y"] - mu) / sd).astype(np.float32)
286
+ it["y_last"] = ((it["y_last"] - mu) / sd).astype(np.float32)
287
+
288
+ def _compute_fp_stats(self):
289
+ fps = np.concatenate([it["fp"] for it in self._items], axis=0)
290
+ mu = fps.mean(axis=0).astype(np.float32)
291
+ sd = fps.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd)
292
+ return (mu, sd.astype(np.float32))
293
+
294
+ def _apply_fp_stats(self, stats):
295
+ mu, sd = stats
296
+ for it in self._items:
297
+ it["fp"] = ((it["fp"] - mu) / sd).astype(np.float32)
298
+
299
+ def __len__(self):
300
+ return len(self._items)
301
+
302
+ def __getitem__(self, idx):
303
+ it = self._items[idx]
304
+ x = {m: torch.from_numpy(np.ascontiguousarray(w)) for m, w in it["x"].items()}
305
+ y = torch.from_numpy(np.ascontiguousarray(it["y"])) # (T_fut, target_dim)
306
+ y_last = torch.from_numpy(np.ascontiguousarray(it["y_last"])) # (target_dim,)
307
+ et = int(it["event_type"])
308
+ if self.include_future_pressure:
309
+ fp = torch.from_numpy(np.ascontiguousarray(it["fp"])) # (T_fut, 50)
310
+ return x, y, y_last, fp, et, it["meta"]
311
+ return x, y, y_last, et, it["meta"]
312
+
313
+ @property
314
+ def modality_dims(self):
315
+ return dict(self._modality_dims)
316
+
317
+ @property
318
+ def target_dim(self):
319
+ return self._target_dim
320
+
321
+
322
+ def collate_signal_forecast(batch):
323
+ if len(batch[0]) == 6: # has future pressure
324
+ xs, ys, ylasts, fps, ets, metas = zip(*batch)
325
+ mods = list(xs[0].keys())
326
+ x_out = {m: torch.stack([x[m] for x in xs], dim=0) for m in mods}
327
+ y_out = torch.stack(ys, dim=0)
328
+ yl_out = torch.stack(ylasts, dim=0)
329
+ fp_out = torch.stack(fps, dim=0) # (B, T_fut, 50)
330
+ et_out = torch.tensor(ets, dtype=torch.long)
331
+ return x_out, y_out, yl_out, fp_out, et_out, list(metas)
332
+ xs, ys, ylasts, ets, metas = zip(*batch)
333
+ mods = list(xs[0].keys())
334
+ x_out = {m: torch.stack([x[m] for x in xs], dim=0) for m in mods}
335
+ y_out = torch.stack(ys, dim=0)
336
+ yl_out = torch.stack(ylasts, dim=0)
337
+ et_out = torch.tensor(ets, dtype=torch.long)
338
+ return x_out, y_out, yl_out, et_out, list(metas)
339
+
340
+
341
+ def build_signal_train_test(
342
+ input_modalities, target_modality,
343
+ t_obs_sec=1.5, t_fut_sec=0.5, anchor_stride_sec=0.25,
344
+ downsample=5,
345
+ dataset_dir=DEFAULT_DATASET_DIR, annot_dir=DEFAULT_ANNOT_DIR,
346
+ contact_threshold_g=5.0, per_event_max=None,
347
+ include_future_pressure=False,
348
+ rng_seed=0,
349
+ ):
350
+ train = SignalForecastDataset(
351
+ TRAIN_VOLS_V3, input_modalities=input_modalities,
352
+ target_modality=target_modality,
353
+ t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec,
354
+ anchor_stride_sec=anchor_stride_sec, downsample=downsample,
355
+ dataset_dir=dataset_dir, annot_dir=annot_dir,
356
+ contact_threshold_g=contact_threshold_g, per_event_max=per_event_max,
357
+ include_future_pressure=include_future_pressure,
358
+ rng_seed=rng_seed, log=True,
359
+ )
360
+ test = SignalForecastDataset(
361
+ TEST_VOLS_V3, input_modalities=input_modalities,
362
+ target_modality=target_modality,
363
+ t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec,
364
+ anchor_stride_sec=anchor_stride_sec, downsample=downsample,
365
+ dataset_dir=dataset_dir, annot_dir=annot_dir,
366
+ contact_threshold_g=contact_threshold_g, per_event_max=per_event_max,
367
+ input_stats=train._input_stats, target_stats=train._target_stats,
368
+ future_pressure_stats=train._fp_stats,
369
+ expected_input_dims=train._modality_dims,
370
+ expected_target_dim=train._target_dim,
371
+ include_future_pressure=include_future_pressure,
372
+ rng_seed=rng_seed + 1, log=True,
373
+ )
374
+ return train, test
375
+
376
+
377
+ if __name__ == "__main__":
378
+ import argparse
379
+ ap = argparse.ArgumentParser()
380
+ ap.add_argument("--input_modalities", default="imu")
381
+ ap.add_argument("--target_modality", default="imu")
382
+ ap.add_argument("--t_obs", type=float, default=1.5)
383
+ ap.add_argument("--t_fut", type=float, default=0.5)
384
+ args = ap.parse_args()
385
+ tr, te = build_signal_train_test(
386
+ input_modalities=args.input_modalities.split(","),
387
+ target_modality=args.target_modality,
388
+ t_obs_sec=args.t_obs, t_fut_sec=args.t_fut,
389
+ )
390
+ x, y, y_last, et, meta = tr[0]
391
+ print(f"Sample: x={ {m: tuple(v.shape) for m,v in x.items()} } y={tuple(y.shape)} y_last={tuple(y_last.shape)} event_type={et}")
experiments/nets/__init__.py ADDED
File without changes
experiments/nets/__pycache__/models_seqpred.cpython-312.pyc ADDED
Binary file (44.4 kB). View file
 
experiments/nets/baselines_published/__init__.py ADDED
File without changes
experiments/nets/baselines_published/baselines.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Published baselines for T1 Scene Recognition, reproduced on DailyAct-5M.
3
+
4
+ Each method accepts a concatenated feature tensor (B, T, F_total) where F_total
5
+ is the sum of the active modality dims; the per-modality slices are recorded in
6
+ the `modality_dims` dict. Each method then uses the subset of modalities its
7
+ original paper intended.
8
+
9
+ All methods output an (B, num_classes) logit tensor.
10
+ """
11
+ import math
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+
17
+ def _slice(x, mod_dims, wanted):
18
+ """Slice the concatenated feature tensor to keep only `wanted` modalities,
19
+ in the order given. mod_dims is an ordered dict. Returns
20
+ {name: tensor(B,T,d_name)} plus the concat."""
21
+ parts = {}
22
+ offset = 0
23
+ for name, d in mod_dims.items():
24
+ if name in wanted:
25
+ parts[name] = x[..., offset:offset + d]
26
+ offset += d
27
+ assert len(parts) > 0, f"None of {wanted} in {list(mod_dims.keys())}"
28
+ return parts
29
+
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # 1) ST-GCN (Yan et al., AAAI 2018)
33
+ # Spatio-temporal graph CNN for skeleton action recognition.
34
+ # We treat the 56-joint MoCap skeleton as the graph.
35
+ # ---------------------------------------------------------------------------
36
+
37
+ class STGCNBlock(nn.Module):
38
+ def __init__(self, in_ch, out_ch, n_joints, stride=1, dropout=0.2):
39
+ super().__init__()
40
+ # Spatial graph conv: learnable adjacency (fully learned, no handcrafted A)
41
+ self.A = nn.Parameter(torch.eye(n_joints) + 0.1 * torch.randn(n_joints, n_joints))
42
+ self.spatial = nn.Conv2d(in_ch, out_ch, kernel_size=(1, 1), bias=False)
43
+ self.spatial_bn = nn.BatchNorm2d(out_ch)
44
+ self.temporal = nn.Conv2d(out_ch, out_ch, kernel_size=(9, 1),
45
+ padding=(4, 0), stride=(stride, 1))
46
+ self.temporal_bn = nn.BatchNorm2d(out_ch)
47
+ self.dropout = nn.Dropout(dropout)
48
+ if in_ch != out_ch or stride != 1:
49
+ self.res = nn.Conv2d(in_ch, out_ch, kernel_size=1,
50
+ stride=(stride, 1))
51
+ else:
52
+ self.res = nn.Identity()
53
+
54
+ def forward(self, x):
55
+ # x: (B, C, T, V)
56
+ res = self.res(x)
57
+ # spatial: aggregate along joints via A
58
+ h = self.spatial(x)
59
+ h = torch.einsum('bctv,vw->bctw', h, F.softmax(self.A, dim=-1))
60
+ h = self.spatial_bn(h)
61
+ h = F.relu(h)
62
+ # temporal
63
+ h = self.temporal(h)
64
+ h = self.temporal_bn(h)
65
+ h = self.dropout(h)
66
+ return F.relu(h + res)
67
+
68
+
69
+ class STGCN(nn.Module):
70
+ """ST-GCN on MoCap skeleton. We assume the MoCap modality is 620-dim
71
+ (hip-relative + velocity) and reshape to ~56 joints."""
72
+ def __init__(self, feat_dim_mocap, num_classes, hidden=64, n_joints=52):
73
+ super().__init__()
74
+ self.n_joints = n_joints
75
+ # MoCap feat is (T, 620). 52 joints × 4 (xyz+quat_type), or we take per-joint xyz-only = 156.
76
+ # In this repo, 620 = 52 markers * 4 cols + velocity features. We'll
77
+ # reshape by slicing to 3*52=156 "primary" coords, padded if needed.
78
+ self.coord_dim = 3 # we'll treat each joint as having 3 coords (XYZ)
79
+ self.proj_in = nn.Linear(feat_dim_mocap, n_joints * self.coord_dim)
80
+
81
+ self.blocks = nn.ModuleList([
82
+ STGCNBlock(self.coord_dim, hidden, n_joints),
83
+ STGCNBlock(hidden, hidden, n_joints),
84
+ STGCNBlock(hidden, hidden * 2, n_joints, stride=2),
85
+ STGCNBlock(hidden * 2, hidden * 2, n_joints),
86
+ STGCNBlock(hidden * 2, hidden * 4, n_joints, stride=2),
87
+ STGCNBlock(hidden * 4, hidden * 4, n_joints),
88
+ ])
89
+ self.head = nn.Sequential(
90
+ nn.Dropout(0.3),
91
+ nn.Linear(hidden * 4, num_classes),
92
+ )
93
+
94
+ def forward(self, x_mocap, mask=None):
95
+ # x_mocap: (B, T, feat_dim_mocap)
96
+ B, T, _ = x_mocap.shape
97
+ h = self.proj_in(x_mocap) # (B, T, n_joints * 3)
98
+ h = h.reshape(B, T, self.n_joints, self.coord_dim).permute(0, 3, 1, 2) # (B, C, T, V)
99
+ for blk in self.blocks:
100
+ h = blk(h)
101
+ # Global mean pool over time & joints (with mask if provided)
102
+ if mask is not None:
103
+ # mask: (B, T), h: (B, C, T', V) where T' may be < T due to stride
104
+ T_ = h.shape[2]
105
+ m = mask[:, :T_].float().unsqueeze(1).unsqueeze(-1) # (B, 1, T', 1)
106
+ h = (h * m).sum(dim=(2, 3)) / (m.sum(dim=(2, 3)) * h.shape[3] + 1e-8)
107
+ else:
108
+ h = h.mean(dim=(2, 3))
109
+ return self.head(h)
110
+
111
+
112
+ # ---------------------------------------------------------------------------
113
+ # 2) CTR-GCN (Chen et al., ICCV 2021)
114
+ # Channel-wise Topology Refinement GCN — learns a separate adjacency
115
+ # matrix per channel group, known as SOTA for skeleton action recognition.
116
+ # ---------------------------------------------------------------------------
117
+
118
+ class CTRGC(nn.Module):
119
+ """Simplified CTR-GC block: learnable per-channel topology refinement."""
120
+ def __init__(self, in_ch, out_ch, n_joints, rel_reduction=4):
121
+ super().__init__()
122
+ self.n_joints = n_joints
123
+ self.conv1 = nn.Conv2d(in_ch, out_ch // rel_reduction, 1)
124
+ self.conv2 = nn.Conv2d(in_ch, out_ch // rel_reduction, 1)
125
+ self.conv3 = nn.Conv2d(in_ch, out_ch, 1)
126
+ self.alpha = nn.Parameter(torch.zeros(1))
127
+ self.A = nn.Parameter(torch.eye(n_joints) + 0.1 * torch.randn(n_joints, n_joints))
128
+
129
+ def forward(self, x):
130
+ # x: (B, C, T, V)
131
+ q = self.conv1(x).mean(dim=2) # (B, C', V)
132
+ k = self.conv2(x).mean(dim=2) # (B, C', V)
133
+ v = self.conv3(x) # (B, C_out, T, V)
134
+ # Channel-specific topology refinement
135
+ topology = F.softmax(torch.tanh(q.unsqueeze(-1) - k.unsqueeze(-2)), dim=-1)
136
+ # topology: (B, C', V, V); we average across channels to get a shared (B, V, V)
137
+ topology = topology.mean(dim=1)
138
+ A = self.A.unsqueeze(0) + self.alpha * topology
139
+ # apply A to v
140
+ out = torch.einsum('bctv,bvw->bctw', v, A)
141
+ return out
142
+
143
+
144
+ class CTRGCNBlock(nn.Module):
145
+ def __init__(self, in_ch, out_ch, n_joints, stride=1):
146
+ super().__init__()
147
+ self.gc = CTRGC(in_ch, out_ch, n_joints)
148
+ self.bn = nn.BatchNorm2d(out_ch)
149
+ self.tcn = nn.Sequential(
150
+ nn.Conv2d(out_ch, out_ch, (9, 1), padding=(4, 0), stride=(stride, 1)),
151
+ nn.BatchNorm2d(out_ch),
152
+ )
153
+ if in_ch != out_ch or stride != 1:
154
+ self.res = nn.Conv2d(in_ch, out_ch, 1, stride=(stride, 1))
155
+ else:
156
+ self.res = nn.Identity()
157
+
158
+ def forward(self, x):
159
+ res = self.res(x)
160
+ h = self.gc(x)
161
+ h = self.bn(h)
162
+ h = F.relu(h)
163
+ h = self.tcn(h)
164
+ return F.relu(h + res)
165
+
166
+
167
+ class CTRGCN(nn.Module):
168
+ def __init__(self, feat_dim_mocap, num_classes, hidden=64, n_joints=52):
169
+ super().__init__()
170
+ self.n_joints = n_joints
171
+ self.coord_dim = 3
172
+ self.proj_in = nn.Linear(feat_dim_mocap, n_joints * self.coord_dim)
173
+ self.blocks = nn.ModuleList([
174
+ CTRGCNBlock(self.coord_dim, hidden, n_joints),
175
+ CTRGCNBlock(hidden, hidden, n_joints),
176
+ CTRGCNBlock(hidden, hidden * 2, n_joints, stride=2),
177
+ CTRGCNBlock(hidden * 2, hidden * 4, n_joints, stride=2),
178
+ ])
179
+ self.head = nn.Sequential(
180
+ nn.Dropout(0.3),
181
+ nn.Linear(hidden * 4, num_classes),
182
+ )
183
+
184
+ def forward(self, x_mocap, mask=None):
185
+ B, T, _ = x_mocap.shape
186
+ h = self.proj_in(x_mocap)
187
+ h = h.reshape(B, T, self.n_joints, self.coord_dim).permute(0, 3, 1, 2)
188
+ for blk in self.blocks:
189
+ h = blk(h)
190
+ h = h.mean(dim=(2, 3))
191
+ return self.head(h)
192
+
193
+
194
+ # ---------------------------------------------------------------------------
195
+ # 3) LIMU-BERT (Xu et al., SenSys 2021)
196
+ # IMU self-supervised pretraining via masked reconstruction + fine-tune.
197
+ # We implement a simpler variant: BERT-style encoder with optional
198
+ # pretraining head.
199
+ # ---------------------------------------------------------------------------
200
+
201
+ class LIMUBertEncoder(nn.Module):
202
+ def __init__(self, feat_dim_imu, hidden=128, n_layers=4, n_heads=4, dropout=0.1):
203
+ super().__init__()
204
+ self.in_proj = nn.Linear(feat_dim_imu, hidden)
205
+ self.pos = nn.Parameter(torch.zeros(1, 4096, hidden))
206
+ nn.init.trunc_normal_(self.pos, std=0.02)
207
+ layer = nn.TransformerEncoderLayer(
208
+ d_model=hidden, nhead=n_heads, dim_feedforward=4 * hidden,
209
+ dropout=dropout, batch_first=True, activation='gelu',
210
+ )
211
+ self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)
212
+
213
+ def forward(self, x, mask):
214
+ T = x.size(1)
215
+ h = self.in_proj(x) + self.pos[:, :T, :]
216
+ h = self.encoder(h, src_key_padding_mask=~mask)
217
+ return h
218
+
219
+
220
+ class LIMUBert(nn.Module):
221
+ """Supervised-only variant: encoder + classifier head. Paper's
222
+ pretraining is a masked-recon objective; for simplicity we report the
223
+ supervised-only baseline here."""
224
+ def __init__(self, feat_dim_imu, num_classes, hidden=128, n_layers=4,
225
+ n_heads=4, dropout=0.1):
226
+ super().__init__()
227
+ self.encoder = LIMUBertEncoder(feat_dim_imu, hidden, n_layers, n_heads, dropout)
228
+ self.head = nn.Sequential(
229
+ nn.LayerNorm(hidden),
230
+ nn.Dropout(dropout),
231
+ nn.Linear(hidden, num_classes),
232
+ )
233
+
234
+ def forward(self, x_imu, mask):
235
+ h = self.encoder(x_imu, mask)
236
+ m = mask.unsqueeze(-1).float()
237
+ pooled = (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)
238
+ return self.head(pooled)
239
+
240
+
241
+ # ---------------------------------------------------------------------------
242
+ # 4) EMG-CNN (standard 1D CNN baseline from sEMG classification literature)
243
+ # E.g. Atzori et al. — multi-layer CNN with moving-window input.
244
+ # ---------------------------------------------------------------------------
245
+
246
+ class EMGCNN(nn.Module):
247
+ def __init__(self, feat_dim_emg, num_classes, hidden=64):
248
+ super().__init__()
249
+ self.cnn = nn.Sequential(
250
+ nn.Conv1d(feat_dim_emg, hidden, 7, padding=3),
251
+ nn.BatchNorm1d(hidden), nn.ReLU(), nn.Dropout(0.3),
252
+ nn.Conv1d(hidden, hidden * 2, 5, padding=2),
253
+ nn.BatchNorm1d(hidden * 2), nn.ReLU(), nn.Dropout(0.3),
254
+ nn.Conv1d(hidden * 2, hidden * 4, 3, padding=1),
255
+ nn.BatchNorm1d(hidden * 4), nn.ReLU(),
256
+ )
257
+ self.head = nn.Linear(hidden * 4, num_classes)
258
+
259
+ def forward(self, x_emg, mask):
260
+ # (B, T, 8) -> (B, 8, T) for conv1d
261
+ h = self.cnn(x_emg.transpose(1, 2))
262
+ # Masked pool
263
+ m = mask.unsqueeze(1).float()
264
+ T_ = h.size(2)
265
+ if m.size(2) != T_:
266
+ m = F.adaptive_avg_pool1d(m, T_)
267
+ m = (m > 0.5).float()
268
+ pooled = (h * m).sum(dim=2) / m.sum(dim=2).clamp(min=1.0)
269
+ return self.head(pooled)
270
+
271
+
272
+ # ---------------------------------------------------------------------------
273
+ # 5) ActionSense baseline (DelPreto et al., NeurIPS '22)
274
+ # Simple 3-layer MLP per modality + shared LSTM + classifier.
275
+ # ---------------------------------------------------------------------------
276
+
277
+ class ActionSenseLSTM(nn.Module):
278
+ def __init__(self, modality_dims: dict, num_classes, hidden=128):
279
+ super().__init__()
280
+ self.mod_names = list(modality_dims.keys())
281
+ self.mod_dims = modality_dims
282
+ self.per_mod = nn.ModuleDict({
283
+ name: nn.Sequential(
284
+ nn.Linear(d, hidden), nn.ReLU(), nn.Dropout(0.2),
285
+ nn.Linear(hidden, hidden), nn.ReLU(),
286
+ ) for name, d in modality_dims.items()
287
+ })
288
+ concat_dim = hidden * len(modality_dims)
289
+ self.lstm = nn.LSTM(concat_dim, hidden, num_layers=2,
290
+ batch_first=True, bidirectional=True, dropout=0.2)
291
+ self.head = nn.Linear(hidden * 2, num_classes)
292
+
293
+ def forward(self, x, mask):
294
+ # x: (B, T, F_total), slice by modality
295
+ offset = 0
296
+ feats = []
297
+ for name in self.mod_names:
298
+ d = self.mod_dims[name]
299
+ x_m = x[..., offset:offset + d]
300
+ offset += d
301
+ feats.append(self.per_mod[name](x_m))
302
+ h = torch.cat(feats, dim=-1) # (B, T, hidden * M)
303
+ h, _ = self.lstm(h)
304
+ m = mask.unsqueeze(-1).float()
305
+ pooled = (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)
306
+ return self.head(pooled)
307
+
308
+
309
+ # ---------------------------------------------------------------------------
310
+ # 6) MulT (Multimodal Transformer, Tsai et al., ACL 2019)
311
+ # Core idea: cross-modal attention between every pair of modalities.
312
+ # For a 3-modality input (A, B, C), produce
313
+ # {A->B, A->C, B->A, B->C, C->A, C->B} via directed cross-attention.
314
+ # ---------------------------------------------------------------------------
315
+
316
+ class CrossModalTransformer(nn.Module):
317
+ def __init__(self, d_model, n_heads=4, n_layers=2, dropout=0.1):
318
+ super().__init__()
319
+ self.layers = nn.ModuleList([
320
+ nn.TransformerDecoderLayer(
321
+ d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
322
+ dropout=dropout, batch_first=True, activation='gelu',
323
+ ) for _ in range(n_layers)
324
+ ])
325
+
326
+ def forward(self, q, kv, q_mask, kv_mask):
327
+ # q: (B, T_q, D), kv: (B, T_kv, D)
328
+ h = q
329
+ for layer in self.layers:
330
+ h = layer(h, kv,
331
+ tgt_key_padding_mask=~q_mask,
332
+ memory_key_padding_mask=~kv_mask)
333
+ return h
334
+
335
+
336
+ class MulT(nn.Module):
337
+ """Multimodal Transformer. Uses MoCap + EMG + IMU as 3 modalities
338
+ (EyeTrack/Pressure omitted to match original 3-mod paper design)."""
339
+ def __init__(self, modality_dims: dict, num_classes, d_model=128,
340
+ n_layers=2, n_heads=4, dropout=0.1):
341
+ super().__init__()
342
+ self.mod_names = [m for m in ['mocap', 'emg', 'imu'] if m in modality_dims]
343
+ if len(self.mod_names) < 2:
344
+ self.mod_names = list(modality_dims.keys())[:3]
345
+ self.mod_dims = {m: modality_dims[m] for m in self.mod_names}
346
+ self.in_proj = nn.ModuleDict({
347
+ m: nn.Linear(d, d_model) for m, d in self.mod_dims.items()
348
+ })
349
+ # Pairwise cross-attention
350
+ self.cross = nn.ModuleDict({
351
+ f"{a}_to_{b}": CrossModalTransformer(d_model, n_heads, n_layers, dropout)
352
+ for a in self.mod_names for b in self.mod_names if a != b
353
+ })
354
+ # Self-attention after cross
355
+ self.self_tx = nn.ModuleDict({
356
+ m: nn.TransformerEncoder(
357
+ nn.TransformerEncoderLayer(
358
+ d_model=d_model, nhead=n_heads,
359
+ dim_feedforward=4 * d_model, dropout=dropout,
360
+ batch_first=True, activation='gelu',
361
+ ), num_layers=1,
362
+ ) for m in self.mod_names
363
+ })
364
+ total_dim = d_model * len(self.mod_names) * len(self.mod_names)
365
+ self.head = nn.Sequential(
366
+ nn.LayerNorm(total_dim),
367
+ nn.Dropout(dropout),
368
+ nn.Linear(total_dim, num_classes),
369
+ )
370
+
371
+ def forward(self, x, mask):
372
+ # Slice modalities from x
373
+ offset = 0
374
+ projs = {}
375
+ # Walk through all known mod_dims to find offsets
376
+ # We need the FULL modality_dims order, which we don't have here;
377
+ # expect caller to already supply x with exactly mod_names in order.
378
+ # Workaround: assume caller passes mod_names order matching projection.
379
+ for m in self.mod_names:
380
+ d = self.mod_dims[m]
381
+ projs[m] = self.in_proj[m](x[..., offset:offset + d])
382
+ offset += d
383
+
384
+ # Cross-attention: each modality attends to each other
385
+ fused = {m: [] for m in self.mod_names}
386
+ for a in self.mod_names:
387
+ for b in self.mod_names:
388
+ if a == b:
389
+ fused[a].append(projs[a])
390
+ else:
391
+ out = self.cross[f"{a}_to_{b}"](projs[a], projs[b], mask, mask)
392
+ fused[a].append(out)
393
+
394
+ # Self-attention + pool per modality
395
+ pooled = []
396
+ for a in self.mod_names:
397
+ # Concat all attended-to representations along feature dim
398
+ cat = torch.cat(fused[a], dim=-1) # (B, T, D * M)
399
+ # Actually re-project back to D per stream, then self-attn on stacked
400
+ # Simplified: self-attention over concatenated, pool, flatten
401
+ # Here we just pool each separately
402
+ for i, rep in enumerate(fused[a]):
403
+ rep = self.self_tx[a](rep)
404
+ m = mask.unsqueeze(-1).float()
405
+ p = (rep * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)
406
+ pooled.append(p)
407
+
408
+ h = torch.cat(pooled, dim=-1)
409
+ return self.head(h)
410
+
411
+
412
+ # ---------------------------------------------------------------------------
413
+ # 7) Perceiver IO (Jaegle et al., ICML 2021)
414
+ # Cross-attention from a fixed-size latent query set to all input tokens,
415
+ # repeated for a few iterations.
416
+ # ---------------------------------------------------------------------------
417
+
418
+ class PerceiverBlock(nn.Module):
419
+ def __init__(self, latent_dim, n_heads, dropout):
420
+ super().__init__()
421
+ self.ca = nn.MultiheadAttention(
422
+ latent_dim, n_heads, dropout=dropout, batch_first=True,
423
+ )
424
+ self.norm1 = nn.LayerNorm(latent_dim)
425
+ self.sa = nn.TransformerEncoderLayer(
426
+ d_model=latent_dim, nhead=n_heads,
427
+ dim_feedforward=4 * latent_dim, dropout=dropout,
428
+ batch_first=True, activation='gelu',
429
+ )
430
+
431
+ def forward(self, latents, inputs, input_kpm):
432
+ # Cross-attn: latents attend to inputs
433
+ h, _ = self.ca(latents, inputs, inputs, key_padding_mask=input_kpm)
434
+ latents = self.norm1(latents + h)
435
+ # Self-attn on latents
436
+ latents = self.sa(latents)
437
+ return latents
438
+
439
+
440
+ class PerceiverIO(nn.Module):
441
+ """Perceiver with N learnable latent queries; supports any modality mix."""
442
+ def __init__(self, modality_dims: dict, num_classes,
443
+ latent_dim=128, n_latents=32, n_layers=3, n_heads=4, dropout=0.1):
444
+ super().__init__()
445
+ self.mod_names = list(modality_dims.keys())
446
+ self.mod_dims = modality_dims
447
+ # Per-modality input projection to latent_dim, with modality-id embedding
448
+ self.in_proj = nn.ModuleDict({
449
+ m: nn.Linear(d, latent_dim) for m, d in modality_dims.items()
450
+ })
451
+ self.mod_emb = nn.Parameter(torch.randn(len(self.mod_names), latent_dim) * 0.02)
452
+ # Positional encoding (shared)
453
+ self.pos = nn.Parameter(torch.zeros(1, 4096, latent_dim))
454
+ nn.init.trunc_normal_(self.pos, std=0.02)
455
+ # Learnable latents
456
+ self.latents = nn.Parameter(torch.randn(n_latents, latent_dim) * 0.02)
457
+ self.blocks = nn.ModuleList([
458
+ PerceiverBlock(latent_dim, n_heads, dropout) for _ in range(n_layers)
459
+ ])
460
+ self.head = nn.Sequential(
461
+ nn.LayerNorm(latent_dim),
462
+ nn.Linear(latent_dim, num_classes),
463
+ )
464
+
465
+ def forward(self, x, mask):
466
+ B, T, _ = x.shape
467
+ # Project each modality + add modality embedding
468
+ offset = 0
469
+ tokens = []
470
+ for i, m in enumerate(self.mod_names):
471
+ d = self.mod_dims[m]
472
+ tok = self.in_proj[m](x[..., offset:offset + d]) # (B, T, D)
473
+ tok = tok + self.mod_emb[i]
474
+ offset += d
475
+ tokens.append(tok)
476
+ # Concatenate along TIME dim, add shared pos enc per-modality
477
+ # Each modality gets its own time sequence concatenated
478
+ # Simpler: sum across modalities (like early fusion in latent space) + pos
479
+ h = torch.stack(tokens, dim=2).mean(dim=2) # (B, T, D)
480
+ h = h + self.pos[:, :T, :]
481
+ input_kpm = ~mask # (B, T), True = ignore
482
+ # Iterative cross-attention
483
+ latents = self.latents.unsqueeze(0).expand(B, -1, -1) # (B, N, D)
484
+ for blk in self.blocks:
485
+ latents = blk(latents, h, input_kpm)
486
+ # Mean-pool latents
487
+ pooled = latents.mean(dim=1)
488
+ return self.head(pooled)
experiments/nets/baselines_published/syncfuse.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SyncFuse — our proposed method for T1 scene recognition.
3
+
4
+ Four components (all toggleable via args for ablation):
5
+
6
+ (1) Modality dropout: per-sample independent Bernoulli(p=0.3) drop on each
7
+ modality during training; at test time all modalities
8
+ are active. Keeps at least 1 modality.
9
+ (2) Pretrained transfer: each per-modality backbone is optionally loaded from
10
+ an independently pretrained single-modality
11
+ checkpoint and frozen during fine-tuning.
12
+ (3) Cross-modal temporal-shift attention:
13
+ a late cross-attention block where EMG queries
14
+ attend to MoCap keys/values at a LEARNED temporal
15
+ offset Δ (Gumbel-softmax over {-10,...,+10} bins at
16
+ 20 Hz = ±500 ms). Motivated by the paper's case-study
17
+ finding (EMG leads motion by ~20 ms sub-frame).
18
+ (4) Learnable late fusion:
19
+ per-modality classifier logits are combined with a
20
+ learnable softmax-weighted average (temperature is
21
+ also learned). Equivalent to `late_agg='learned'`
22
+ in the repo's existing LateFusionModel.
23
+ """
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+ import random
28
+
29
+
30
+ def masked_mean(x, mask):
31
+ m = mask.unsqueeze(-1).float()
32
+ return (x * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)
33
+
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Per-modality Transformer branch (same as repo's TransformerBackbone)
37
+ # ---------------------------------------------------------------------------
38
+
39
+ class ModTransformer(nn.Module):
40
+ def __init__(self, feat_dim, hidden=128, n_layers=2, n_heads=4, dropout=0.1):
41
+ super().__init__()
42
+ self.in_proj = nn.Linear(feat_dim, hidden)
43
+ self.pos = nn.Parameter(torch.zeros(1, 4096, hidden))
44
+ nn.init.trunc_normal_(self.pos, std=0.02)
45
+ layer = nn.TransformerEncoderLayer(
46
+ d_model=hidden, nhead=n_heads, dim_feedforward=4 * hidden,
47
+ dropout=dropout, batch_first=True, activation='gelu',
48
+ )
49
+ self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)
50
+ self.output_dim = hidden
51
+
52
+ def forward(self, x, mask):
53
+ # x: (B, T, feat_dim)
54
+ T = x.size(1)
55
+ h = self.in_proj(x) + self.pos[:, :T, :]
56
+ h = self.encoder(h, src_key_padding_mask=~mask)
57
+ return h # (B, T, hidden) — token-level, NOT pooled
58
+
59
+
60
+ # ---------------------------------------------------------------------------
61
+ # (3) Cross-modal temporal-shift attention
62
+ # ---------------------------------------------------------------------------
63
+
64
+ class TemporalShiftAttention(nn.Module):
65
+ """Multi-head attention where queries are temporally shifted by a learned
66
+ offset Δ from the keys. Δ is drawn from a discrete set {-3,...,+3} via
67
+ straight-through Gumbel-softmax: we sample ONE shift per forward pass,
68
+ but the softmax weights flow gradient back through shift_logits.
69
+
70
+ At 20 Hz bins, ±3 ≈ ±150 ms, which brackets the paper's ~20 ms EMG-motion
71
+ lead. Memory cost is ~1 attention pass (not 7)."""
72
+ def __init__(self, d_model, n_heads=4, dropout=0.1, max_shift=3,
73
+ gumbel_tau=1.0):
74
+ super().__init__()
75
+ self.max_shift = max_shift
76
+ self.shifts = list(range(-max_shift, max_shift + 1))
77
+ self.shift_logits = nn.Parameter(torch.zeros(len(self.shifts)))
78
+ self.tau = gumbel_tau
79
+ self.attn = nn.MultiheadAttention(
80
+ d_model, n_heads, dropout=dropout, batch_first=True,
81
+ )
82
+ self.norm = nn.LayerNorm(d_model)
83
+
84
+ def _shift_tensor(self, x, shift, mask):
85
+ if shift == 0:
86
+ return x, mask
87
+ B, T, D = x.shape
88
+ if shift > 0:
89
+ pad = torch.zeros(B, shift, D, device=x.device, dtype=x.dtype)
90
+ x_s = torch.cat([x[:, shift:, :], pad], dim=1)
91
+ m_s = torch.cat([mask[:, shift:],
92
+ torch.zeros(B, shift, device=mask.device, dtype=torch.bool)],
93
+ dim=1)
94
+ else:
95
+ s = -shift
96
+ pad = torch.zeros(B, s, D, device=x.device, dtype=x.dtype)
97
+ x_s = torch.cat([pad, x[:, :-s, :]], dim=1)
98
+ m_s = torch.cat([torch.zeros(B, s, device=mask.device, dtype=torch.bool),
99
+ mask[:, :-s]], dim=1)
100
+ return x_s, m_s
101
+
102
+ def forward(self, q_tokens, kv_tokens, q_mask, kv_mask, hard=False):
103
+ if hard or not self.training:
104
+ # Eval: take the argmax shift
105
+ with torch.no_grad():
106
+ idx = self.shift_logits.argmax().item()
107
+ shift = self.shifts[idx]
108
+ shifted_kv, shifted_mask = self._shift_tensor(kv_tokens, shift, kv_mask)
109
+ out, _ = self.attn(q_tokens, shifted_kv, shifted_kv,
110
+ key_padding_mask=~shifted_mask)
111
+ return self.norm(q_tokens + out)
112
+
113
+ # Training: straight-through Gumbel-softmax to sample 1 shift,
114
+ # with gradient flowing via softmax weights.
115
+ one_hot = F.gumbel_softmax(self.shift_logits, tau=self.tau, hard=True)
116
+ # pick the sampled shift (argmax of the hard one-hot)
117
+ idx = int(one_hot.argmax().item())
118
+ shift = self.shifts[idx]
119
+ shifted_kv, shifted_mask = self._shift_tensor(kv_tokens, shift, kv_mask)
120
+ out, _ = self.attn(q_tokens, shifted_kv, shifted_kv,
121
+ key_padding_mask=~shifted_mask)
122
+ # scale out by the corresponding soft weight to let gradient flow
123
+ out = out * one_hot[idx]
124
+ return self.norm(q_tokens + out)
125
+
126
+
127
+ # ---------------------------------------------------------------------------
128
+ # SyncFuse main model
129
+ # ---------------------------------------------------------------------------
130
+
131
+ class SyncFuse(nn.Module):
132
+ def __init__(self, modality_dims: dict, num_classes, hidden=128, n_heads=4,
133
+ n_layers=2, dropout=0.1,
134
+ use_xmod_shift=True, use_learned_late=True):
135
+ super().__init__()
136
+ self.mod_names = list(modality_dims.keys())
137
+ self.mod_dims = modality_dims
138
+ self.use_xmod_shift = use_xmod_shift
139
+ self.use_learned_late = use_learned_late
140
+
141
+ self.branches = nn.ModuleDict({
142
+ m: ModTransformer(d, hidden, n_layers, n_heads, dropout)
143
+ for m, d in modality_dims.items()
144
+ })
145
+ self.classifiers = nn.ModuleDict({
146
+ m: nn.Sequential(nn.LayerNorm(hidden), nn.Dropout(dropout),
147
+ nn.Linear(hidden, num_classes))
148
+ for m in self.mod_names
149
+ })
150
+
151
+ # Cross-modal temporal-shift: apply to EMG branch attending to MoCap
152
+ # (and symmetrically MoCap->EMG), only when both modalities are present.
153
+ if use_xmod_shift and 'emg' in self.mod_names and 'mocap' in self.mod_names:
154
+ self.xmod_emg2mocap = TemporalShiftAttention(hidden, n_heads, dropout)
155
+ self.xmod_mocap2emg = TemporalShiftAttention(hidden, n_heads, dropout)
156
+ else:
157
+ self.xmod_emg2mocap = None
158
+ self.xmod_mocap2emg = None
159
+
160
+ if use_learned_late:
161
+ self.late_logits = nn.Parameter(torch.zeros(len(self.mod_names)))
162
+ self.late_temperature = nn.Parameter(torch.ones(1))
163
+
164
+ def load_pretrained(self, pretrain_paths: dict, freeze=True):
165
+ """Load pretrained single-modality checkpoints into branches.
166
+ pretrain_paths: {modality_name: path_to_checkpoint_state_dict}."""
167
+ import torch as _torch
168
+ for m, path in pretrain_paths.items():
169
+ if m not in self.branches:
170
+ continue
171
+ try:
172
+ sd = _torch.load(path, weights_only=True, map_location='cpu')
173
+ except TypeError:
174
+ sd = _torch.load(path, map_location='cpu')
175
+ # Map SingleModel keys ("backbone.X.*") -> branch keys
176
+ mapped = {}
177
+ for k, v in sd.items():
178
+ if k.startswith('backbone.'):
179
+ new_k = k.replace('backbone.', '')
180
+ if new_k in self.branches[m].state_dict():
181
+ mapped[new_k] = v
182
+ if mapped:
183
+ self.branches[m].load_state_dict(mapped, strict=False)
184
+ if freeze:
185
+ for p in self.branches[m].parameters():
186
+ p.requires_grad = False
187
+ print(f" [SyncFuse] loaded {len(mapped)} tensors into branch '{m}' (frozen={freeze})")
188
+
189
+ def forward(self, x, mask, mod_dropout_p=0.0, training_time=True):
190
+ """
191
+ x: (B, T, F_total) concatenated features
192
+ mask: (B, T)
193
+ mod_dropout_p: probability of dropping each modality (training only)
194
+ """
195
+ B, T, _ = x.shape
196
+
197
+ # Slice modality features
198
+ offset = 0
199
+ feats = {}
200
+ for m in self.mod_names:
201
+ d = self.mod_dims[m]
202
+ feats[m] = x[..., offset:offset + d]
203
+ offset += d
204
+
205
+ # (1) Modality dropout — per sample, independent per modality
206
+ active = {m: torch.ones(B, dtype=torch.bool, device=x.device) for m in self.mod_names}
207
+ if training_time and self.training and mod_dropout_p > 0:
208
+ drop_map = {m: (torch.rand(B, device=x.device) < mod_dropout_p)
209
+ for m in self.mod_names}
210
+ all_dropped = torch.stack([drop_map[m] for m in self.mod_names], dim=0).all(dim=0) # (B,)
211
+ if all_dropped.any():
212
+ # for all-dropped samples, un-drop one random modality
213
+ rescue_idx = torch.randint(0, len(self.mod_names),
214
+ (all_dropped.sum().item(),),
215
+ device=x.device)
216
+ mod_name_tensor = self.mod_names # python list
217
+ j = 0
218
+ for b in range(B):
219
+ if all_dropped[b]:
220
+ r = mod_name_tensor[rescue_idx[j].item()]
221
+ drop_map[r][b] = False
222
+ j += 1
223
+ for m in self.mod_names:
224
+ active[m] = ~drop_map[m]
225
+ # zero out dropped features for that branch
226
+ feats[m] = feats[m] * active[m].view(B, 1, 1).float()
227
+
228
+ # Per-modality encoding
229
+ tokens = {}
230
+ for m in self.mod_names:
231
+ tokens[m] = self.branches[m](feats[m], mask) # (B, T, hidden)
232
+
233
+ # (3) Cross-modal temporal-shift (bidirectional EMG <-> MoCap)
234
+ if self.xmod_emg2mocap is not None:
235
+ tokens['emg'] = self.xmod_emg2mocap(
236
+ tokens['emg'], tokens['mocap'], mask, mask,
237
+ hard=not self.training,
238
+ )
239
+ tokens['mocap'] = self.xmod_mocap2emg(
240
+ tokens['mocap'], tokens['emg'], mask, mask,
241
+ hard=not self.training,
242
+ )
243
+
244
+ # Pool and classify per modality
245
+ logits_per = []
246
+ for m in self.mod_names:
247
+ pooled = masked_mean(tokens[m], mask)
248
+ logits_per.append(self.classifiers[m](pooled))
249
+ stacked = torch.stack(logits_per, dim=0) # (M, B, C)
250
+
251
+ # Mask out logits from dropped modalities (so they don't dominate)
252
+ if training_time and self.training and mod_dropout_p > 0:
253
+ act_mask = torch.stack([active[m].float() for m in self.mod_names], dim=0) # (M, B)
254
+ # Re-normalize weights across active modalities
255
+ if self.use_learned_late:
256
+ w = F.softmax(self.late_logits / self.late_temperature.clamp(min=0.1), dim=0)
257
+ w = w.view(-1, 1) * act_mask # (M, B)
258
+ w = w / w.sum(dim=0, keepdim=True).clamp(min=1e-6)
259
+ out = (stacked * w.unsqueeze(-1)).sum(dim=0)
260
+ else:
261
+ w = act_mask / act_mask.sum(dim=0, keepdim=True).clamp(min=1e-6)
262
+ out = (stacked * w.unsqueeze(-1)).sum(dim=0)
263
+ else:
264
+ # (4) Learnable late fusion (or simple mean)
265
+ if self.use_learned_late:
266
+ w = F.softmax(self.late_logits / self.late_temperature.clamp(min=0.1), dim=0)
267
+ out = (stacked * w.view(-1, 1, 1)).sum(dim=0)
268
+ else:
269
+ out = stacked.mean(dim=0)
270
+ return out
experiments/nets/models.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model definitions for Experiment 1: Scene Recognition.
3
+ Backbones: CNN1D, BiLSTM, Transformer
4
+ Fusion: Early (default), Late, Attention, WeightedLate, GatedLate, Stacking, Product, MoE
5
+
6
+ Supports optional per-modality projection via proj_dim parameter:
7
+ proj_dim > 0: project each modality to proj_dim before backbone
8
+ proj_dim = 0: no projection, use raw features (original behavior)
9
+ """
10
+
11
+ import math
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+
17
+ # ============================================================
18
+ # Per-modality projection
19
+ # ============================================================
20
+
21
+ class ModalityProjector(nn.Module):
22
+ """Project each modality from its raw dimension to proj_dim."""
23
+
24
+ def __init__(self, modality_dims, proj_dim):
25
+ super().__init__()
26
+ self.mod_names = list(modality_dims.keys())
27
+ self.mod_dims = list(modality_dims.values())
28
+ self.proj_dim = proj_dim
29
+ self.projectors = nn.ModuleList()
30
+ for dim in self.mod_dims:
31
+ self.projectors.append(nn.Sequential(
32
+ nn.Linear(dim, proj_dim),
33
+ nn.LayerNorm(proj_dim),
34
+ nn.ReLU(),
35
+ ))
36
+
37
+ @property
38
+ def output_dim(self):
39
+ return self.proj_dim * len(self.mod_dims)
40
+
41
+ def forward(self, x):
42
+ """x: (B, T, total_raw_dim) -> (B, T, proj_dim * M)"""
43
+ parts = []
44
+ offset = 0
45
+ for i, dim in enumerate(self.mod_dims):
46
+ x_mod = x[:, :, offset:offset + dim]
47
+ offset += dim
48
+ parts.append(self.projectors[i](x_mod))
49
+ return torch.cat(parts, dim=-1)
50
+
51
+
52
+ # ============================================================
53
+ # Per-modality hidden dim scaling (used when proj_dim=0)
54
+ # ============================================================
55
+
56
+ def _compute_per_modality_hidden(mod_dim, base_hidden_dim):
57
+ if mod_dim >= 128:
58
+ return max(base_hidden_dim, 48)
59
+ elif mod_dim >= 32:
60
+ return base_hidden_dim
61
+ else:
62
+ return max(16, base_hidden_dim // 2)
63
+
64
+
65
+ # ============================================================
66
+ # Backbones
67
+ # ============================================================
68
+
69
+ class CNN1DBackbone(nn.Module):
70
+ def __init__(self, input_dim, hidden_dim=128):
71
+ super().__init__()
72
+ self.conv1 = nn.Sequential(
73
+ nn.Conv1d(input_dim, 64, kernel_size=7, padding=3),
74
+ nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.1),
75
+ )
76
+ self.conv2 = nn.Sequential(
77
+ nn.Conv1d(64, 128, kernel_size=5, padding=2),
78
+ nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.1),
79
+ )
80
+ self.conv3 = nn.Sequential(
81
+ nn.Conv1d(128, hidden_dim, kernel_size=3, padding=1),
82
+ nn.BatchNorm1d(hidden_dim), nn.ReLU(),
83
+ )
84
+ self.output_dim = hidden_dim
85
+
86
+ def forward(self, x, mask=None):
87
+ x = x.permute(0, 2, 1)
88
+ x = self.conv1(x)
89
+ x = self.conv2(x)
90
+ x = self.conv3(x)
91
+ if mask is not None:
92
+ x = (x * mask.unsqueeze(1).float()).sum(2) / mask.sum(1, keepdim=True).float().clamp(min=1)
93
+ else:
94
+ x = x.mean(2)
95
+ return x
96
+
97
+
98
+ class LSTMBackbone(nn.Module):
99
+ def __init__(self, input_dim, hidden_dim=128, num_layers=2, dropout=0.2):
100
+ super().__init__()
101
+ self.lstm = nn.LSTM(
102
+ input_dim, hidden_dim, num_layers=num_layers,
103
+ batch_first=True, bidirectional=True,
104
+ dropout=dropout if num_layers > 1 else 0,
105
+ )
106
+ self.attn = nn.Linear(hidden_dim * 2, 1)
107
+ self.output_dim = hidden_dim * 2
108
+
109
+ def forward(self, x, mask=None):
110
+ out, _ = self.lstm(x)
111
+ scores = self.attn(out).squeeze(-1)
112
+ if mask is not None:
113
+ scores = scores.masked_fill(~mask, float('-inf'))
114
+ weights = torch.softmax(scores, dim=1)
115
+ out = (out * weights.unsqueeze(-1)).sum(dim=1)
116
+ return out
117
+
118
+
119
+ class TinyHARBackbone(nn.Module):
120
+ """TinyHAR backbone (Zhou et al., ISWC 2022 Best Paper).
121
+
122
+ Lightweight model for human activity recognition from wearable sensors.
123
+ Uses multi-scale temporal convolutions + cross-channel interaction + temporal pooling.
124
+
125
+ Input: (B, T, C) with optional mask
126
+ Output: (B, hidden_dim)
127
+ """
128
+
129
+ def __init__(self, input_dim, hidden_dim=128, num_scales=4):
130
+ super().__init__()
131
+ scale_dim = max(4, hidden_dim // num_scales)
132
+ actual_hidden = scale_dim * num_scales
133
+
134
+ # Multi-scale temporal convolution feature extraction
135
+ self.convs = nn.ModuleList()
136
+ for i in range(num_scales):
137
+ ks = 2 * (i + 1) + 1 # kernel sizes: 3, 5, 7, 9
138
+ self.convs.append(nn.Sequential(
139
+ nn.Conv1d(input_dim, scale_dim, kernel_size=ks, padding=ks // 2),
140
+ nn.BatchNorm1d(scale_dim),
141
+ nn.ReLU(),
142
+ ))
143
+
144
+ # Cross-channel interaction via multi-head self-attention
145
+ nhead = max(1, min(4, actual_hidden // 8))
146
+ # Ensure actual_hidden is divisible by nhead
147
+ while actual_hidden % nhead != 0 and nhead > 1:
148
+ nhead -= 1
149
+ self.channel_attn = nn.MultiheadAttention(
150
+ actual_hidden, num_heads=nhead, batch_first=True, dropout=0.1,
151
+ )
152
+ self.channel_norm = nn.LayerNorm(actual_hidden)
153
+ self.channel_ff = nn.Sequential(
154
+ nn.Linear(actual_hidden, actual_hidden),
155
+ nn.ReLU(),
156
+ nn.Dropout(0.1),
157
+ nn.Linear(actual_hidden, actual_hidden),
158
+ )
159
+ self.ff_norm = nn.LayerNorm(actual_hidden)
160
+
161
+ # Temporal attention pooling
162
+ self.temporal_query = nn.Parameter(torch.randn(1, 1, actual_hidden) * 0.02)
163
+ self.temporal_attn = nn.MultiheadAttention(
164
+ actual_hidden, num_heads=1, batch_first=True, dropout=0.1,
165
+ )
166
+
167
+ self.output_dim = actual_hidden
168
+
169
+ def forward(self, x, mask=None):
170
+ # x: (B, T, C)
171
+ B, T, C = x.shape
172
+ x_t = x.permute(0, 2, 1) # (B, C, T)
173
+
174
+ # Multi-scale feature extraction
175
+ scale_features = [conv(x_t) for conv in self.convs]
176
+ x = torch.cat(scale_features, dim=1) # (B, actual_hidden, T)
177
+ x = x.permute(0, 2, 1) # (B, T, actual_hidden)
178
+
179
+ # Cross-channel interaction
180
+ key_padding_mask = ~mask if mask is not None else None
181
+ attn_out, _ = self.channel_attn(x, x, x, key_padding_mask=key_padding_mask)
182
+ x = self.channel_norm(x + attn_out)
183
+ x = self.ff_norm(x + self.channel_ff(x))
184
+
185
+ # Temporal attention pooling
186
+ query = self.temporal_query.expand(B, -1, -1) # (B, 1, actual_hidden)
187
+ pooled, _ = self.temporal_attn(query, x, x, key_padding_mask=key_padding_mask)
188
+ return pooled.squeeze(1) # (B, actual_hidden)
189
+
190
+
191
+ class PositionalEncoding(nn.Module):
192
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
193
+ super().__init__()
194
+ self.dropout = nn.Dropout(p=dropout)
195
+ pe = torch.zeros(max_len, d_model)
196
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
197
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
198
+ pe[:, 0::2] = torch.sin(position * div_term)
199
+ pe[:, 1::2] = torch.cos(position * div_term)
200
+ pe = pe.unsqueeze(0)
201
+ self.register_buffer('pe', pe)
202
+
203
+ def forward(self, x):
204
+ x = x + self.pe[:, :x.size(1)]
205
+ return self.dropout(x)
206
+
207
+
208
+ class TransformerBackbone(nn.Module):
209
+ def __init__(self, input_dim, d_model=128, nhead=4, num_layers=2, dropout=0.1):
210
+ super().__init__()
211
+ self.input_proj = nn.Linear(input_dim, d_model)
212
+ self.pos_enc = PositionalEncoding(d_model, dropout=dropout)
213
+ encoder_layer = nn.TransformerEncoderLayer(
214
+ d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4,
215
+ dropout=dropout, batch_first=True,
216
+ )
217
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
218
+ self.output_dim = d_model
219
+
220
+ def forward(self, x, mask=None):
221
+ x = self.input_proj(x)
222
+ x = self.pos_enc(x)
223
+ src_key_padding_mask = ~mask if mask is not None else None
224
+ x = self.encoder(x, src_key_padding_mask=src_key_padding_mask)
225
+ if mask is not None:
226
+ x = (x * mask.unsqueeze(-1).float()).sum(1) / mask.sum(1, keepdim=True).float().clamp(min=1)
227
+ else:
228
+ x = x.mean(1)
229
+ return x
230
+
231
+
232
+ # ============================================================
233
+ # Full models
234
+ # ============================================================
235
+
236
+ def get_backbone(name, input_dim, hidden_dim=128):
237
+ if name == 'cnn':
238
+ return CNN1DBackbone(input_dim, hidden_dim)
239
+ elif name == 'lstm':
240
+ return LSTMBackbone(input_dim, hidden_dim)
241
+ elif name == 'transformer':
242
+ return TransformerBackbone(input_dim, hidden_dim)
243
+ elif name == 'tinyhar':
244
+ return TinyHARBackbone(input_dim, hidden_dim)
245
+ elif name == 'deepconvlstm':
246
+ from experiments.published_models import DeepConvLSTMBackbone
247
+ return DeepConvLSTMBackbone(input_dim, hidden_dim)
248
+ elif name == 'inceptiontime':
249
+ from experiments.published_models import InceptionTimeBackbone
250
+ return InceptionTimeBackbone(input_dim, hidden_dim)
251
+ else:
252
+ raise ValueError(f"Unknown backbone: {name}")
253
+
254
+
255
+ def _make_branch(backbone_name, raw_dim, hidden_dim, proj_dim):
256
+ """Create optional projector + backbone for one modality branch."""
257
+ if proj_dim > 0:
258
+ proj = nn.Sequential(
259
+ nn.Linear(raw_dim, proj_dim),
260
+ nn.LayerNorm(proj_dim),
261
+ nn.ReLU(),
262
+ )
263
+ bb_input = proj_dim
264
+ bb_hidden = hidden_dim
265
+ else:
266
+ proj = None
267
+ bb_input = raw_dim
268
+ bb_hidden = _compute_per_modality_hidden(raw_dim, hidden_dim)
269
+ bb = get_backbone(backbone_name, bb_input, bb_hidden)
270
+ return proj, bb
271
+
272
+
273
+ class SingleModel(nn.Module):
274
+ """Single backbone + classifier (early fusion or single-modality)."""
275
+
276
+ def __init__(self, backbone_name, input_dim, num_classes, hidden_dim=128,
277
+ modality_dims=None, proj_dim=0):
278
+ super().__init__()
279
+ self.projector = None
280
+ if proj_dim > 0 and modality_dims:
281
+ self.projector = ModalityProjector(modality_dims, proj_dim)
282
+ actual_input_dim = self.projector.output_dim
283
+ else:
284
+ actual_input_dim = input_dim
285
+ self.backbone = get_backbone(backbone_name, actual_input_dim, hidden_dim)
286
+ self.classifier = nn.Sequential(
287
+ nn.Dropout(0.5),
288
+ nn.Linear(self.backbone.output_dim, num_classes),
289
+ )
290
+
291
+ def forward(self, x, mask=None):
292
+ if self.projector is not None:
293
+ x = self.projector(x)
294
+ feat = self.backbone(x, mask)
295
+ return self.classifier(feat)
296
+
297
+
298
+ class LateFusionModel(nn.Module):
299
+ """Late fusion: separate backbone per modality, configurable logit aggregation.
300
+
301
+ late_agg='mean': simple average (original)
302
+ late_agg='confidence': entropy-based confidence weighting (0 extra params)
303
+ late_agg='learned': temperature-scaled learned weights (M+1 extra params)
304
+ """
305
+
306
+ def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64,
307
+ proj_dim=0, late_agg='mean'):
308
+ super().__init__()
309
+ self.mod_names = list(modality_dims.keys())
310
+ self.mod_dims = list(modality_dims.values())
311
+ self.late_agg = late_agg
312
+ self.projectors = nn.ModuleList()
313
+ self.backbones = nn.ModuleList()
314
+ self.classifiers = nn.ModuleList()
315
+ for dim in self.mod_dims:
316
+ proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim)
317
+ self.projectors.append(proj if proj else nn.Identity())
318
+ self.backbones.append(bb)
319
+ self.classifiers.append(nn.Sequential(
320
+ nn.Dropout(0.5), nn.Linear(bb.output_dim, num_classes),
321
+ ))
322
+ self._has_proj = proj_dim > 0
323
+
324
+ M = len(self.mod_dims)
325
+ if late_agg == 'learned':
326
+ self.modality_logits = nn.Parameter(torch.zeros(M))
327
+ self.temperature = nn.Parameter(torch.ones(1))
328
+
329
+ def forward(self, x, mask=None):
330
+ offset = 0
331
+ all_logits = []
332
+ for i, dim in enumerate(self.mod_dims):
333
+ x_mod = x[:, :, offset:offset + dim]
334
+ offset += dim
335
+ if self._has_proj:
336
+ x_mod = self.projectors[i](x_mod)
337
+ feat = self.backbones[i](x_mod, mask)
338
+ all_logits.append(self.classifiers[i](feat))
339
+
340
+ stacked = torch.stack(all_logits, dim=0) # (M, B, C)
341
+
342
+ if self.late_agg == 'confidence':
343
+ # Weight by confidence: low entropy → high weight
344
+ probs = F.softmax(stacked, dim=-1) # (M, B, C)
345
+ entropy = -(probs * (probs + 1e-8).log()).sum(dim=-1) # (M, B)
346
+ weights = F.softmax(-entropy, dim=0).unsqueeze(-1) # (M, B, 1)
347
+ return (stacked * weights).sum(dim=0)
348
+ elif self.late_agg == 'learned':
349
+ weights = F.softmax(self.modality_logits / self.temperature, dim=0)
350
+ return (stacked * weights.view(-1, 1, 1)).sum(dim=0)
351
+ else: # 'mean'
352
+ return stacked.mean(dim=0)
353
+
354
+
355
+ class AttentionFusionModel(nn.Module):
356
+ """Attention fusion: separate encoder per modality -> cross-modal attention -> classifier."""
357
+
358
+ def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0):
359
+ super().__init__()
360
+ self.mod_names = list(modality_dims.keys())
361
+ self.mod_dims = list(modality_dims.values())
362
+ unified_dim = hidden_dim
363
+ self.projectors = nn.ModuleList()
364
+ self.backbones = nn.ModuleList()
365
+ self.feat_projections = nn.ModuleList()
366
+ for dim in self.mod_dims:
367
+ proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim)
368
+ self.projectors.append(proj if proj else nn.Identity())
369
+ self.backbones.append(bb)
370
+ if bb.output_dim != unified_dim:
371
+ self.feat_projections.append(nn.Linear(bb.output_dim, unified_dim))
372
+ else:
373
+ self.feat_projections.append(nn.Identity())
374
+ self._has_proj = proj_dim > 0
375
+ nhead = 4 if unified_dim % 4 == 0 else (2 if unified_dim % 2 == 0 else 1)
376
+ self.cross_attn = nn.TransformerEncoderLayer(
377
+ d_model=unified_dim, nhead=nhead, dim_feedforward=unified_dim * 2,
378
+ dropout=0.1, batch_first=True,
379
+ )
380
+ self.classifier = nn.Sequential(
381
+ nn.Dropout(0.5), nn.Linear(unified_dim, num_classes),
382
+ )
383
+
384
+ def forward(self, x, mask=None):
385
+ offset = 0
386
+ mod_features = []
387
+ for i, dim in enumerate(self.mod_dims):
388
+ x_mod = x[:, :, offset:offset + dim]
389
+ offset += dim
390
+ if self._has_proj:
391
+ x_mod = self.projectors[i](x_mod)
392
+ feat = self.backbones[i](x_mod, mask)
393
+ feat = self.feat_projections[i](feat)
394
+ mod_features.append(feat)
395
+ tokens = torch.stack(mod_features, dim=1)
396
+ tokens = self.cross_attn(tokens)
397
+ pooled = tokens.mean(dim=1)
398
+ return self.classifier(pooled)
399
+
400
+
401
+ class WeightedLateFusionModel(nn.Module):
402
+ def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0):
403
+ super().__init__()
404
+ self.mod_names = list(modality_dims.keys())
405
+ self.mod_dims = list(modality_dims.values())
406
+ self.projectors = nn.ModuleList()
407
+ self.backbones = nn.ModuleList()
408
+ self.classifiers = nn.ModuleList()
409
+ for dim in self.mod_dims:
410
+ proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim)
411
+ self.projectors.append(proj if proj else nn.Identity())
412
+ self.backbones.append(bb)
413
+ self.classifiers.append(nn.Sequential(
414
+ nn.Dropout(0.5), nn.Linear(bb.output_dim, num_classes),
415
+ ))
416
+ self._has_proj = proj_dim > 0
417
+ self.modality_weights = nn.Parameter(torch.ones(len(self.mod_dims)))
418
+
419
+ def forward(self, x, mask=None):
420
+ offset = 0
421
+ all_logits = []
422
+ for i, dim in enumerate(self.mod_dims):
423
+ x_mod = x[:, :, offset:offset + dim]
424
+ offset += dim
425
+ if self._has_proj:
426
+ x_mod = self.projectors[i](x_mod)
427
+ feat = self.backbones[i](x_mod, mask)
428
+ all_logits.append(self.classifiers[i](feat))
429
+ weights = F.softmax(self.modality_weights, dim=0)
430
+ stacked = torch.stack(all_logits, dim=0)
431
+ return (stacked * weights.view(-1, 1, 1)).sum(dim=0)
432
+
433
+
434
+ class GatedLateFusionModel(nn.Module):
435
+ def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0):
436
+ super().__init__()
437
+ self.mod_names = list(modality_dims.keys())
438
+ self.mod_dims = list(modality_dims.values())
439
+ M = len(self.mod_dims)
440
+ self.projectors = nn.ModuleList()
441
+ self.backbones = nn.ModuleList()
442
+ self.classifiers = nn.ModuleList()
443
+ total_feat_dim = 0
444
+ for dim in self.mod_dims:
445
+ proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim)
446
+ self.projectors.append(proj if proj else nn.Identity())
447
+ self.backbones.append(bb)
448
+ total_feat_dim += bb.output_dim
449
+ self.classifiers.append(nn.Sequential(
450
+ nn.Dropout(0.5), nn.Linear(bb.output_dim, num_classes),
451
+ ))
452
+ self._has_proj = proj_dim > 0
453
+ self.gate = nn.Sequential(
454
+ nn.Linear(total_feat_dim, 32), nn.ReLU(), nn.Linear(32, M),
455
+ )
456
+
457
+ def forward(self, x, mask=None):
458
+ offset = 0
459
+ all_feats, all_logits = [], []
460
+ for i, dim in enumerate(self.mod_dims):
461
+ x_mod = x[:, :, offset:offset + dim]
462
+ offset += dim
463
+ if self._has_proj:
464
+ x_mod = self.projectors[i](x_mod)
465
+ feat = self.backbones[i](x_mod, mask)
466
+ all_feats.append(feat)
467
+ all_logits.append(self.classifiers[i](feat))
468
+ cat_feats = torch.cat(all_feats, dim=1)
469
+ gate_weights = F.softmax(self.gate(cat_feats), dim=1)
470
+ stacked = torch.stack(all_logits, dim=1)
471
+ return (stacked * gate_weights.unsqueeze(-1)).sum(dim=1)
472
+
473
+
474
+ class StackingFusionModel(nn.Module):
475
+ def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0):
476
+ super().__init__()
477
+ self.mod_names = list(modality_dims.keys())
478
+ self.mod_dims = list(modality_dims.values())
479
+ M = len(self.mod_dims)
480
+ self.projectors = nn.ModuleList()
481
+ self.backbones = nn.ModuleList()
482
+ self.classifiers = nn.ModuleList()
483
+ for dim in self.mod_dims:
484
+ proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim)
485
+ self.projectors.append(proj if proj else nn.Identity())
486
+ self.backbones.append(bb)
487
+ self.classifiers.append(nn.Sequential(
488
+ nn.Dropout(0.5), nn.Linear(bb.output_dim, num_classes),
489
+ ))
490
+ self._has_proj = proj_dim > 0
491
+ self.meta_learner = nn.Sequential(
492
+ nn.Linear(M * num_classes, 32), nn.ReLU(),
493
+ nn.Dropout(0.5), nn.Linear(32, num_classes),
494
+ )
495
+
496
+ def forward(self, x, mask=None):
497
+ offset = 0
498
+ all_logits = []
499
+ for i, dim in enumerate(self.mod_dims):
500
+ x_mod = x[:, :, offset:offset + dim]
501
+ offset += dim
502
+ if self._has_proj:
503
+ x_mod = self.projectors[i](x_mod)
504
+ feat = self.backbones[i](x_mod, mask)
505
+ all_logits.append(self.classifiers[i](feat))
506
+ cat_logits = torch.cat(all_logits, dim=1)
507
+ return self.meta_learner(cat_logits)
508
+
509
+
510
+ class ProductOfExpertsModel(nn.Module):
511
+ def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0):
512
+ super().__init__()
513
+ self.mod_names = list(modality_dims.keys())
514
+ self.mod_dims = list(modality_dims.values())
515
+ self.projectors = nn.ModuleList()
516
+ self.backbones = nn.ModuleList()
517
+ self.classifiers = nn.ModuleList()
518
+ for dim in self.mod_dims:
519
+ proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim)
520
+ self.projectors.append(proj if proj else nn.Identity())
521
+ self.backbones.append(bb)
522
+ self.classifiers.append(nn.Sequential(
523
+ nn.Dropout(0.5), nn.Linear(bb.output_dim, num_classes),
524
+ ))
525
+ self._has_proj = proj_dim > 0
526
+
527
+ def forward(self, x, mask=None):
528
+ offset = 0
529
+ log_probs_sum = None
530
+ for i, dim in enumerate(self.mod_dims):
531
+ x_mod = x[:, :, offset:offset + dim]
532
+ offset += dim
533
+ if self._has_proj:
534
+ x_mod = self.projectors[i](x_mod)
535
+ feat = self.backbones[i](x_mod, mask)
536
+ logits = self.classifiers[i](feat)
537
+ log_p = F.log_softmax(logits, dim=1)
538
+ log_probs_sum = log_p if log_probs_sum is None else log_probs_sum + log_p
539
+ return log_probs_sum
540
+
541
+
542
+ class MoEFusionModel(nn.Module):
543
+ def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0):
544
+ super().__init__()
545
+ self.mod_names = list(modality_dims.keys())
546
+ self.mod_dims = list(modality_dims.values())
547
+ M = len(self.mod_dims)
548
+ self.top_k = min(2, M)
549
+ self.projectors = nn.ModuleList()
550
+ self.backbones = nn.ModuleList()
551
+ self.classifiers = nn.ModuleList()
552
+ total_feat_dim = 0
553
+ for dim in self.mod_dims:
554
+ proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim)
555
+ self.projectors.append(proj if proj else nn.Identity())
556
+ self.backbones.append(bb)
557
+ total_feat_dim += bb.output_dim
558
+ self.classifiers.append(nn.Sequential(
559
+ nn.Dropout(0.5), nn.Linear(bb.output_dim, num_classes),
560
+ ))
561
+ self._has_proj = proj_dim > 0
562
+ self.router = nn.Linear(total_feat_dim, M)
563
+
564
+ def forward(self, x, mask=None):
565
+ offset = 0
566
+ all_feats, all_logits = [], []
567
+ for i, dim in enumerate(self.mod_dims):
568
+ x_mod = x[:, :, offset:offset + dim]
569
+ offset += dim
570
+ if self._has_proj:
571
+ x_mod = self.projectors[i](x_mod)
572
+ feat = self.backbones[i](x_mod, mask)
573
+ all_feats.append(feat)
574
+ all_logits.append(self.classifiers[i](feat))
575
+ cat_feats = torch.cat(all_feats, dim=1)
576
+ router_logits = self.router(cat_feats)
577
+ top_vals, top_idx = router_logits.topk(self.top_k, dim=1)
578
+ top_weights = F.softmax(top_vals, dim=1)
579
+ stacked = torch.stack(all_logits, dim=1)
580
+ top_idx_exp = top_idx.unsqueeze(-1).expand(-1, -1, stacked.size(-1))
581
+ selected = stacked.gather(1, top_idx_exp)
582
+ return (selected * top_weights.unsqueeze(-1)).sum(dim=1)
583
+
584
+
585
+ class FeatureConcatFusionModel(nn.Module):
586
+ """Feature-level late fusion: separate backbones, concatenate features, joint classifier."""
587
+
588
+ def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0):
589
+ super().__init__()
590
+ self.mod_names = list(modality_dims.keys())
591
+ self.mod_dims = list(modality_dims.values())
592
+ self.projectors = nn.ModuleList()
593
+ self.backbones = nn.ModuleList()
594
+ total_feat_dim = 0
595
+ for dim in self.mod_dims:
596
+ proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim)
597
+ self.projectors.append(proj if proj else nn.Identity())
598
+ self.backbones.append(bb)
599
+ total_feat_dim += bb.output_dim
600
+ self._has_proj = proj_dim > 0
601
+ self.classifier = nn.Sequential(
602
+ nn.LayerNorm(total_feat_dim),
603
+ nn.Dropout(0.5),
604
+ nn.Linear(total_feat_dim, hidden_dim),
605
+ nn.ReLU(),
606
+ nn.Dropout(0.3),
607
+ nn.Linear(hidden_dim, num_classes),
608
+ )
609
+
610
+ def forward(self, x, mask=None):
611
+ offset = 0
612
+ all_feats = []
613
+ for i, dim in enumerate(self.mod_dims):
614
+ x_mod = x[:, :, offset:offset + dim]
615
+ offset += dim
616
+ if self._has_proj:
617
+ x_mod = self.projectors[i](x_mod)
618
+ feat = self.backbones[i](x_mod, mask)
619
+ all_feats.append(feat)
620
+ cat_feats = torch.cat(all_feats, dim=1)
621
+ return self.classifier(cat_feats)
622
+
623
+
624
+ def build_model(backbone_name, fusion, input_dim, modality_dims, num_classes,
625
+ hidden_dim=128, proj_dim=0, late_agg='mean'):
626
+ """Factory function. proj_dim=0 means no projection (raw features)."""
627
+ if fusion == 'early':
628
+ return SingleModel(backbone_name, input_dim, num_classes, hidden_dim,
629
+ modality_dims=modality_dims, proj_dim=proj_dim)
630
+ elif fusion == 'late':
631
+ return LateFusionModel(backbone_name, modality_dims, num_classes, hidden_dim,
632
+ proj_dim, late_agg=late_agg)
633
+ elif fusion == 'attention':
634
+ return AttentionFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim)
635
+ elif fusion == 'weighted_late':
636
+ return WeightedLateFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim)
637
+ elif fusion == 'gated_late':
638
+ return GatedLateFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim)
639
+ elif fusion == 'stacking':
640
+ return StackingFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim)
641
+ elif fusion == 'product':
642
+ return ProductOfExpertsModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim)
643
+ elif fusion == 'moe':
644
+ return MoEFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim)
645
+ elif fusion == 'feat_concat':
646
+ return FeatureConcatFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim)
647
+ else:
648
+ raise ValueError(f"Unknown fusion: {fusion}")
experiments/nets/models_forecast.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Frame-level future forecasting models.
2
+
3
+ Three baselines (all sharing the same forecast head signature):
4
+ - TransformerForecast (our DAF-style)
5
+ - FUTRForecast (Transformer encoder + parallel query decoder)
6
+ - DeepConvLSTMForecast (Ordoñez & Roggen 2016 wearable HAR backbone)
7
+
8
+ All take a dict {mod: (B, T_obs, F_mod)} and output (B, T_fut, num_classes).
9
+ """
10
+ from __future__ import annotations
11
+ from typing import Dict, List
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+
18
+ # ---------------------------------------------------------------------------
19
+ # Shared per-modality projection: each modality -> hidden dim d_model
20
+ # ---------------------------------------------------------------------------
21
+
22
+ class _PerModalityProj(nn.Module):
23
+ def __init__(self, modality_dims: Dict[str, int], d_model: int):
24
+ super().__init__()
25
+ self.proj = nn.ModuleDict({
26
+ m: nn.Linear(d, d_model) for m, d in modality_dims.items()
27
+ })
28
+ self.mod_emb = nn.Parameter(torch.zeros(len(modality_dims), d_model))
29
+ nn.init.trunc_normal_(self.mod_emb, std=0.02)
30
+ self.mods = list(modality_dims.keys())
31
+
32
+ def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
33
+ # Concatenate per-modality projections along time? Or sum?
34
+ # We sum modality-projected features per time step (with modality
35
+ # embedding broadcast). Equivalent to early-fusion at the d_model
36
+ # space and is what a "modality-aware Transformer" typically uses.
37
+ out = None
38
+ for i, m in enumerate(self.mods):
39
+ h = self.proj[m](x[m]) + self.mod_emb[i]
40
+ out = h if out is None else out + h
41
+ return out / len(self.mods) # (B, T_obs, d_model)
42
+
43
+
44
+ # ---------------------------------------------------------------------------
45
+ # 1. Transformer (DAF-style) forecast model
46
+ # ---------------------------------------------------------------------------
47
+
48
+ class TransformerForecast(nn.Module):
49
+ def __init__(self, modality_dims: Dict[str, int], num_classes: int,
50
+ t_obs: int, t_fut: int, d_model: int = 128,
51
+ n_heads: int = 4, n_layers: int = 2, dropout: float = 0.1):
52
+ super().__init__()
53
+ self.t_obs = t_obs
54
+ self.t_fut = t_fut
55
+ self.num_classes = num_classes
56
+ self.embed = _PerModalityProj(modality_dims, d_model)
57
+ self.pos = nn.Parameter(torch.zeros(1, t_obs, d_model))
58
+ nn.init.trunc_normal_(self.pos, std=0.02)
59
+ layer = nn.TransformerEncoderLayer(
60
+ d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
61
+ dropout=dropout, batch_first=True, activation="gelu",
62
+ )
63
+ self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)
64
+ self.queries = nn.Parameter(torch.zeros(1, t_fut, d_model))
65
+ nn.init.trunc_normal_(self.queries, std=0.02)
66
+ self.cross_attn = nn.MultiheadAttention(
67
+ d_model, n_heads, dropout=dropout, batch_first=True
68
+ )
69
+ self.norm = nn.LayerNorm(d_model)
70
+ self.head = nn.Linear(d_model, num_classes)
71
+
72
+ def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
73
+ h = self.embed(x) + self.pos
74
+ h = self.encoder(h) # (B, T_obs, D)
75
+ q = self.queries.expand(h.size(0), -1, -1) # (B, T_fut, D)
76
+ out, _ = self.cross_attn(q, h, h, need_weights=False)
77
+ out = self.norm(out)
78
+ return self.head(out) # (B, T_fut, C)
79
+
80
+
81
+ # ---------------------------------------------------------------------------
82
+ # 2. FUTR-style forecast (Future Transformer, Gong et al. CVPR 2022)
83
+ # Same encoder + parallel query decoder. We add a small Transformer
84
+ # decoder so it's not literally identical to TransformerForecast.
85
+ # ---------------------------------------------------------------------------
86
+
87
+ class FUTRForecast(nn.Module):
88
+ def __init__(self, modality_dims: Dict[str, int], num_classes: int,
89
+ t_obs: int, t_fut: int, d_model: int = 128,
90
+ n_heads: int = 4, n_enc: int = 2, n_dec: int = 1,
91
+ dropout: float = 0.1):
92
+ super().__init__()
93
+ self.t_obs = t_obs
94
+ self.t_fut = t_fut
95
+ self.num_classes = num_classes
96
+ self.embed = _PerModalityProj(modality_dims, d_model)
97
+ self.pos = nn.Parameter(torch.zeros(1, t_obs, d_model))
98
+ nn.init.trunc_normal_(self.pos, std=0.02)
99
+ enc_layer = nn.TransformerEncoderLayer(
100
+ d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
101
+ dropout=dropout, batch_first=True, activation="gelu",
102
+ )
103
+ self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_enc)
104
+ dec_layer = nn.TransformerDecoderLayer(
105
+ d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
106
+ dropout=dropout, batch_first=True, activation="gelu",
107
+ )
108
+ self.decoder = nn.TransformerDecoder(dec_layer, num_layers=n_dec)
109
+ self.queries = nn.Parameter(torch.zeros(1, t_fut, d_model))
110
+ nn.init.trunc_normal_(self.queries, std=0.02)
111
+ self.head = nn.Linear(d_model, num_classes)
112
+
113
+ def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
114
+ memory = self.encoder(self.embed(x) + self.pos) # (B, T_obs, D)
115
+ q = self.queries.expand(memory.size(0), -1, -1) # (B, T_fut, D)
116
+ out = self.decoder(q, memory)
117
+ return self.head(out) # (B, T_fut, C)
118
+
119
+
120
+ # ---------------------------------------------------------------------------
121
+ # 3. DeepConvLSTM-style forecast
122
+ # ---------------------------------------------------------------------------
123
+
124
+ class DeepConvLSTMForecast(nn.Module):
125
+ def __init__(self, modality_dims: Dict[str, int], num_classes: int,
126
+ t_obs: int, t_fut: int, conv_filters: int = 64,
127
+ lstm_hidden: int = 128, n_lstm_layers: int = 2,
128
+ dropout: float = 0.1):
129
+ super().__init__()
130
+ self.t_obs = t_obs
131
+ self.t_fut = t_fut
132
+ self.num_classes = num_classes
133
+ self.mods = list(modality_dims.keys())
134
+ in_ch = sum(modality_dims.values())
135
+ # Same 4-layer conv stack as the original DeepConvLSTM
136
+ layers = []
137
+ ch = in_ch
138
+ for i in range(4):
139
+ layers.append(nn.Sequential(
140
+ nn.Conv1d(ch, conv_filters, kernel_size=5, padding=2),
141
+ nn.BatchNorm1d(conv_filters),
142
+ nn.ReLU(),
143
+ nn.Dropout(dropout if i < 3 else 0.2),
144
+ ))
145
+ ch = conv_filters
146
+ self.convs = nn.ModuleList(layers)
147
+ self.lstm = nn.LSTM(
148
+ conv_filters, lstm_hidden, num_layers=n_lstm_layers,
149
+ batch_first=True, dropout=dropout if n_lstm_layers > 1 else 0,
150
+ )
151
+ self.head = nn.Linear(lstm_hidden, t_fut * num_classes)
152
+
153
+ def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
154
+ h = torch.cat([x[m] for m in self.mods], dim=-1) # (B, T_obs, F_total)
155
+ h = h.permute(0, 2, 1) # (B, F, T_obs)
156
+ for c in self.convs:
157
+ h = c(h)
158
+ h = h.permute(0, 2, 1) # (B, T_obs, conv_filters)
159
+ out, (h_n, _) = self.lstm(h)
160
+ feat = h_n[-1] # (B, lstm_hidden)
161
+ logits = self.head(feat).view(-1, self.t_fut, self.num_classes)
162
+ return logits
163
+
164
+
165
+ # ---------------------------------------------------------------------------
166
+ # 4. RU-LSTM (Furnari et al. RAL 2019, "Rolling-Unrolling LSTM for action
167
+ # anticipation"). Two-phase LSTM: a "rolling" phase encodes past, an
168
+ # "unrolling" phase autoregressively decodes future tokens.
169
+ # ---------------------------------------------------------------------------
170
+
171
+ class RULSTMForecast(nn.Module):
172
+ def __init__(self, modality_dims: Dict[str, int], num_classes: int,
173
+ t_obs: int, t_fut: int, d_model: int = 128,
174
+ n_lstm_layers: int = 2, dropout: float = 0.1):
175
+ super().__init__()
176
+ self.t_obs = t_obs
177
+ self.t_fut = t_fut
178
+ self.num_classes = num_classes
179
+ self.embed = _PerModalityProj(modality_dims, d_model)
180
+ self.rolling = nn.LSTM(
181
+ d_model, d_model, num_layers=n_lstm_layers,
182
+ batch_first=True, dropout=dropout if n_lstm_layers > 1 else 0,
183
+ )
184
+ self.unrolling = nn.LSTM(
185
+ d_model, d_model, num_layers=n_lstm_layers,
186
+ batch_first=True, dropout=dropout if n_lstm_layers > 1 else 0,
187
+ )
188
+ self.fut_init = nn.Parameter(torch.zeros(1, 1, d_model))
189
+ nn.init.trunc_normal_(self.fut_init, std=0.02)
190
+ self.head = nn.Linear(d_model, num_classes)
191
+
192
+ def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
193
+ h_past = self.embed(x) # (B, T_obs, D)
194
+ _, (h_n, c_n) = self.rolling(h_past)
195
+ B = h_past.size(0)
196
+ # Use a learned initial future token, repeated T_fut times
197
+ fut_input = self.fut_init.expand(B, self.t_fut, -1)
198
+ out, _ = self.unrolling(fut_input, (h_n, c_n))
199
+ return self.head(out) # (B, T_fut, C)
200
+
201
+
202
+ # ---------------------------------------------------------------------------
203
+ # 5. AVT (Girdhar & Grauman ICCV 2021, "Anticipative Video Transformer").
204
+ # Causal Transformer over the concatenation of past + future tokens.
205
+ # ---------------------------------------------------------------------------
206
+
207
+ class AVTForecast(nn.Module):
208
+ def __init__(self, modality_dims: Dict[str, int], num_classes: int,
209
+ t_obs: int, t_fut: int, d_model: int = 128,
210
+ n_heads: int = 4, n_layers: int = 2, dropout: float = 0.1):
211
+ super().__init__()
212
+ self.t_obs = t_obs
213
+ self.t_fut = t_fut
214
+ self.num_classes = num_classes
215
+ self.embed = _PerModalityProj(modality_dims, d_model)
216
+ seq_len = t_obs + t_fut
217
+ self.pos = nn.Parameter(torch.zeros(1, seq_len, d_model))
218
+ nn.init.trunc_normal_(self.pos, std=0.02)
219
+ layer = nn.TransformerEncoderLayer(
220
+ d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
221
+ dropout=dropout, batch_first=True, activation="gelu",
222
+ )
223
+ self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)
224
+ self.fut_tokens = nn.Parameter(torch.zeros(1, t_fut, d_model))
225
+ nn.init.trunc_normal_(self.fut_tokens, std=0.02)
226
+ self.head = nn.Linear(d_model, num_classes)
227
+ # Causal mask over concatenated [past | future] sequence
228
+ mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
229
+ self.register_buffer("causal_mask", mask)
230
+
231
+ def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
232
+ h_past = self.embed(x) # (B, T_obs, D)
233
+ B = h_past.size(0)
234
+ h_fut = self.fut_tokens.expand(B, -1, -1) # (B, T_fut, D)
235
+ seq = torch.cat([h_past, h_fut], dim=1) + self.pos
236
+ out = self.encoder(seq, mask=self.causal_mask)
237
+ out_fut = out[:, self.t_obs:, :]
238
+ return self.head(out_fut) # (B, T_fut, C)
239
+
240
+
241
+ # ---------------------------------------------------------------------------
242
+ # Builder
243
+ # ---------------------------------------------------------------------------
244
+
245
+ def build_forecast_model(name: str, modality_dims: Dict[str, int],
246
+ num_classes: int, t_obs: int, t_fut: int,
247
+ d_model: int = 128, dropout: float = 0.1) -> nn.Module:
248
+ name = name.lower()
249
+ if name in ("daf", "transformer"):
250
+ return TransformerForecast(modality_dims, num_classes,
251
+ t_obs=t_obs, t_fut=t_fut,
252
+ d_model=d_model, dropout=dropout)
253
+ if name == "futr":
254
+ return FUTRForecast(modality_dims, num_classes,
255
+ t_obs=t_obs, t_fut=t_fut,
256
+ d_model=d_model, dropout=dropout)
257
+ if name == "deepconvlstm":
258
+ return DeepConvLSTMForecast(modality_dims, num_classes,
259
+ t_obs=t_obs, t_fut=t_fut,
260
+ dropout=dropout)
261
+ if name in ("rulstm", "ru-lstm", "ru_lstm"):
262
+ return RULSTMForecast(modality_dims, num_classes,
263
+ t_obs=t_obs, t_fut=t_fut,
264
+ d_model=d_model, dropout=dropout)
265
+ if name == "avt":
266
+ return AVTForecast(modality_dims, num_classes,
267
+ t_obs=t_obs, t_fut=t_fut,
268
+ d_model=d_model, dropout=dropout)
269
+ raise ValueError(f"Unknown forecast model: {name!r}")
experiments/nets/models_forecast_priv.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Models for T8 v3 — privileged future-pressure conditioning.
2
+
3
+ Wraps the existing TransformerForecast (DAF) to accept future pressure as
4
+ side-channel context. The future pressure trajectory is encoded into T_fut
5
+ tokens that get appended to the past memory; future queries cross-attend
6
+ over the union (past sensors + future pressure). This is privileged
7
+ information (oracle) — at test time we'd not have future pressure — so
8
+ this is a hypothesis-test setup, not a deployable forecaster.
9
+ """
10
+ from __future__ import annotations
11
+ from typing import Dict
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+
17
+ class _PerModalityProj(nn.Module):
18
+ def __init__(self, modality_dims, d_model):
19
+ super().__init__()
20
+ self.proj = nn.ModuleDict({
21
+ m: nn.Linear(d, d_model) for m, d in modality_dims.items()
22
+ })
23
+ self.mod_emb = nn.Parameter(torch.zeros(len(modality_dims), d_model))
24
+ nn.init.trunc_normal_(self.mod_emb, std=0.02)
25
+ self.mods = list(modality_dims.keys())
26
+
27
+ def forward(self, x):
28
+ out = None
29
+ for i, m in enumerate(self.mods):
30
+ h = self.proj[m](x[m]) + self.mod_emb[i]
31
+ out = h if out is None else out + h
32
+ return out / len(self.mods)
33
+
34
+
35
+ class DAFFuturePressure(nn.Module):
36
+ """DAF backbone + future-pressure conditioning."""
37
+
38
+ def __init__(self, modality_dims: Dict[str, int], target_dim: int,
39
+ t_obs: int, t_fut: int, future_pressure_dim: int = 50,
40
+ d_model: int = 128, n_heads: int = 4, n_layers: int = 2,
41
+ dropout: float = 0.1):
42
+ super().__init__()
43
+ self.t_obs = t_obs
44
+ self.t_fut = t_fut
45
+ self.embed = _PerModalityProj(modality_dims, d_model)
46
+ self.pos = nn.Parameter(torch.zeros(1, t_obs, d_model))
47
+ nn.init.trunc_normal_(self.pos, std=0.02)
48
+ layer = nn.TransformerEncoderLayer(
49
+ d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
50
+ dropout=dropout, batch_first=True, activation="gelu",
51
+ )
52
+ self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)
53
+ # future-pressure encoder
54
+ self.fp_proj = nn.Linear(future_pressure_dim, d_model)
55
+ self.fp_pos = nn.Parameter(torch.zeros(1, t_fut, d_model))
56
+ nn.init.trunc_normal_(self.fp_pos, std=0.02)
57
+ self.fp_seg = nn.Parameter(torch.zeros(1, 1, d_model)) # segment id
58
+ nn.init.trunc_normal_(self.fp_seg, std=0.02)
59
+ # decoder side
60
+ self.queries = nn.Parameter(torch.zeros(1, t_fut, d_model))
61
+ nn.init.trunc_normal_(self.queries, std=0.02)
62
+ self.cross_attn = nn.MultiheadAttention(
63
+ d_model, n_heads, dropout=dropout, batch_first=True
64
+ )
65
+ self.norm = nn.LayerNorm(d_model)
66
+ self.head = nn.Linear(d_model, target_dim)
67
+
68
+ def forward(self, x: Dict[str, torch.Tensor],
69
+ future_pressure: torch.Tensor) -> torch.Tensor:
70
+ h_past = self.encoder(self.embed(x) + self.pos) # (B, T_obs, D)
71
+ h_fp = self.fp_proj(future_pressure) + self.fp_pos + self.fp_seg
72
+ memory = torch.cat([h_past, h_fp], dim=1) # (B, T_obs+T_fut, D)
73
+ q = self.queries.expand(memory.size(0), -1, -1) # (B, T_fut, D)
74
+ out, _ = self.cross_attn(q, memory, memory, need_weights=False)
75
+ out = self.norm(out)
76
+ return self.head(out) # (B, T_fut, target_dim)
experiments/nets/models_seqpred.py ADDED
@@ -0,0 +1,806 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Models for T10 Triplet Next-Action Prediction.
3
+
4
+ Two classes live here:
5
+
6
+ * TripletHead — shared head module producing (verb_fine, verb_composite,
7
+ noun, hand) logits from a pooled feature vector.
8
+ * DeepConvLSTMTriplet — single-flow CNN+LSTM baseline (concatenates all
9
+ available modalities along the feature axis).
10
+ * DailyActFormer — our full-modality cross-modal Transformer that keeps
11
+ each modality in its own stem, fuses via a modality
12
+ token, and runs a causal temporal Transformer. Supports
13
+ the anticipatory auxiliary loss mentioned in the paper
14
+ plan (currently as a stub; enabled later in training).
15
+
16
+ All models take:
17
+ x: dict[mod_name -> (B, T, F_mod)]
18
+ mask: BoolTensor (B, T)
19
+ and return a dict:
20
+ {'verb_fine': (B, NUM_VERB_FINE),
21
+ 'verb_composite': (B, NUM_VERB_COMPOSITE),
22
+ 'noun': (B, NUM_NOUN),
23
+ 'hand': (B, NUM_HAND)}
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import math
29
+ import sys
30
+ from pathlib import Path
31
+ from typing import Dict, List, Optional, Sequence
32
+
33
+ import torch
34
+ import torch.nn as nn
35
+ import torch.nn.functional as F
36
+
37
+ # Importable from either (a) neurips26 root, or (b) frozen row/code/ folder.
38
+ _THIS = Path(__file__).resolve()
39
+ sys.path.insert(0, str(_THIS.parent))
40
+ sys.path.insert(0, str(_THIS.parent.parent))
41
+
42
+ try:
43
+ from experiments.taxonomy import (
44
+ NUM_VERB_FINE, NUM_VERB_COMPOSITE, NUM_NOUN, NUM_HAND,
45
+ )
46
+ except ModuleNotFoundError:
47
+ from taxonomy import (
48
+ NUM_VERB_FINE, NUM_VERB_COMPOSITE, NUM_NOUN, NUM_HAND,
49
+ )
50
+
51
+ # ---------------------------------------------------------------------------
52
+ # Shared triplet head
53
+ # ---------------------------------------------------------------------------
54
+
55
+ class _PrevActionConcat(nn.Module):
56
+ """Embeds the previous-segment (verb_composite, noun) ground-truth labels
57
+ and concatenates them to a pooled feature vector. Used by every model
58
+ when `use_prev_action=True`. The +1 vocab slot is the BOS / no-prev
59
+ sentinel emitted by the dataset for the first kept segment of each
60
+ recording. Output dim added to pooled = 2 * prev_emb_dim."""
61
+
62
+ def __init__(self, prev_emb_dim: int = 32):
63
+ super().__init__()
64
+ from taxonomy import NUM_VERB_COMPOSITE as _NVC, NUM_NOUN as _NN # noqa
65
+ self.vc_emb = nn.Embedding(_NVC + 1, prev_emb_dim)
66
+ self.n_emb = nn.Embedding(_NN + 1, prev_emb_dim)
67
+ self.out_dim = 2 * prev_emb_dim
68
+
69
+ def forward(self, pooled: torch.Tensor,
70
+ prev_v_comp: Optional[torch.Tensor] = None,
71
+ prev_noun: Optional[torch.Tensor] = None) -> torch.Tensor:
72
+ if prev_v_comp is None or prev_noun is None:
73
+ B = pooled.size(0)
74
+ prev_v_comp = torch.full((B,), self.vc_emb.num_embeddings - 1,
75
+ dtype=torch.long, device=pooled.device)
76
+ prev_noun = torch.full((B,), self.n_emb.num_embeddings - 1,
77
+ dtype=torch.long, device=pooled.device)
78
+ pe = torch.cat([self.vc_emb(prev_v_comp), self.n_emb(prev_noun)], dim=-1)
79
+ return torch.cat([pooled, pe], dim=-1)
80
+
81
+
82
+ class TripletHead(nn.Module):
83
+ def __init__(self, feat_dim: int, hidden: int = 256, dropout: float = 0.2):
84
+ super().__init__()
85
+ self.norm = nn.LayerNorm(feat_dim)
86
+ self.trunk = nn.Sequential(
87
+ nn.Linear(feat_dim, hidden),
88
+ nn.GELU(),
89
+ nn.Dropout(dropout),
90
+ )
91
+ self.verb_fine = nn.Linear(hidden, NUM_VERB_FINE)
92
+ self.verb_composite = nn.Linear(hidden, NUM_VERB_COMPOSITE)
93
+ self.noun = nn.Linear(hidden, NUM_NOUN)
94
+ self.hand = nn.Linear(hidden, NUM_HAND)
95
+
96
+ def forward(self, feat: torch.Tensor) -> Dict[str, torch.Tensor]:
97
+ h = self.trunk(self.norm(feat))
98
+ return {
99
+ "verb_fine": self.verb_fine(h),
100
+ "verb_composite": self.verb_composite(h),
101
+ "noun": self.noun(h),
102
+ "hand": self.hand(h),
103
+ }
104
+
105
+
106
+ def _masked_mean_pool(h: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
107
+ """Mean over the time axis of `h` (B, T, D) using a boolean mask (B, T)."""
108
+ m = mask.to(h.dtype).unsqueeze(-1)
109
+ return (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)
110
+
111
+
112
+ # ---------------------------------------------------------------------------
113
+ # Baseline: DeepConvLSTM (Ordonez & Roggen 2016) adapted for triplet prediction
114
+ # ---------------------------------------------------------------------------
115
+
116
+ class DeepConvLSTMTriplet(nn.Module):
117
+ """Single-flow CNN+LSTM. Concatenates per-modality features on F axis."""
118
+
119
+ def __init__(
120
+ self,
121
+ modality_dims: Dict[str, int],
122
+ conv_filters: int = 64,
123
+ conv_kernel: int = 5,
124
+ num_conv_layers: int = 4,
125
+ lstm_hidden: int = 128,
126
+ num_lstm_layers: int = 2,
127
+ dropout: float = 0.2,
128
+ head_hidden: int = 256,
129
+ use_prev_action: bool = False,
130
+ prev_emb_dim: int = 32,
131
+ ):
132
+ super().__init__()
133
+ self.modality_dims = dict(modality_dims)
134
+ self.use_prev_action = use_prev_action
135
+ in_ch = sum(modality_dims.values())
136
+
137
+ convs: List[nn.Module] = []
138
+ c = in_ch
139
+ for i in range(num_conv_layers):
140
+ convs.append(nn.Sequential(
141
+ nn.Conv1d(c, conv_filters, conv_kernel, padding=conv_kernel // 2),
142
+ nn.BatchNorm1d(conv_filters),
143
+ nn.ReLU(),
144
+ nn.Dropout(dropout if i < num_conv_layers - 1 else dropout + 0.1),
145
+ ))
146
+ c = conv_filters
147
+ self.convs = nn.Sequential(*convs)
148
+
149
+ self.lstm = nn.LSTM(
150
+ conv_filters, lstm_hidden, num_layers=num_lstm_layers,
151
+ batch_first=True, bidirectional=False,
152
+ dropout=dropout if num_lstm_layers > 1 else 0.0,
153
+ )
154
+ head_in = lstm_hidden
155
+ if use_prev_action:
156
+ self.prev_concat = _PrevActionConcat(prev_emb_dim)
157
+ head_in += self.prev_concat.out_dim
158
+ else:
159
+ self.prev_concat = None
160
+ self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout)
161
+
162
+ def forward(
163
+ self, x: Dict[str, torch.Tensor], mask: torch.Tensor,
164
+ prev_v_comp: Optional[torch.Tensor] = None,
165
+ prev_noun: Optional[torch.Tensor] = None,
166
+ ) -> Dict[str, torch.Tensor]:
167
+ feats = torch.cat([x[m] for m in x], dim=-1).transpose(1, 2)
168
+ feats = self.convs(feats).transpose(1, 2)
169
+ out, (h_n, _) = self.lstm(feats)
170
+ pooled = h_n[-1]
171
+ if self.use_prev_action:
172
+ pooled = self.prev_concat(pooled, prev_v_comp, prev_noun)
173
+ return self.head(pooled)
174
+
175
+
176
+ # ---------------------------------------------------------------------------
177
+ # Our model: DailyActFormer
178
+ # ---------------------------------------------------------------------------
179
+
180
+ class _ModalityStem(nn.Module):
181
+ """Multi-scale 1-D conv stem (kernels 3, 5, 9) per modality.
182
+
183
+ Borrowed from HandFormer (the top-1 baseline on T10 recognition): three
184
+ parallel convolutions capture fast (k=3, ~0.15s @ 20Hz), medium (k=5),
185
+ and slow (k=9, ~0.45s) temporal patterns. Output is a 1×1 fusion of
186
+ the three branches, projected back to d_model.
187
+ """
188
+
189
+ def __init__(self, in_dim: int, d_model: int, kernels=(3, 5, 9),
190
+ dropout: float = 0.1):
191
+ super().__init__()
192
+ self.kernels = kernels
193
+ self.branches = nn.ModuleList([
194
+ nn.Conv1d(in_dim, d_model, k, padding=k // 2) for k in kernels
195
+ ])
196
+ self.merge = nn.Sequential(
197
+ nn.GELU(),
198
+ nn.Conv1d(d_model * len(kernels), d_model, 1),
199
+ )
200
+ self.norm = nn.LayerNorm(d_model)
201
+ self.drop = nn.Dropout(dropout)
202
+
203
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
204
+ # x: (B, T, F_in) -> (B, F_in, T) for conv1d
205
+ z = x.transpose(1, 2)
206
+ multi = [c(z) for c in self.branches] # each (B, D, T)
207
+ h = self.merge(torch.cat(multi, dim=1)).transpose(1, 2) # (B, T, D)
208
+ return self.drop(self.norm(h))
209
+
210
+
211
+ class _QueryPool(nn.Module):
212
+ """Learnable-query cross-attention pooling (replaces mean pool).
213
+
214
+ Inspired by FUTR (the top-5 baseline winner): a single learnable query
215
+ cross-attends to the entire encoder output, producing one summary vector.
216
+ Compared to a plain mean pool this lets the model weight informative
217
+ frames more heavily.
218
+ """
219
+
220
+ def __init__(self, d_model: int, n_heads: int = 4, dropout: float = 0.1):
221
+ super().__init__()
222
+ self.q = nn.Parameter(torch.zeros(1, 1, d_model))
223
+ nn.init.trunc_normal_(self.q, std=0.02)
224
+ self.attn = nn.MultiheadAttention(
225
+ d_model, n_heads, dropout=dropout, batch_first=True,
226
+ )
227
+ self.norm = nn.LayerNorm(d_model)
228
+
229
+ def forward(self, h: torch.Tensor, key_padding_mask: Optional[torch.Tensor]):
230
+ # h: (B, T, D); key_padding_mask: (B, T) where True = pad-to-mask-out
231
+ B = h.size(0)
232
+ q = self.q.expand(B, -1, -1)
233
+ out, _ = self.attn(q, h, h, key_padding_mask=key_padding_mask,
234
+ need_weights=False)
235
+ return self.norm(out.squeeze(1))
236
+
237
+
238
+ class _CrossModalTemporalShift(nn.Module):
239
+ """Cross-modal temporal-shift attention between two modalities.
240
+
241
+ Motivation (paper case study, §sec:grasp-phase-main): EMG activation leads
242
+ motion onset by a sub-frame ~20ms in our 100Hz recordings. After the 5x
243
+ downsample to 20Hz, that lag is ~0.4 frames, but per-subject variability
244
+ plus slack in our segment annotations introduces a few frames of drift
245
+ that a fixed alignment cannot capture.
246
+
247
+ We learn a discrete temporal shift Δ ∈ {-max_shift, …, +max_shift} frames
248
+ applied to one of the two modalities (EMG by default), so the shifted
249
+ tokens align with the other branch (MoCap) before cross-modal fusion. The
250
+ shift is sampled via straight-through Gumbel-softmax during training; at
251
+ inference we take the argmax (deterministic).
252
+
253
+ Inputs are per-modality token sequences (B, T, D). Outputs the same shape.
254
+ Only the `shift_modality` branch is shifted; other modalities pass through.
255
+ """
256
+
257
+ def __init__(self, max_shift: int = 3, tau: float = 1.0):
258
+ super().__init__()
259
+ self.max_shift = max_shift
260
+ self.tau = tau
261
+ # Logits over 2*max_shift+1 categorical shift candidates.
262
+ self.shift_logits = nn.Parameter(torch.zeros(2 * max_shift + 1))
263
+
264
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
265
+ # x: (B, T, D); produce a shifted version that's a soft-blend over
266
+ # the shift dimension. Hard at inference, gumbel-softmax at training.
267
+ if self.training:
268
+ w = F.gumbel_softmax(self.shift_logits, tau=self.tau, hard=True, dim=-1)
269
+ else:
270
+ w = F.one_hot(self.shift_logits.argmax(),
271
+ num_classes=2 * self.max_shift + 1).float()
272
+ shifted = []
273
+ for i, s in enumerate(range(-self.max_shift, self.max_shift + 1)):
274
+ shifted.append(w[i] * torch.roll(x, shifts=s, dims=1))
275
+ return torch.stack(shifted, dim=0).sum(dim=0)
276
+
277
+
278
+ class _CausalTransformerBlock(nn.Module):
279
+ """Standard Transformer encoder block with a strictly causal attention mask."""
280
+
281
+ def __init__(self, d_model: int, n_heads: int, mlp_ratio: float = 4.0,
282
+ dropout: float = 0.1):
283
+ super().__init__()
284
+ self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout,
285
+ batch_first=True)
286
+ self.norm1 = nn.LayerNorm(d_model)
287
+ self.norm2 = nn.LayerNorm(d_model)
288
+ mlp_dim = int(d_model * mlp_ratio)
289
+ self.mlp = nn.Sequential(
290
+ nn.Linear(d_model, mlp_dim), nn.GELU(), nn.Dropout(dropout),
291
+ nn.Linear(mlp_dim, d_model), nn.Dropout(dropout),
292
+ )
293
+
294
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor,
295
+ key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor:
296
+ h = self.norm1(x)
297
+ h, _ = self.attn(h, h, h, attn_mask=attn_mask,
298
+ key_padding_mask=key_padding_mask, need_weights=False)
299
+ x = x + h
300
+ x = x + self.mlp(self.norm2(x))
301
+ return x
302
+
303
+
304
+ class DailyActFormer(nn.Module):
305
+ """Cross-modal Transformer that uses every available modality.
306
+
307
+ Architecture outline:
308
+ per-modality stem → learnable modality embedding →
309
+ concat across time (each frame -> M modality tokens) →
310
+ 1 fusion-layer cross-modal attention (compress M→1 per frame) →
311
+ temporal Transformer (bidirectional by default; causal when
312
+ `causal=True` for anticipation-style next-action prediction)
313
+ → pooled → TripletHead
314
+
315
+ For simplicity the fusion step is an attention pooling with learnable
316
+ queries, rather than a full cross-modal block. This keeps the parameter
317
+ count modest (2–4 M range with d_model=128).
318
+ """
319
+
320
+ def __init__(
321
+ self,
322
+ modality_dims: Dict[str, int],
323
+ d_model: int = 128,
324
+ n_layers: int = 4,
325
+ n_heads: int = 4,
326
+ dropout: float = 0.1,
327
+ head_hidden: int = 256,
328
+ max_T: int = 256,
329
+ causal: bool = False,
330
+ xshift_modality: Optional[str] = "emg",
331
+ xshift_max: int = 3,
332
+ use_prev_action: bool = False,
333
+ prev_emb_dim: int = 32,
334
+ ):
335
+ super().__init__()
336
+ self.modalities = list(modality_dims.keys())
337
+ self.causal = causal
338
+ self.use_prev_action = use_prev_action
339
+
340
+ # Prev-action concat (shared helper)
341
+ if use_prev_action:
342
+ self.prev_concat = _PrevActionConcat(prev_emb_dim)
343
+ self._prev_extra_dim = self.prev_concat.out_dim
344
+ else:
345
+ self.prev_concat = None
346
+ self._prev_extra_dim = 0
347
+
348
+ # 0) Cross-modal temporal-shift block on one branch (EMG by default).
349
+ # Disabled if `xshift_modality` is None or not present.
350
+ if xshift_modality is not None and xshift_modality in modality_dims:
351
+ self.xshift_modality = xshift_modality
352
+ self.xshift = _CrossModalTemporalShift(max_shift=xshift_max)
353
+ else:
354
+ self.xshift_modality = None
355
+ self.xshift = None
356
+
357
+ # 1) per-modality 1-D conv stems (each produces d_model features/frame)
358
+ self.stems = nn.ModuleDict({
359
+ m: _ModalityStem(F, d_model, dropout=dropout)
360
+ for m, F in modality_dims.items()
361
+ })
362
+
363
+ # 2) modality embedding (broadcast-add to per-modality tokens)
364
+ self.modality_embed = nn.Parameter(
365
+ torch.zeros(len(self.modalities), d_model)
366
+ )
367
+ nn.init.trunc_normal_(self.modality_embed, std=0.02)
368
+
369
+ # 3) per-frame cross-modal fusion: use a single learnable query token
370
+ self.fusion_q = nn.Parameter(torch.zeros(1, 1, d_model))
371
+ self.fusion_kv = nn.LayerNorm(d_model)
372
+ self.fusion_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
373
+
374
+ # 4) positional embedding along time (post-fusion)
375
+ self.pos_embed = nn.Parameter(torch.zeros(1, max_T, d_model))
376
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
377
+ self.max_T = max_T
378
+
379
+ # 5) causal temporal Transformer
380
+ self.temporal_norm = nn.LayerNorm(d_model)
381
+ self.temporal = nn.ModuleList([
382
+ _CausalTransformerBlock(d_model, n_heads, dropout=dropout)
383
+ for _ in range(n_layers)
384
+ ])
385
+
386
+ # 6) Pool: learnable-query cross-attention (replaces mean pool, FUTR-style)
387
+ self.pool = _QueryPool(d_model, n_heads=n_heads, dropout=dropout)
388
+
389
+ # 7) triplet head: input dim = d_model + (optional prev-action embed)
390
+ head_in = d_model + self._prev_extra_dim
391
+ self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout)
392
+
393
+ nn.init.trunc_normal_(self.fusion_q, std=0.02)
394
+
395
+ # ---- helpers ----
396
+ def _causal_mask(self, T: int, device) -> torch.Tensor:
397
+ # MultiheadAttention wants additive mask with -inf above diag.
398
+ m = torch.full((T, T), float("-inf"), device=device)
399
+ m.triu_(diagonal=1)
400
+ return m
401
+
402
+ # ---- forward ----
403
+ def forward(
404
+ self, x: Dict[str, torch.Tensor], mask: torch.Tensor,
405
+ prev_v_comp: Optional[torch.Tensor] = None,
406
+ prev_noun: Optional[torch.Tensor] = None,
407
+ return_features: bool = False,
408
+ ) -> Dict[str, torch.Tensor]:
409
+ # Stems: per-modality token streams
410
+ stem_tokens: List[torch.Tensor] = []
411
+ mods_in = [m for m in self.modalities if m in x]
412
+ if not mods_in:
413
+ raise ValueError("No modality from the model signature was provided.")
414
+ for i, m in enumerate(mods_in):
415
+ h = self.stems[m](x[m]) # (B, T, D)
416
+ # Cross-modal temporal shift: apply to one branch (e.g. EMG) so it
417
+ # aligns with the others before fusion. Implements paper SyncFuse's
418
+ # main novelty (sub-frame anticipatory coupling between EMG/MoCap).
419
+ if self.xshift is not None and m == self.xshift_modality:
420
+ h = self.xshift(h)
421
+ h = h + self.modality_embed[self.modalities.index(m)]
422
+ stem_tokens.append(h)
423
+
424
+ # Cross-modal fusion: per-frame, attend learnable query over the M stacked
425
+ # modality tokens. Output is (B, T, D).
426
+ B, T, D = stem_tokens[0].shape
427
+ # stack -> (B, T, M, D) -> reshape as (B*T, M, D)
428
+ stacked = torch.stack(stem_tokens, dim=2) # (B, T, M, D)
429
+ M = stacked.size(2)
430
+ stacked = stacked.reshape(B * T, M, D)
431
+ kv = self.fusion_kv(stacked)
432
+ q = self.fusion_q.expand(B * T, -1, -1)
433
+ fused, _ = self.fusion_attn(q, kv, kv, need_weights=False)
434
+ fused = fused.reshape(B, T, D) # (B, T, D)
435
+
436
+ # Positional embedding + causal temporal Transformer
437
+ if T > self.max_T:
438
+ raise ValueError(f"T={T} exceeds max_T={self.max_T}")
439
+ h = fused + self.pos_embed[:, :T, :]
440
+ h = self.temporal_norm(h)
441
+
442
+ attn_mask = self._causal_mask(T, h.device) if self.causal else None
443
+ key_padding = ~mask if mask is not None else None
444
+ for block in self.temporal:
445
+ h = block(h, attn_mask=attn_mask, key_padding_mask=key_padding)
446
+
447
+ # Pool: learnable-query cross-attention (FUTR-style) over valid frames
448
+ pooled = self.pool(h, key_padding_mask=key_padding)
449
+
450
+ # Optional: condition on previous segment's labels
451
+ if self.use_prev_action:
452
+ pooled = self.prev_concat(pooled, prev_v_comp, prev_noun)
453
+
454
+ logits = self.head(pooled)
455
+ if return_features:
456
+ logits["_pooled"] = pooled
457
+ return logits
458
+
459
+
460
+ # ===========================================================================
461
+ # Published baselines, sensor-adapted. Each keeps the original paper's key
462
+ # idea (rolling+unrolling LSTM for RULSTM, causal encoder–decoder for FUTR,
463
+ # early modality-token fusion for AFFT, etc.) but swaps the RGB/feature input
464
+ # for our multimodal sensor streams, and the classification head for our
465
+ # shared TripletHead.
466
+ # ===========================================================================
467
+
468
+
469
+ # ---------------------------------------------------------------------------
470
+ # RULSTM (Furnari & Farinella, TPAMI 2020) — sensor-adapted
471
+ # Per-modality rolling LSTM summarises the past, a second unrolling LSTM
472
+ # takes R-LSTM state and walks `future_steps` steps forward to mimic
473
+ # anticipation without needing future sensor data. Fusion is late: each
474
+ # modality produces logits, we average them.
475
+ # ---------------------------------------------------------------------------
476
+
477
+ class _RULSTMBranch(nn.Module):
478
+ def __init__(self, in_dim: int, hidden: int, future_steps: int,
479
+ dropout: float = 0.2):
480
+ super().__init__()
481
+ self.future_steps = future_steps
482
+ self.rolling = nn.LSTM(in_dim, hidden, batch_first=True)
483
+ self.unrolling = nn.LSTMCell(hidden, hidden)
484
+ self.drop = nn.Dropout(dropout)
485
+ self.out_dim = hidden
486
+
487
+ def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
488
+ # x: (B, T, F_in), mask: (B, T)
489
+ # Pack-free: LSTM on padded sequences is fine since we pool from h_n.
490
+ _, (h_n, c_n) = self.rolling(x) # (1, B, H)
491
+ h = h_n.squeeze(0); c = c_n.squeeze(0)
492
+ inp = h
493
+ for _ in range(self.future_steps):
494
+ h, c = self.unrolling(inp, (h, c))
495
+ inp = h
496
+ return self.drop(h)
497
+
498
+
499
+ class RULSTMTriplet(nn.Module):
500
+ def __init__(self, modality_dims: Dict[str, int], hidden: int = 128,
501
+ future_steps: int = 8, dropout: float = 0.2,
502
+ head_hidden: int = 256,
503
+ use_prev_action: bool = False, prev_emb_dim: int = 32):
504
+ super().__init__()
505
+ self.use_prev_action = use_prev_action
506
+ self.branches = nn.ModuleDict({
507
+ m: _RULSTMBranch(F, hidden, future_steps, dropout)
508
+ for m, F in modality_dims.items()
509
+ })
510
+ head_in = hidden
511
+ if use_prev_action:
512
+ self.prev_concat = _PrevActionConcat(prev_emb_dim)
513
+ head_in += self.prev_concat.out_dim
514
+ else:
515
+ self.prev_concat = None
516
+ self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout)
517
+
518
+ def forward(self, x, mask, prev_v_comp=None, prev_noun=None):
519
+ feats = []
520
+ for m in x:
521
+ feats.append(self.branches[m](x[m], mask))
522
+ fused = torch.stack(feats, dim=0).mean(dim=0)
523
+ if self.use_prev_action:
524
+ fused = self.prev_concat(fused, prev_v_comp, prev_noun)
525
+ return self.head(fused)
526
+
527
+
528
+ # ---------------------------------------------------------------------------
529
+ # FUTR (Gong et al., CVPR 2022) — sensor-adapted
530
+ # Transformer encoder over observation frames (with per-frame feature from
531
+ # concat(modalities)). A decoder query attends over the encoder memory to
532
+ # produce a single future-action embedding which is fed into the triplet
533
+ # head. No autoregressive decoding — we only predict 1 target segment.
534
+ # ---------------------------------------------------------------------------
535
+
536
+ class FUTRTriplet(nn.Module):
537
+ def __init__(self, modality_dims: Dict[str, int], d_model: int = 128,
538
+ n_heads: int = 4, n_layers: int = 3, dropout: float = 0.1,
539
+ head_hidden: int = 256, max_T: int = 256,
540
+ use_prev_action: bool = False, prev_emb_dim: int = 32):
541
+ super().__init__()
542
+ self.use_prev_action = use_prev_action
543
+ in_dim = sum(modality_dims.values())
544
+ self.in_proj = nn.Linear(in_dim, d_model)
545
+ self.pos = nn.Parameter(torch.zeros(1, max_T, d_model))
546
+ nn.init.trunc_normal_(self.pos, std=0.02)
547
+ self.max_T = max_T
548
+
549
+ enc_layer = nn.TransformerEncoderLayer(
550
+ d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
551
+ dropout=dropout, batch_first=True, activation="gelu",
552
+ )
553
+ self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
554
+
555
+ self.future_q = nn.Parameter(torch.zeros(1, 1, d_model))
556
+ nn.init.trunc_normal_(self.future_q, std=0.02)
557
+ self.cross_attn = nn.MultiheadAttention(
558
+ d_model, n_heads, dropout=dropout, batch_first=True,
559
+ )
560
+ head_in = d_model
561
+ if use_prev_action:
562
+ self.prev_concat = _PrevActionConcat(prev_emb_dim)
563
+ head_in += self.prev_concat.out_dim
564
+ else:
565
+ self.prev_concat = None
566
+ self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout)
567
+
568
+ def forward(self, x, mask, prev_v_comp=None, prev_noun=None):
569
+ feats = torch.cat([x[m] for m in x], dim=-1)
570
+ B, T, _ = feats.shape
571
+ if T > self.max_T:
572
+ raise ValueError(f"T={T} exceeds FUTR max_T={self.max_T}")
573
+ h = self.in_proj(feats) + self.pos[:, :T, :]
574
+ h = self.encoder(h, src_key_padding_mask=~mask)
575
+ q = self.future_q.expand(B, -1, -1)
576
+ out, _ = self.cross_attn(q, h, h, key_padding_mask=~mask,
577
+ need_weights=False)
578
+ pooled = out.squeeze(1)
579
+ if self.use_prev_action:
580
+ pooled = self.prev_concat(pooled, prev_v_comp, prev_noun)
581
+ return self.head(pooled)
582
+
583
+
584
+ # ---------------------------------------------------------------------------
585
+ # AFFT (Zhong et al., WACV 2023) — sensor-adapted
586
+ # Per-modality tokens (one per frame per modality) are concatenated into a
587
+ # long token sequence of length T*M and passed through an encoder with
588
+ # causal temporal attention so the model must anticipate strictly from the
589
+ # past. Fusion happens "anticipatively" inside the attention.
590
+ # ---------------------------------------------------------------------------
591
+
592
+ class AFFTTriplet(nn.Module):
593
+ def __init__(self, modality_dims: Dict[str, int], d_model: int = 96,
594
+ n_heads: int = 4, n_layers: int = 3, dropout: float = 0.1,
595
+ head_hidden: int = 256, max_T: int = 256,
596
+ use_prev_action: bool = False, prev_emb_dim: int = 32):
597
+ super().__init__()
598
+ self.use_prev_action = use_prev_action
599
+ self.modalities = list(modality_dims.keys())
600
+ self.stems = nn.ModuleDict({
601
+ m: nn.Linear(F, d_model) for m, F in modality_dims.items()
602
+ })
603
+ self.mod_embed = nn.Parameter(
604
+ torch.zeros(len(self.modalities), d_model)
605
+ )
606
+ nn.init.trunc_normal_(self.mod_embed, std=0.02)
607
+ self.pos = nn.Parameter(torch.zeros(1, max_T, d_model))
608
+ nn.init.trunc_normal_(self.pos, std=0.02)
609
+ self.max_T = max_T
610
+ self.d_model = d_model
611
+
612
+ self.blocks = nn.ModuleList([
613
+ _CausalTransformerBlock(d_model, n_heads, dropout=dropout)
614
+ for _ in range(n_layers)
615
+ ])
616
+ head_in = d_model
617
+ if use_prev_action:
618
+ self.prev_concat = _PrevActionConcat(prev_emb_dim)
619
+ head_in += self.prev_concat.out_dim
620
+ else:
621
+ self.prev_concat = None
622
+ self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout)
623
+
624
+ def _expand_causal_mask(self, T: int, M: int, device) -> torch.Tensor:
625
+ # Token layout: [m0_t0, m1_t0, ..., mM_t0, m0_t1, ..., mM_t(T-1)]
626
+ # Token at (m, t) can attend to all (m', t') with t' <= t.
627
+ ts = torch.arange(T, device=device).unsqueeze(1).expand(-1, M).reshape(-1)
628
+ return ts[:, None] < ts[None, :] # True where future (mask out)
629
+
630
+ def forward(self, x, mask, prev_v_comp=None, prev_noun=None):
631
+ # Build per-frame token streams.
632
+ mods = [m for m in self.modalities if m in x]
633
+ per_mod_tokens = []
634
+ B, T, _ = x[mods[0]].shape
635
+ for i, m in enumerate(mods):
636
+ h = self.stems[m](x[m]) + self.mod_embed[self.modalities.index(m)]
637
+ per_mod_tokens.append(h)
638
+ stacked = torch.stack(per_mod_tokens, dim=2)
639
+ M = stacked.size(2)
640
+ tokens = stacked.reshape(B, T * M, self.d_model)
641
+ if T > self.max_T:
642
+ raise ValueError(f"T={T} exceeds AFFT max_T={self.max_T}")
643
+ pos_per_frame = self.pos[:, :T, :].unsqueeze(2).expand(-1, -1, M, -1)
644
+ tokens = tokens + pos_per_frame.reshape(1, T * M, self.d_model)
645
+ attn_mask = self._expand_causal_mask(T, M, tokens.device)
646
+ attn_mask = torch.where(attn_mask, torch.tensor(float("-inf"),
647
+ device=tokens.device),
648
+ torch.tensor(0.0, device=tokens.device))
649
+ kp = (~mask).unsqueeze(2).expand(-1, -1, M).reshape(B, T * M)
650
+ for blk in self.blocks:
651
+ tokens = blk(tokens, attn_mask=attn_mask, key_padding_mask=kp)
652
+ last_slice = tokens[:, -M:, :]
653
+ pooled = last_slice.mean(dim=1)
654
+ if self.use_prev_action:
655
+ pooled = self.prev_concat(pooled, prev_v_comp, prev_noun)
656
+ return self.head(pooled)
657
+
658
+
659
+ # ---------------------------------------------------------------------------
660
+ # HandFormer (Shamil et al., ECCV 2024) — sensor-adapted
661
+ # Originally on 3D hand poses. We feed it only the MoCap modality (which
662
+ # contains 10 fingertip joints). Multi-scale 1-D conv over time, followed
663
+ # by a Transformer. If MoCap is not in `modalities`, falls back to whatever
664
+ # is provided (but then it's no longer the paper's "pose-only" setup).
665
+ # ---------------------------------------------------------------------------
666
+
667
+ class HandFormerTriplet(nn.Module):
668
+ def __init__(self, modality_dims: Dict[str, int], d_model: int = 128,
669
+ n_heads: int = 4, n_layers: int = 3, kernels=(3, 5, 9),
670
+ dropout: float = 0.1, head_hidden: int = 256, max_T: int = 256,
671
+ use_prev_action: bool = False, prev_emb_dim: int = 32):
672
+ super().__init__()
673
+ self.use_prev_action = use_prev_action
674
+ in_dim = sum(modality_dims.values())
675
+ self.multi_conv = nn.ModuleList([
676
+ nn.Conv1d(in_dim, d_model, k, padding=k // 2) for k in kernels
677
+ ])
678
+ self.conv_merge = nn.Conv1d(d_model * len(kernels), d_model, 1)
679
+
680
+ self.pos = nn.Parameter(torch.zeros(1, max_T, d_model))
681
+ nn.init.trunc_normal_(self.pos, std=0.02)
682
+ self.max_T = max_T
683
+
684
+ enc_layer = nn.TransformerEncoderLayer(
685
+ d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
686
+ dropout=dropout, batch_first=True, activation="gelu",
687
+ )
688
+ self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
689
+ head_in = d_model
690
+ if use_prev_action:
691
+ self.prev_concat = _PrevActionConcat(prev_emb_dim)
692
+ head_in += self.prev_concat.out_dim
693
+ else:
694
+ self.prev_concat = None
695
+ self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout)
696
+
697
+ def forward(self, x, mask, prev_v_comp=None, prev_noun=None):
698
+ feats = torch.cat([x[m] for m in x], dim=-1).transpose(1, 2)
699
+ multi = [c(feats) for c in self.multi_conv]
700
+ h = self.conv_merge(torch.cat(multi, dim=1))
701
+ h = h.transpose(1, 2)
702
+ T = h.size(1)
703
+ if T > self.max_T:
704
+ raise ValueError(f"T={T} exceeds HandFormer max_T={self.max_T}")
705
+ h = h + self.pos[:, :T, :]
706
+ h = self.encoder(h, src_key_padding_mask=~mask)
707
+ pooled = _masked_mean_pool(h, mask)
708
+ if self.use_prev_action:
709
+ pooled = self.prev_concat(pooled, prev_v_comp, prev_noun)
710
+ return self.head(pooled)
711
+
712
+
713
+ # ---------------------------------------------------------------------------
714
+ # Placeholder ActionLLM — a conv-stem sensor encoder + a 2-layer Transformer
715
+ # trained from scratch as a surrogate. The *full* LoRA+Qwen version lives in
716
+ # `train_pred.py` and can be wired in later if the surrogate is too weak.
717
+ # ---------------------------------------------------------------------------
718
+
719
+ class ActionLLMSurrogate(nn.Module):
720
+ def __init__(self, modality_dims: Dict[str, int], d_model: int = 192,
721
+ n_heads: int = 6, n_layers: int = 2, dropout: float = 0.1,
722
+ head_hidden: int = 256, max_T: int = 256,
723
+ use_prev_action: bool = False, prev_emb_dim: int = 32):
724
+ super().__init__()
725
+ self.use_prev_action = use_prev_action
726
+ in_dim = sum(modality_dims.values())
727
+ self.stem = nn.Sequential(
728
+ nn.Conv1d(in_dim, d_model, 5, padding=2),
729
+ nn.GELU(),
730
+ nn.Conv1d(d_model, d_model, 5, padding=2),
731
+ )
732
+ self.pos = nn.Parameter(torch.zeros(1, max_T, d_model))
733
+ nn.init.trunc_normal_(self.pos, std=0.02)
734
+ self.max_T = max_T
735
+ enc_layer = nn.TransformerEncoderLayer(
736
+ d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
737
+ dropout=dropout, batch_first=True, activation="gelu",
738
+ )
739
+ self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
740
+ head_in = d_model
741
+ if use_prev_action:
742
+ self.prev_concat = _PrevActionConcat(prev_emb_dim)
743
+ head_in += self.prev_concat.out_dim
744
+ else:
745
+ self.prev_concat = None
746
+ self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout)
747
+
748
+ def forward(self, x, mask, prev_v_comp=None, prev_noun=None):
749
+ feats = torch.cat([x[m] for m in x], dim=-1).transpose(1, 2)
750
+ h = self.stem(feats).transpose(1, 2)
751
+ T = h.size(1)
752
+ if T > self.max_T:
753
+ raise ValueError(f"T={T} exceeds ActionLLM max_T={self.max_T}")
754
+ h = h + self.pos[:, :T, :]
755
+ h = self.encoder(h, src_key_padding_mask=~mask)
756
+ pooled = _masked_mean_pool(h, mask)
757
+ if self.use_prev_action:
758
+ pooled = self.prev_concat(pooled, prev_v_comp, prev_noun)
759
+ return self.head(pooled)
760
+
761
+
762
+ # ---------------------------------------------------------------------------
763
+ # Factory
764
+ # ---------------------------------------------------------------------------
765
+
766
+ def build_model(
767
+ name: str, modality_dims: Dict[str, int], **kwargs,
768
+ ) -> nn.Module:
769
+ name = name.lower()
770
+ if name in ("deepconvlstm", "dcl"):
771
+ return DeepConvLSTMTriplet(modality_dims, **kwargs)
772
+ if name in ("dailyactformer", "ours", "daf"):
773
+ return DailyActFormer(modality_dims, **kwargs)
774
+ if name in ("rulstm",):
775
+ return RULSTMTriplet(modality_dims, **kwargs)
776
+ if name in ("futr",):
777
+ return FUTRTriplet(modality_dims, **kwargs)
778
+ if name in ("afft",):
779
+ return AFFTTriplet(modality_dims, **kwargs)
780
+ if name in ("handformer",):
781
+ return HandFormerTriplet(modality_dims, **kwargs)
782
+ if name in ("actionllm",):
783
+ return ActionLLMSurrogate(modality_dims, **kwargs)
784
+ raise ValueError(f"Unknown model: {name}")
785
+
786
+
787
+ # ---------------------------------------------------------------------------
788
+ # Smoke-test: build each model, run a random batch, check output shapes.
789
+ # ---------------------------------------------------------------------------
790
+
791
+ if __name__ == "__main__":
792
+ B, T = 2, 160
793
+ dims = {"imu": 180, "emg": 8, "eyetrack": 24}
794
+ x = {m: torch.randn(B, T, d) for m, d in dims.items()}
795
+ mask = torch.ones(B, T, dtype=torch.bool)
796
+
797
+ for name in ("deepconvlstm", "dailyactformer", "rulstm", "futr", "afft",
798
+ "handformer", "actionllm"):
799
+ model = build_model(name, dims)
800
+ n_params = sum(p.numel() for p in model.parameters())
801
+ out = model(x, mask)
802
+ print(f"{name:16s} params={n_params:>10,} shapes="
803
+ f"vf={tuple(out['verb_fine'].shape)} "
804
+ f"vc={tuple(out['verb_composite'].shape)} "
805
+ f"n={tuple(out['noun'].shape)} "
806
+ f"h={tuple(out['hand'].shape)}")
experiments/nets/published_models.py ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Published baseline models for NeurIPS 2026 benchmark experiments.
3
+
4
+ Contains faithful implementations of 6 published models:
5
+ 1. DeepConvLSTM (Ordonez & Roggen, Sensors 2016) - Exp1/Exp3
6
+ 2. InceptionTime (Fawaz et al., DMKD 2020) - Exp1/Exp3
7
+ 3. MS-TCN++ (Li et al., TPAMI 2020) - Exp2
8
+ 4. DiffAct (Liu et al., ICCV 2023) - Exp2
9
+ 5. UnderPressure (Mourot et al., SCA/CGF 2022) - Exp3/Exp4a
10
+ 6. emg2pose (Meta, NeurIPS 2024 D&B) - Exp4b
11
+ """
12
+
13
+ import math
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import numpy as np
18
+
19
+
20
+ # ============================================================
21
+ # 1. DeepConvLSTM (Ordonez & Roggen, Sensors 2016)
22
+ # "Deep Convolutional and LSTM Recurrent Neural Networks
23
+ # for Multimodal Wearable Activity Recognition"
24
+ # 4 Conv layers -> 2 LSTM layers -> pooling/per-frame output
25
+ # ============================================================
26
+
27
+ class DeepConvLSTMBackbone(nn.Module):
28
+ """DeepConvLSTM backbone for sequence-level classification (Exp1).
29
+
30
+ Input: (B, T, C), optional mask
31
+ Output: (B, output_dim)
32
+ """
33
+
34
+ def __init__(self, input_dim, hidden_dim=128, num_conv_layers=4,
35
+ conv_filters=64, conv_kernel=5, num_lstm_layers=2):
36
+ super().__init__()
37
+ conv_layers = []
38
+ in_ch = input_dim
39
+ for i in range(num_conv_layers):
40
+ out_ch = conv_filters
41
+ conv_layers.append(nn.Sequential(
42
+ nn.Conv1d(in_ch, out_ch, conv_kernel, padding=conv_kernel // 2),
43
+ nn.BatchNorm1d(out_ch),
44
+ nn.ReLU(),
45
+ nn.Dropout(0.1 if i < num_conv_layers - 1 else 0.2),
46
+ ))
47
+ in_ch = out_ch
48
+ self.convs = nn.ModuleList(conv_layers)
49
+
50
+ self.lstm = nn.LSTM(
51
+ conv_filters, hidden_dim, num_layers=num_lstm_layers,
52
+ batch_first=True, bidirectional=False,
53
+ dropout=0.2 if num_lstm_layers > 1 else 0,
54
+ )
55
+ self.output_dim = hidden_dim
56
+
57
+ def forward(self, x, mask=None):
58
+ # x: (B, T, C) -> Conv expects (B, C, T)
59
+ x = x.permute(0, 2, 1)
60
+ for conv in self.convs:
61
+ x = conv(x)
62
+ x = x.permute(0, 2, 1) # (B, T, conv_filters)
63
+
64
+ out, (h_n, _) = self.lstm(x)
65
+ # Use last hidden state
66
+ feat = h_n[-1] # (B, hidden_dim)
67
+ return feat
68
+
69
+
70
+ class DeepConvLSTMContact(nn.Module):
71
+ """DeepConvLSTM for frame-level contact detection (Exp3).
72
+
73
+ Input: (B, T, C)
74
+ Output: (B, T, 2)
75
+ """
76
+
77
+ def __init__(self, input_dim, hidden_dim=64, num_conv_layers=4,
78
+ conv_filters=64, conv_kernel=5):
79
+ super().__init__()
80
+ conv_layers = []
81
+ in_ch = input_dim
82
+ for i in range(num_conv_layers):
83
+ conv_layers.append(nn.Sequential(
84
+ nn.Conv1d(in_ch, conv_filters, conv_kernel, padding=conv_kernel // 2),
85
+ nn.BatchNorm1d(conv_filters),
86
+ nn.ReLU(),
87
+ nn.Dropout(0.1),
88
+ ))
89
+ in_ch = conv_filters
90
+ self.convs = nn.ModuleList(conv_layers)
91
+ self.lstm = nn.LSTM(conv_filters, hidden_dim, num_layers=2,
92
+ batch_first=True, bidirectional=True, dropout=0.2)
93
+ self.head = nn.Linear(hidden_dim * 2, 2)
94
+
95
+ def forward(self, x):
96
+ x = x.permute(0, 2, 1)
97
+ for conv in self.convs:
98
+ x = conv(x)
99
+ x = x.permute(0, 2, 1)
100
+ out, _ = self.lstm(x)
101
+ return self.head(out)
102
+
103
+
104
+ # ============================================================
105
+ # 2. InceptionTime (Fawaz et al., DMKD 2020)
106
+ # "InceptionTime: Finding AlexNet for Time Series Classification"
107
+ # Inception modules with multi-scale convolutions + residual
108
+ # ============================================================
109
+
110
+ class InceptionModule(nn.Module):
111
+ """Single Inception module for time series."""
112
+
113
+ def __init__(self, in_channels, n_filters=32, kernel_sizes=(9, 19, 39),
114
+ bottleneck_channels=32):
115
+ super().__init__()
116
+ # Bottleneck
117
+ self.bottleneck = nn.Conv1d(in_channels, bottleneck_channels, 1, bias=False)
118
+
119
+ # Parallel convolutions with different kernel sizes (odd kernels for symmetric padding)
120
+ self.convs = nn.ModuleList()
121
+ for ks in kernel_sizes:
122
+ self.convs.append(
123
+ nn.Conv1d(bottleneck_channels, n_filters, ks,
124
+ padding=(ks - 1) // 2, bias=False)
125
+ )
126
+
127
+ # MaxPool branch
128
+ self.maxpool_conv = nn.Sequential(
129
+ nn.MaxPool1d(3, stride=1, padding=1),
130
+ nn.Conv1d(in_channels, n_filters, 1, bias=False),
131
+ )
132
+
133
+ self.bn = nn.BatchNorm1d(n_filters * (len(kernel_sizes) + 1))
134
+ self.relu = nn.ReLU()
135
+
136
+ def forward(self, x):
137
+ # x: (B, C, T)
138
+ x_bottleneck = self.bottleneck(x)
139
+ conv_outputs = [conv(x_bottleneck) for conv in self.convs]
140
+ conv_outputs.append(self.maxpool_conv(x))
141
+ out = torch.cat(conv_outputs, dim=1)
142
+ return self.relu(self.bn(out))
143
+
144
+
145
+ class InceptionBlock(nn.Module):
146
+ """Stack of Inception modules with a residual connection."""
147
+
148
+ def __init__(self, in_channels, n_filters=32, depth=3):
149
+ super().__init__()
150
+ n_out = n_filters * 4 # 3 conv branches + 1 maxpool branch
151
+ modules = []
152
+ for i in range(depth):
153
+ inc = in_channels if i == 0 else n_out
154
+ modules.append(InceptionModule(inc, n_filters))
155
+ self.modules_list = nn.ModuleList(modules)
156
+
157
+ # Residual connection
158
+ self.use_residual = (in_channels != n_out)
159
+ if self.use_residual:
160
+ self.residual = nn.Sequential(
161
+ nn.Conv1d(in_channels, n_out, 1, bias=False),
162
+ nn.BatchNorm1d(n_out),
163
+ )
164
+ self.relu = nn.ReLU()
165
+
166
+ def forward(self, x):
167
+ residual = x
168
+ for mod in self.modules_list:
169
+ x = mod(x)
170
+ if self.use_residual:
171
+ residual = self.residual(residual)
172
+ return self.relu(x + residual)
173
+
174
+
175
+ class InceptionTimeBackbone(nn.Module):
176
+ """InceptionTime backbone for sequence-level classification (Exp1).
177
+
178
+ Input: (B, T, C), optional mask
179
+ Output: (B, output_dim)
180
+ """
181
+
182
+ def __init__(self, input_dim, hidden_dim=128, n_filters=32, num_blocks=2, depth=3):
183
+ super().__init__()
184
+ blocks = []
185
+ in_ch = input_dim
186
+ for i in range(num_blocks):
187
+ blocks.append(InceptionBlock(in_ch, n_filters, depth))
188
+ in_ch = n_filters * 4
189
+ self.blocks = nn.ModuleList(blocks)
190
+ self.output_dim = n_filters * 4
191
+
192
+ def forward(self, x, mask=None):
193
+ # x: (B, T, C) -> (B, C, T)
194
+ x = x.permute(0, 2, 1)
195
+ for block in self.blocks:
196
+ x = block(x)
197
+ # Global average pooling with mask
198
+ if mask is not None:
199
+ x = (x * mask.unsqueeze(1).float()).sum(2) / mask.sum(1, keepdim=True).float().clamp(min=1)
200
+ else:
201
+ x = x.mean(2)
202
+ return x # (B, n_filters*4)
203
+
204
+
205
+ class InceptionTimeContact(nn.Module):
206
+ """InceptionTime for frame-level contact detection (Exp3).
207
+
208
+ Input: (B, T, C)
209
+ Output: (B, T, 2)
210
+ """
211
+
212
+ def __init__(self, input_dim, hidden_dim=64, n_filters=32, num_blocks=2, depth=3):
213
+ super().__init__()
214
+ blocks = []
215
+ in_ch = input_dim
216
+ for i in range(num_blocks):
217
+ blocks.append(InceptionBlock(in_ch, n_filters, depth))
218
+ in_ch = n_filters * 4
219
+ self.blocks = nn.ModuleList(blocks)
220
+ self.head = nn.Conv1d(n_filters * 4, 2, 1)
221
+
222
+ def forward(self, x):
223
+ x = x.permute(0, 2, 1)
224
+ for block in self.blocks:
225
+ x = block(x)
226
+ out = self.head(x)
227
+ return out.permute(0, 2, 1) # (B, T, 2)
228
+
229
+
230
+ # ============================================================
231
+ # 3. MS-TCN++ (Li et al., TPAMI 2020)
232
+ # "MS-TCN++: Multi-Stage Temporal Convolutional Network
233
+ # for Action Segmentation"
234
+ # Key improvement: dual dilated layers in each residual block
235
+ # ============================================================
236
+
237
+ class DualDilatedResBlock(nn.Module):
238
+ """Dual dilated residual block (MS-TCN++ key contribution).
239
+
240
+ Uses two parallel dilated convolutions with different dilation rates
241
+ to capture both short-range and long-range temporal patterns.
242
+ """
243
+
244
+ def __init__(self, channels, dilation1, dilation2):
245
+ super().__init__()
246
+ # Branch 1: smaller dilation
247
+ self.conv1_dilated = nn.Conv1d(
248
+ channels, channels, 3,
249
+ padding=dilation1, dilation=dilation1
250
+ )
251
+ # Branch 2: larger dilation
252
+ self.conv2_dilated = nn.Conv1d(
253
+ channels, channels, 3,
254
+ padding=dilation2, dilation=dilation2
255
+ )
256
+ self.conv_fusion = nn.Conv1d(channels, channels, 1)
257
+ self.bn = nn.BatchNorm1d(channels)
258
+ self.dropout = nn.Dropout(0.3)
259
+
260
+ def forward(self, x):
261
+ residual = x
262
+ out1 = F.relu(self.conv1_dilated(x))
263
+ out2 = F.relu(self.conv2_dilated(x))
264
+ out = out1 + out2
265
+ out = self.dropout(F.relu(self.bn(self.conv_fusion(out))))
266
+ return out + residual
267
+
268
+
269
+ class MSTCNPPStage(nn.Module):
270
+ """Single stage of MS-TCN++ with dual dilated layers."""
271
+
272
+ def __init__(self, in_channels, hidden_channels, num_classes, num_layers=10):
273
+ super().__init__()
274
+ self.input_conv = nn.Conv1d(in_channels, hidden_channels, 1)
275
+ self.layers = nn.ModuleList()
276
+ for i in range(num_layers):
277
+ dilation1 = 2 ** i
278
+ dilation2 = 2 ** (i + 1) if i < num_layers - 1 else 2 ** i
279
+ self.layers.append(DualDilatedResBlock(hidden_channels, dilation1, dilation2))
280
+ self.output_conv = nn.Conv1d(hidden_channels, num_classes, 1)
281
+
282
+ def forward(self, x):
283
+ x = self.input_conv(x)
284
+ for layer in self.layers:
285
+ x = layer(x)
286
+ return self.output_conv(x)
287
+
288
+
289
+ class MSTCNPP(nn.Module):
290
+ """MS-TCN++ for temporal action segmentation (Exp2).
291
+
292
+ Input: (B, T, C)
293
+ Output: list of (B, T, num_classes) per stage
294
+ """
295
+
296
+ def __init__(self, input_dim, num_classes, hidden_dim=64, num_stages=4, num_layers=10):
297
+ super().__init__()
298
+ self.stages = nn.ModuleList()
299
+ # First stage: input features -> predictions
300
+ self.stages.append(MSTCNPPStage(input_dim, hidden_dim, num_classes, num_layers))
301
+ # Refinement stages: predictions -> refined predictions
302
+ for _ in range(num_stages - 1):
303
+ self.stages.append(MSTCNPPStage(num_classes, hidden_dim, num_classes, num_layers))
304
+
305
+ def forward(self, x):
306
+ x = x.permute(0, 2, 1) # (B, C, T)
307
+ outputs = []
308
+ for stage in self.stages:
309
+ x = stage(x)
310
+ outputs.append(x.permute(0, 2, 1)) # (B, T, num_classes)
311
+ # Feed softmax of predictions to next stage
312
+ if stage != self.stages[-1]:
313
+ x = F.softmax(x, dim=1)
314
+ return outputs
315
+
316
+
317
+ # ============================================================
318
+ # 4. DiffAct (Liu et al., ICCV 2023)
319
+ # "Diffusion Action Segmentation"
320
+ # Denoising diffusion model for iterative action refinement.
321
+ # Simplified but faithful implementation.
322
+ # ============================================================
323
+
324
+ class ConditionalLayerNorm(nn.Module):
325
+ """Layer norm conditioned on diffusion timestep."""
326
+
327
+ def __init__(self, channels):
328
+ super().__init__()
329
+ self.norm = nn.GroupNorm(1, channels) # equivalent to LayerNorm for 1D
330
+
331
+ def forward(self, x):
332
+ return self.norm(x)
333
+
334
+
335
+ class DiffActBlock(nn.Module):
336
+ """Residual block for DiffAct denoising network."""
337
+
338
+ def __init__(self, channels, dilation, time_emb_dim):
339
+ super().__init__()
340
+ self.conv1 = nn.Conv1d(channels, channels, 3, padding=dilation, dilation=dilation)
341
+ self.conv2 = nn.Conv1d(channels, channels, 1)
342
+ self.norm1 = ConditionalLayerNorm(channels)
343
+ self.norm2 = ConditionalLayerNorm(channels)
344
+ self.time_proj = nn.Linear(time_emb_dim, channels)
345
+ self.dropout = nn.Dropout(0.1)
346
+
347
+ def forward(self, x, time_emb):
348
+ residual = x
349
+ x = self.norm1(x)
350
+ x = F.relu(self.conv1(x))
351
+ # Add time embedding
352
+ t = self.time_proj(time_emb).unsqueeze(-1) # (B, C, 1)
353
+ x = x + t
354
+ x = self.norm2(x)
355
+ x = self.dropout(F.relu(self.conv2(x)))
356
+ return x + residual
357
+
358
+
359
+ class DiffActConditionEncoder(nn.Module):
360
+ """Temporal feature encoder for conditioning the denoising network."""
361
+
362
+ def __init__(self, input_dim, hidden_dim, num_layers=6):
363
+ super().__init__()
364
+ self.input_conv = nn.Conv1d(input_dim, hidden_dim, 1)
365
+ self.layers = nn.ModuleList()
366
+ for i in range(num_layers):
367
+ dilation = 2 ** (i % 5)
368
+ self.layers.append(nn.Sequential(
369
+ nn.Conv1d(hidden_dim, hidden_dim, 3, padding=dilation, dilation=dilation),
370
+ nn.BatchNorm1d(hidden_dim),
371
+ nn.ReLU(),
372
+ nn.Dropout(0.1),
373
+ ))
374
+
375
+ def forward(self, x):
376
+ x = self.input_conv(x)
377
+ for layer in self.layers:
378
+ x = layer(x) + x # residual
379
+ return x
380
+
381
+
382
+ class SinusoidalTimeEmbedding(nn.Module):
383
+ """Sinusoidal positional embedding for diffusion timestep."""
384
+
385
+ def __init__(self, dim):
386
+ super().__init__()
387
+ self.dim = dim
388
+ self.mlp = nn.Sequential(
389
+ nn.Linear(dim, dim * 4),
390
+ nn.GELU(),
391
+ nn.Linear(dim * 4, dim),
392
+ )
393
+
394
+ def forward(self, t):
395
+ half_dim = self.dim // 2
396
+ emb = math.log(10000) / (half_dim - 1)
397
+ emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
398
+ emb = t.unsqueeze(-1).float() * emb.unsqueeze(0)
399
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
400
+ return self.mlp(emb)
401
+
402
+
403
+ class DiffAct(nn.Module):
404
+ """DiffAct: Diffusion Action Segmentation (Exp2).
405
+
406
+ During training: noises ground-truth action probabilities and denoises.
407
+ During inference: iteratively denoises from pure noise.
408
+
409
+ Input: (B, T, C)
410
+ Output: list of (B, T, num_classes) [final denoised prediction]
411
+ """
412
+
413
+ def __init__(self, input_dim, num_classes, hidden_dim=64,
414
+ num_encoder_layers=6, num_denoise_layers=6,
415
+ num_diffusion_steps=10):
416
+ super().__init__()
417
+ self.num_classes = num_classes
418
+ self.num_steps = num_diffusion_steps
419
+
420
+ # Condition encoder: extract temporal features from input
421
+ self.condition_encoder = DiffActConditionEncoder(input_dim, hidden_dim, num_encoder_layers)
422
+
423
+ # Initial prediction head (non-diffusion baseline)
424
+ self.initial_head = nn.Conv1d(hidden_dim, num_classes, 1)
425
+
426
+ # Time embedding
427
+ self.time_emb = SinusoidalTimeEmbedding(hidden_dim)
428
+
429
+ # Denoising network
430
+ self.denoise_input = nn.Conv1d(num_classes + hidden_dim, hidden_dim, 1)
431
+ self.denoise_blocks = nn.ModuleList()
432
+ for i in range(num_denoise_layers):
433
+ dilation = 2 ** (i % 5)
434
+ self.denoise_blocks.append(DiffActBlock(hidden_dim, dilation, hidden_dim))
435
+ self.denoise_output = nn.Conv1d(hidden_dim, num_classes, 1)
436
+
437
+ # Noise schedule (cosine)
438
+ self._setup_noise_schedule()
439
+
440
+ def _setup_noise_schedule(self):
441
+ steps = self.num_steps
442
+ s = 0.008
443
+ t = torch.linspace(0, steps, steps + 1)
444
+ alphas_cumprod = torch.cos(((t / steps) + s) / (1 + s) * math.pi * 0.5) ** 2
445
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
446
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
447
+ betas = torch.clamp(betas, 0.0001, 0.999)
448
+ alphas = 1.0 - betas
449
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
450
+ self.register_buffer('betas', betas)
451
+ self.register_buffer('alphas_cumprod', alphas_cumprod)
452
+ self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
453
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1 - alphas_cumprod))
454
+
455
+ def _add_noise(self, x_start, t, noise=None):
456
+ """Add noise to x_start at timestep t."""
457
+ if noise is None:
458
+ noise = torch.randn_like(x_start)
459
+ sqrt_alpha = self.sqrt_alphas_cumprod[t].view(-1, 1, 1)
460
+ sqrt_one_minus = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1)
461
+ return sqrt_alpha * x_start + sqrt_one_minus * noise
462
+
463
+ def _denoise_step(self, x_noisy, cond_features, time_emb):
464
+ """Single denoising step."""
465
+ x = torch.cat([x_noisy, cond_features], dim=1) # (B, C+hidden, T)
466
+ x = self.denoise_input(x)
467
+ for block in self.denoise_blocks:
468
+ x = block(x, time_emb)
469
+ return self.denoise_output(x)
470
+
471
+ def forward(self, x):
472
+ """
473
+ Training: returns [initial_pred, denoised_pred]
474
+ Inference: returns [initial_pred, iteratively_denoised_pred]
475
+ """
476
+ x_in = x.permute(0, 2, 1) # (B, C, T)
477
+ B, _, T = x_in.shape
478
+
479
+ # Encode condition features
480
+ cond = self.condition_encoder(x_in) # (B, hidden, T)
481
+ initial_logits = self.initial_head(cond).permute(0, 2, 1) # (B, T, num_classes)
482
+
483
+ if self.training:
484
+ # Training: noise the initial prediction and denoise (end-to-end)
485
+ x_start = F.softmax(initial_logits, dim=-1).permute(0, 2, 1) # (B, C, T)
486
+ t = torch.randint(0, self.num_steps, (B,), device=x.device)
487
+ noise = torch.randn_like(x_start)
488
+ x_noisy = self._add_noise(x_start.detach(), t, noise)
489
+ time_emb = self.time_emb(t)
490
+ denoised = self._denoise_step(x_noisy, cond, time_emb)
491
+ return [initial_logits, denoised.permute(0, 2, 1)]
492
+ else:
493
+ # Inference: iterative denoising from noise
494
+ x_t = torch.randn(B, self.num_classes, T, device=x.device)
495
+ for step in reversed(range(self.num_steps)):
496
+ t = torch.full((B,), step, device=x.device, dtype=torch.long)
497
+ time_emb = self.time_emb(t)
498
+ pred_noise = self._denoise_step(x_t, cond, time_emb)
499
+ # Simplified DDPM update
500
+ alpha = self.alphas_cumprod[step]
501
+ alpha_prev = self.alphas_cumprod[step - 1] if step > 0 else torch.tensor(1.0)
502
+ beta = self.betas[step]
503
+ x_t = (1 / torch.sqrt(1 - beta)) * (
504
+ x_t - beta / self.sqrt_one_minus_alphas_cumprod[step] * pred_noise
505
+ )
506
+ if step > 0:
507
+ x_t = x_t + torch.sqrt(beta) * torch.randn_like(x_t) * 0.5
508
+ return [initial_logits, x_t.permute(0, 2, 1)]
509
+
510
+
511
+ # ============================================================
512
+ # 5. UnderPressure (Mourot et al., SCA/CGF 2022)
513
+ # "UnderPressure: Deep Learning for Foot Contact Detection,
514
+ # Ground Reaction Force Estimation and Footskate Cleanup"
515
+ # GRU-based architecture for contact detection + force regression.
516
+ # Adapted for hand contact detection and MoCap->Pressure prediction.
517
+ # ============================================================
518
+
519
+ class UnderPressureContact(nn.Module):
520
+ """UnderPressure model adapted for hand contact detection (Exp3).
521
+
522
+ Architecture: Conv feature extractor -> BiGRU -> contact prediction head
523
+ Input: (B, T, C)
524
+ Output: (B, T, 2) [right_contact, left_contact]
525
+ """
526
+
527
+ def __init__(self, input_dim, hidden_dim=64, num_gru_layers=2):
528
+ super().__init__()
529
+ # Feature extractor (conv layers for local temporal patterns)
530
+ self.feature_extractor = nn.Sequential(
531
+ nn.Conv1d(input_dim, hidden_dim, 7, padding=3),
532
+ nn.BatchNorm1d(hidden_dim),
533
+ nn.ReLU(),
534
+ nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
535
+ nn.BatchNorm1d(hidden_dim),
536
+ nn.ReLU(),
537
+ )
538
+ # BiGRU for temporal modeling
539
+ self.gru = nn.GRU(
540
+ hidden_dim, hidden_dim, num_layers=num_gru_layers,
541
+ batch_first=True, bidirectional=True,
542
+ dropout=0.2 if num_gru_layers > 1 else 0,
543
+ )
544
+ # Contact prediction head
545
+ self.contact_head = nn.Sequential(
546
+ nn.Linear(hidden_dim * 2, hidden_dim),
547
+ nn.ReLU(),
548
+ nn.Dropout(0.2),
549
+ nn.Linear(hidden_dim, 2),
550
+ )
551
+
552
+ def forward(self, x):
553
+ # x: (B, T, C) -> (B, C, T)
554
+ feat = self.feature_extractor(x.permute(0, 2, 1))
555
+ feat = feat.permute(0, 2, 1) # (B, T, hidden)
556
+ gru_out, _ = self.gru(feat)
557
+ return self.contact_head(gru_out) # (B, T, 2)
558
+
559
+
560
+ class UnderPressureRegressor(nn.Module):
561
+ """UnderPressure model adapted for MoCap -> Pressure regression (Exp4a).
562
+
563
+ Architecture: Conv feature extractor -> BiGRU -> pressure regression head
564
+ Input: (B, T, input_dim)
565
+ Output: (B, T, output_dim)
566
+ """
567
+
568
+ def __init__(self, input_dim, output_dim, hidden_dim=128, num_gru_layers=2):
569
+ super().__init__()
570
+ self.feature_extractor = nn.Sequential(
571
+ nn.Conv1d(input_dim, hidden_dim, 7, padding=3),
572
+ nn.BatchNorm1d(hidden_dim),
573
+ nn.ReLU(),
574
+ nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
575
+ nn.BatchNorm1d(hidden_dim),
576
+ nn.ReLU(),
577
+ nn.Conv1d(hidden_dim, hidden_dim, 3, padding=1),
578
+ nn.BatchNorm1d(hidden_dim),
579
+ nn.ReLU(),
580
+ )
581
+ self.gru = nn.GRU(
582
+ hidden_dim, hidden_dim, num_layers=num_gru_layers,
583
+ batch_first=True, bidirectional=True,
584
+ dropout=0.2 if num_gru_layers > 1 else 0,
585
+ )
586
+ self.regression_head = nn.Sequential(
587
+ nn.Linear(hidden_dim * 2, hidden_dim),
588
+ nn.ReLU(),
589
+ nn.Dropout(0.2),
590
+ nn.Linear(hidden_dim, output_dim),
591
+ )
592
+
593
+ def forward(self, x):
594
+ feat = self.feature_extractor(x.permute(0, 2, 1))
595
+ feat = feat.permute(0, 2, 1)
596
+ gru_out, _ = self.gru(feat)
597
+ return self.regression_head(gru_out)
598
+
599
+
600
+ # ============================================================
601
+ # 6. emg2pose (Meta/Facebook Research, NeurIPS 2024 D&B)
602
+ # "emg2pose: A Large and Diverse Benchmark for
603
+ # Surface Electromyographic Hand Pose Estimation"
604
+ # CNN feature extractor + Transformer encoder,
605
+ # with optional velocity-based integration (vemg2pose).
606
+ # ============================================================
607
+
608
+ class EMG2PoseEncoder(nn.Module):
609
+ """CNN + Transformer encoder from emg2pose."""
610
+
611
+ def __init__(self, input_dim, hidden_dim=128, num_transformer_layers=4, nhead=4):
612
+ super().__init__()
613
+ # Multi-scale CNN feature extractor
614
+ self.conv_small = nn.Sequential(
615
+ nn.Conv1d(input_dim, hidden_dim // 2, 3, padding=1),
616
+ nn.BatchNorm1d(hidden_dim // 2),
617
+ nn.ReLU(),
618
+ )
619
+ self.conv_medium = nn.Sequential(
620
+ nn.Conv1d(input_dim, hidden_dim // 4, 7, padding=3),
621
+ nn.BatchNorm1d(hidden_dim // 4),
622
+ nn.ReLU(),
623
+ )
624
+ self.conv_large = nn.Sequential(
625
+ nn.Conv1d(input_dim, hidden_dim // 4, 15, padding=7),
626
+ nn.BatchNorm1d(hidden_dim // 4),
627
+ nn.ReLU(),
628
+ )
629
+ # Projection to hidden_dim
630
+ self.proj = nn.Sequential(
631
+ nn.Conv1d(hidden_dim, hidden_dim, 1),
632
+ nn.BatchNorm1d(hidden_dim),
633
+ nn.ReLU(),
634
+ )
635
+ # Transformer encoder for temporal modeling
636
+ encoder_layer = nn.TransformerEncoderLayer(
637
+ d_model=hidden_dim, nhead=nhead,
638
+ dim_feedforward=hidden_dim * 4,
639
+ dropout=0.1, batch_first=True,
640
+ )
641
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_transformer_layers)
642
+
643
+ def forward(self, x):
644
+ # x: (B, T, C) -> (B, C, T)
645
+ x_t = x.permute(0, 2, 1)
646
+ f_small = self.conv_small(x_t)
647
+ f_medium = self.conv_medium(x_t)
648
+ f_large = self.conv_large(x_t)
649
+ feat = torch.cat([f_small, f_medium, f_large], dim=1)
650
+ feat = self.proj(feat).permute(0, 2, 1) # (B, T, hidden)
651
+ return self.transformer(feat)
652
+
653
+
654
+ class EMG2Pose(nn.Module):
655
+ """emg2pose model for EMG -> Hand Pose regression (Exp4b).
656
+
657
+ Predicts per-frame hand joint positions from EMG signals.
658
+ Uses velocity-based integration (vemg2pose variant):
659
+ predict velocity -> integrate to get positions.
660
+
661
+ Input: (B, T, input_dim) [EMG channels]
662
+ Output: (B, T, output_dim) [hand joint positions]
663
+ """
664
+
665
+ def __init__(self, input_dim, output_dim, hidden_dim=128,
666
+ num_transformer_layers=4, use_velocity=True):
667
+ super().__init__()
668
+ self.use_velocity = use_velocity
669
+ self.encoder = EMG2PoseEncoder(input_dim, hidden_dim, num_transformer_layers)
670
+
671
+ if use_velocity:
672
+ # Predict velocity, then integrate
673
+ self.velocity_head = nn.Sequential(
674
+ nn.Linear(hidden_dim, hidden_dim // 2),
675
+ nn.ReLU(),
676
+ nn.Dropout(0.1),
677
+ nn.Linear(hidden_dim // 2, output_dim),
678
+ )
679
+ # Learnable initial position
680
+ self.initial_pos = nn.Parameter(torch.zeros(1, 1, output_dim))
681
+ else:
682
+ # Direct position prediction
683
+ self.position_head = nn.Sequential(
684
+ nn.Linear(hidden_dim, hidden_dim // 2),
685
+ nn.ReLU(),
686
+ nn.Dropout(0.1),
687
+ nn.Linear(hidden_dim // 2, output_dim),
688
+ )
689
+
690
+ def forward(self, x):
691
+ features = self.encoder(x) # (B, T, hidden)
692
+
693
+ if self.use_velocity:
694
+ velocity = self.velocity_head(features) # (B, T, output_dim)
695
+ # Cumulative sum to integrate velocity -> position
696
+ positions = torch.cumsum(velocity, dim=1) + self.initial_pos
697
+ return positions
698
+ else:
699
+ return self.position_head(features)
experiments/s9_primitives.json ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "s9_docx_2025_12_05",
3
+ "source": "${PULSE_ROOT}",
4
+ "categories": ["hand", "arm", "body", "fine", "composite"],
5
+ "primitives": [
6
+ {"id": 0, "category": "hand", "zh": "伸手", "en": "reach", "note": "forward/up/down/side"},
7
+ {"id": 1, "category": "hand", "zh": "抓握", "en": "grasp", "note": "pinch / hold / clamp"},
8
+ {"id": 2, "category": "hand", "zh": "松开", "en": "release", "note": "release object"},
9
+ {"id": 3, "category": "hand", "zh": "旋转手腕", "en": "rotate_wrist", "note": "twist / turn"},
10
+ {"id": 4, "category": "hand", "zh": "按压", "en": "press", "note": "downward force"},
11
+ {"id": 5, "category": "hand", "zh": "拉动", "en": "pull", "note": "toward self"},
12
+ {"id": 6, "category": "hand", "zh": "推动", "en": "push", "note": "outward force"},
13
+ {"id": 7, "category": "hand", "zh": "滑动", "en": "slide", "note": "translation motion"},
14
+ {"id": 8, "category": "hand", "zh": "捏合", "en": "pinch", "note": "two/multi finger pinch"},
15
+ {"id": 9, "category": "hand", "zh": "展开", "en": "spread_fingers", "note": "fingers open"},
16
+
17
+ {"id": 10, "category": "arm", "zh": "抬起", "en": "raise_arm", "note": "arm up"},
18
+ {"id": 11, "category": "arm", "zh": "放下", "en": "lower_arm", "note": "arm down"},
19
+ {"id": 12, "category": "arm", "zh": "伸展", "en": "extend_arm", "note": "arm straight"},
20
+ {"id": 13, "category": "arm", "zh": "弯曲", "en": "bend_elbow", "note": "elbow bend"},
21
+ {"id": 14, "category": "arm", "zh": "摆动", "en": "swing_arm", "note": "left-right / forward-back"},
22
+ {"id": 15, "category": "arm", "zh": "环绕", "en": "circle_arm", "note": "circular motion"},
23
+
24
+ {"id": 16, "category": "body", "zh": "弯腰", "en": "bend_torso", "note": "lean forward"},
25
+ {"id": 17, "category": "body", "zh": "直立", "en": "stand_upright", "note": "return to standing"},
26
+ {"id": 18, "category": "body", "zh": "蹲下", "en": "squat_down", "note": "lower center of mass"},
27
+ {"id": 19, "category": "body", "zh": "站起", "en": "stand_up", "note": "return to height"},
28
+ {"id": 20, "category": "body", "zh": "转身", "en": "turn_body", "note": "torso rotate"},
29
+ {"id": 21, "category": "body", "zh": "侧身", "en": "lean_side", "note": "torso tilt"},
30
+ {"id": 22, "category": "body", "zh": "迈步", "en": "step", "note": "shift position"},
31
+
32
+ {"id": 23, "category": "fine", "zh": "插入", "en": "insert", "note": "object enters"},
33
+ {"id": 24, "category": "fine", "zh": "拔出", "en": "extract", "note": "object exits"},
34
+ {"id": 25, "category": "fine", "zh": "折叠", "en": "fold", "note": "change shape"},
35
+ {"id": 26, "category": "fine", "zh": "撕扯", "en": "tear", "note": "separate"},
36
+ {"id": 27, "category": "fine", "zh": "擦拭", "en": "wipe", "note": "back-and-forth"},
37
+
38
+ {"id": 28, "category": "composite", "zh": "拿起物品", "en": "pick_up_object", "note": "reach -> grasp -> raise"},
39
+ {"id": 29, "category": "composite", "zh": "放下物品", "en": "put_down_object", "note": "move -> release -> retract"},
40
+ {"id": 30, "category": "composite", "zh": "移动物品", "en": "move_object", "note": "pick_up -> move -> put_down"},
41
+ {"id": 31, "category": "composite", "zh": "交换手持物", "en": "transfer_between_hands","note": "one hand grasp -> other hand take -> first release"},
42
+ {"id": 32, "category": "composite", "zh": "打开盖子", "en": "open_lid", "note": "grasp -> rotate/lift"},
43
+ {"id": 33, "category": "composite", "zh": "关闭盖子", "en": "close_lid", "note": "align -> press/rotate"},
44
+ {"id": 34, "category": "composite", "zh": "倒入液体", "en": "pour_liquid", "note": "lift -> tilt -> control flow -> reset"},
45
+ {"id": 35, "category": "composite", "zh": "舀取", "en": "scoop", "note": "insert -> raise -> move"},
46
+ {"id": 36, "category": "composite", "zh": "打开柜门", "en": "open_cabinet_door", "note": "grasp handle -> pull"},
47
+ {"id": 37, "category": "composite", "zh": "关闭柜门", "en": "close_cabinet_door", "note": "push -> confirm"},
48
+ {"id": 38, "category": "composite", "zh": "打开抽屉", "en": "open_drawer", "note": "grasp -> pull out"},
49
+ {"id": 39, "category": "composite", "zh": "按下开关", "en": "press_switch", "note": "reach -> press"},
50
+ {"id": 40, "category": "composite", "zh": "折叠衣物", "en": "fold_clothing", "note": "spread -> fold -> flatten"},
51
+ {"id": 41, "category": "composite", "zh": "叠放物品", "en": "stack_objects", "note": "pick_up -> align -> place gently"},
52
+ {"id": 42, "category": "composite", "zh": "排列物品", "en": "arrange_objects", "note": "move -> adjust spacing -> align"},
53
+ {"id": 43, "category": "composite", "zh": "分类收纳", "en": "sort_and_store", "note": "identify -> group -> place"},
54
+ {"id": 44, "category": "composite", "zh": "擦拭表面", "en": "wipe_surface", "note": "take cloth -> press -> back-and-forth"},
55
+ {"id": 45, "category": "composite", "zh": "扫除垃圾", "en": "sweep_debris", "note": "broom -> gather -> dustpan"},
56
+ {"id": 46, "category": "composite", "zh": "倾倒垃圾", "en": "dump_trash", "note": "lift container -> align -> tilt -> pour"},
57
+ {"id": 47, "category": "composite", "zh": "喷洒液体", "en": "spray_liquid", "note": "press nozzle -> move -> release"},
58
+ {"id": 48, "category": "composite", "zh": "撕胶带", "en": "tear_tape", "note": "pull -> tear off"},
59
+ {"id": 49, "category": "composite", "zh": "贴标签", "en": "stick_label", "note": "peel -> align -> press"},
60
+ {"id": 50, "category": "composite", "zh": "包裹物品", "en": "wrap_object", "note": "spread wrap -> place item -> fold -> seal"},
61
+ {"id": 51, "category": "composite", "zh": "系绳打结", "en": "tie_knot", "note": "cross -> through -> tighten"},
62
+ {"id": 52, "category": "composite", "zh": "拿起笔", "en": "pick_up_pen", "note": "pinch -> adjust grip"},
63
+ {"id": 53, "category": "composite", "zh": "写字", "en": "write", "note": "controlled motion -> apply pressure"},
64
+ {"id": 54, "category": "composite", "zh": "翻页", "en": "turn_page", "note": "pinch corner -> flip"},
65
+ {"id": 55, "category": "composite", "zh": "插入电源", "en": "plug_in_power", "note": "align -> push in"},
66
+ {"id": 56, "category": "composite", "zh": "连接线缆", "en": "connect_cable", "note": "align connector -> insert -> confirm"},
67
+ {"id": 57, "category": "composite", "zh": "组装部件", "en": "assemble_parts", "note": "align -> snap/screw"},
68
+ {"id": 58, "category": "composite", "zh": "称重", "en": "weigh", "note": "place item -> read scale"},
69
+ {"id": 59, "category": "composite", "zh": "量取", "en": "measure_volume", "note": "pour -> read marking -> adjust"},
70
+ {"id": 60, "category": "composite", "zh": "计数", "en": "count", "note": "move one by one -> tally"},
71
+ {"id": 61, "category": "composite", "zh": "挂衣服", "en": "hang_clothing", "note": "take hanger -> insert garment -> hang"},
72
+ {"id": 62, "category": "composite", "zh": "铲猫砂", "en": "scoop_litter", "note": "insert -> raise -> sift -> pour"},
73
+ {"id": 63, "category": "composite", "zh": "搅拌", "en": "stir", "note": "insert spoon -> circular motion"},
74
+ {"id": 64, "category": "composite", "zh": "剪切", "en": "cut", "note": "hold scissors -> align -> close"}
75
+ ]
76
+ }
experiments/slurm/freeze_all_rows.sh ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Create folder structure for ALL rows across Tables 1, 3, 4, 5, 7 and
3
+ # freeze the current experiments/ code into each one. After this you can
4
+ # cd into any <table>/<row>/ and run ./run.sh to submit 5 SLURM seeds.
5
+ #
6
+ # Re-running this script is safe: it will re-freeze the code (overwrite the
7
+ # snapshot), but won't clobber any existing seeds/ outputs.
8
+ set -euo pipefail
9
+
10
+ BASEDIR=${BASEDIR:-${PULSE_ROOT}}
11
+ EXP=${BASEDIR}/experiments
12
+ SETUP="${EXP}/setup_row.sh"
13
+
14
+ COMMON="--epochs 40 --batch_size 32 --lr 3e-4 --weight_decay 1e-4 \
15
+ --patience 12 --label_smoothing 0.05 --use_class_weights \
16
+ --num_workers 2"
17
+
18
+ ALL5="imu,emg,eyetrack,mocap,pressure"
19
+
20
+ row () {
21
+ # $1=table $2=row $3=desc $4=cli
22
+ bash "${SETUP}" --table "$1" --row "$2" --desc "$3" --cli "$4 ${COMMON}"
23
+ }
24
+
25
+ # ============================================================
26
+ # Table 1: Main comparison at T_fut=2s
27
+ # ============================================================
28
+ T1=table1_main_comparison
29
+ cat > "${BASEDIR}/${T1}/README.md" <<'EOF'
30
+ # Table 1: Main Comparison (Next-Action Prediction, T_fut = 2 s)
31
+
32
+ Each baseline is run on its most favourable modality subset; our model
33
+ (DailyActFormer) uses all 5 synchronised modalities. 5 seeds per row;
34
+ report mean ± std of Verb fine Top-1/5, Noun Top-1/5, Hand Top-1, Action
35
+ Top-1 (= verb ∧ noun ∧ hand). Action Top-1 is the headline metric.
36
+
37
+ | Row | Method | Family | Modalities |
38
+ |-----|-------------------|-----------------|---------------------|
39
+ | 01 | DailyActFormer | cross-modal Trf | imu+emg+eye+mocap+P |
40
+ | 02 | DeepConvLSTM | CNN+LSTM (IMU) | imu |
41
+ | 03 | DeepConvLSTM 3mod | CNN+LSTM | imu+mocap+emg |
42
+ | 04 | RULSTM | rolling LSTM | imu+mocap |
43
+ | 05 | FUTR | long-term Trf | mocap+imu+emg |
44
+ | 06 | AFFT | multimodal Trf | imu+emg+eye+mocap |
45
+ | 07 | HandFormer | hand-pose Trf | mocap (fingers) |
46
+ | 08 | ActionLLM (LoRA) | LLM-based | imu+emg+eye |
47
+ EOF
48
+
49
+ mkdir -p "${BASEDIR}/${T1}"
50
+ row ${T1} row01_ours_dailyactformer_all5 \
51
+ "Our model, all 5 modalities (headline row)" \
52
+ "--model dailyactformer --modalities ${ALL5} --t_obs 8 --t_fut 2"
53
+
54
+ row ${T1} row02_deepconvlstm_imu \
55
+ "DeepConvLSTM on IMU only (classic HAR baseline)" \
56
+ "--model deepconvlstm --modalities imu --t_obs 8 --t_fut 2"
57
+
58
+ row ${T1} row03_deepconvlstm_3mod \
59
+ "DeepConvLSTM on IMU+MoCap+EMG (best 3-modality concat)" \
60
+ "--model deepconvlstm --modalities imu,mocap,emg --t_obs 8 --t_fut 2"
61
+
62
+ row ${T1} row04_rulstm_imu_mocap \
63
+ "RULSTM, rolling-unrolling LSTM (IMU + MoCap late fusion)" \
64
+ "--model rulstm --modalities imu,mocap --t_obs 8 --t_fut 2"
65
+
66
+ row ${T1} row05_futr_3mod \
67
+ "FUTR (causal transformer) on MoCap+IMU+EMG" \
68
+ "--model futr --modalities mocap,imu,emg --t_obs 8 --t_fut 2"
69
+
70
+ row ${T1} row06_afft_4mod \
71
+ "AFFT (anticipative feature fusion transformer) on 4 modalities" \
72
+ "--model afft --modalities imu,emg,eyetrack,mocap --t_obs 8 --t_fut 2"
73
+
74
+ row ${T1} row07_handformer_mocap \
75
+ "HandFormer (skeleton-only ECCV'24) on MoCap finger joints" \
76
+ "--model handformer --modalities mocap --t_obs 8 --t_fut 2"
77
+
78
+ row ${T1} row08_actionllm_3mod \
79
+ "ActionLLM (Qwen2.5-0.5B + LoRA) on IMU+EMG+EyeTrack" \
80
+ "--model actionllm --modalities imu,emg,eyetrack --t_obs 8 --t_fut 2"
81
+
82
+ # ============================================================
83
+ # Table 3: Horizon curve (DailyActFormer)
84
+ # ============================================================
85
+ T3=table3_horizon_curve
86
+ mkdir -p "${BASEDIR}/${T3}"
87
+ cat > "${BASEDIR}/${T3}/README.md" <<'EOF'
88
+ # Table 3: Prediction Horizon Curve (DailyActFormer, all 5 modalities)
89
+
90
+ Same model, varying T_fut. Expect monotonic drop in Action Top-1 as
91
+ horizon grows; plot line graph in the paper alongside this table.
92
+ EOF
93
+ HORIZONS=(1 2 5 10 15)
94
+ for i in "${!HORIZONS[@]}"; do
95
+ tfut="${HORIZONS[$i]}"
96
+ idx=$(printf "%02d" $((i+1)))
97
+ row ${T3} row${idx}_ours_tfut${tfut}s \
98
+ "Our model at T_fut=${tfut}s" \
99
+ "--model dailyactformer --modalities ${ALL5} --t_obs 8 --t_fut ${tfut}"
100
+ done
101
+
102
+ # ============================================================
103
+ # Table 4: Modality ablation on DailyActFormer (T_fut=2s)
104
+ # ============================================================
105
+ T4=table4_modality_ablation
106
+ mkdir -p "${BASEDIR}/${T4}"
107
+ cat > "${BASEDIR}/${T4}/README.md" <<'EOF'
108
+ # Table 4: Modality Ablation (DailyActFormer, T_fut = 2 s)
109
+
110
+ Same model, progressively remove modalities. Each row trained from scratch.
111
+ EOF
112
+ row ${T4} row01_full_5mod "Full 5-modality (reference)" "--model dailyactformer --modalities imu,emg,eyetrack,mocap,pressure --t_obs 8 --t_fut 2"
113
+ row ${T4} row02_no_pressure "Drop pressure" "--model dailyactformer --modalities imu,emg,eyetrack,mocap --t_obs 8 --t_fut 2"
114
+ row ${T4} row03_no_eyetrack "Drop eye-tracking" "--model dailyactformer --modalities imu,emg,mocap,pressure --t_obs 8 --t_fut 2"
115
+ row ${T4} row04_no_emg "Drop EMG" "--model dailyactformer --modalities imu,eyetrack,mocap,pressure --t_obs 8 --t_fut 2"
116
+ row ${T4} row05_no_imu "Drop IMU" "--model dailyactformer --modalities emg,eyetrack,mocap,pressure --t_obs 8 --t_fut 2"
117
+ row ${T4} row06_no_mocap "Drop MoCap" "--model dailyactformer --modalities imu,emg,eyetrack,pressure --t_obs 8 --t_fut 2"
118
+ row ${T4} row07_imu_emg_only "Only IMU + EMG (physiology-light)" "--model dailyactformer --modalities imu,emg --t_obs 8 --t_fut 2"
119
+ row ${T4} row08_mocap_only "Only MoCap (skeleton-only)" "--model dailyactformer --modalities mocap --t_obs 8 --t_fut 2"
120
+
121
+ # ============================================================
122
+ # Table 5: Component ablation (DailyActFormer switches)
123
+ # ============================================================
124
+ T5=table5_component_ablation
125
+ mkdir -p "${BASEDIR}/${T5}"
126
+ cat > "${BASEDIR}/${T5}/README.md" <<'EOF'
127
+ # Table 5: Component Ablation (DailyActFormer, T_fut = 2 s)
128
+
129
+ Each row toggles one architectural/training component of our model.
130
+ Component flags are implemented as CLI switches on train_seqpred.py;
131
+ see models_seqpred.py for the corresponding model options.
132
+ EOF
133
+ row ${T5} row01_full \
134
+ "Full model (reference)" \
135
+ "--model dailyactformer --modalities ${ALL5} --t_obs 8 --t_fut 2"
136
+ row ${T5} row02_no_composite_head \
137
+ "Drop the auxiliary verb-composite head (lambda=0)" \
138
+ "--model dailyactformer --modalities ${ALL5} --t_obs 8 --t_fut 2 --lambda_verb_composite 0.0"
139
+ row ${T5} row03_equal_lambda \
140
+ "Equal-weight all 4 heads (no prior on verb>hand)" \
141
+ "--model dailyactformer --modalities ${ALL5} --t_obs 8 --t_fut 2 --lambda_verb_composite 1.0 --lambda_hand 1.0"
142
+ row ${T5} row04_no_class_weight \
143
+ "No inverse-frequency class weighting" \
144
+ "--model dailyactformer --modalities ${ALL5} --t_obs 8 --t_fut 2 --lambda_verb_composite 0.5"
145
+ # row04 re-exposes the default; the variable-off is the absence of --use_class_weights
146
+ # We patch this manually — strip the flag out of COMMON.
147
+ ROW_DIR="${BASEDIR}/${T5}/row04_no_class_weight/run.sh"
148
+ if [[ -e "${ROW_DIR}" ]]; then
149
+ sed -i 's/--use_class_weights //g' "${ROW_DIR}"
150
+ fi
151
+
152
+ row ${T5} row05_no_label_smoothing \
153
+ "Label smoothing off" \
154
+ "--model dailyactformer --modalities ${ALL5} --t_obs 8 --t_fut 2 --label_smoothing 0.0"
155
+
156
+ # ============================================================
157
+ # Table 7: Missing-modality robustness (train once, eval 6 ways)
158
+ # ============================================================
159
+ T7=table7_missing_modality
160
+ mkdir -p "${BASEDIR}/${T7}"
161
+ cat > "${BASEDIR}/${T7}/README.md" <<'EOF'
162
+ # Table 7: Missing-Modality Robustness (T_fut = 2 s)
163
+
164
+ Train DailyActFormer with random per-modality dropout (p=0.3). At test time,
165
+ evaluate under 6 configurations: full / drop one modality each. Only the
166
+ training job has its own folder; eval uses the trained checkpoint to fill
167
+ multiple rows of the final table.
168
+ EOF
169
+ row ${T7} row01_train_with_modality_dropout \
170
+ "DailyActFormer trained with --modality_dropout 0.3" \
171
+ "--model dailyactformer --modalities ${ALL5} --t_obs 8 --t_fut 2 --modality_dropout 0.3"
172
+ # The 6 test-time configurations (full / no_P / no_E / no_emg / no_imu /
173
+ # no_mocap) will be produced by a separate eval script that loads the
174
+ # checkpoint from row01 and runs evaluate() with modality subsets. See
175
+ # experiments/tasks/eval_missing_modality.py (TBD).
176
+
177
+ echo ""
178
+ echo "[ok] Froze rows under:"
179
+ echo " ${BASEDIR}/{${T1},${T3},${T4},${T5},${T7}}/"
experiments/slurm/run_ablation_fix.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=ablation_fix
3
+ #SBATCH --partition=gpuA800
4
+ #SBATCH --gres=gpu:1
5
+ #SBATCH --nodes=1
6
+ #SBATCH --ntasks=1
7
+ #SBATCH --cpus-per-task=4
8
+ #SBATCH --mem=32G
9
+ #SBATCH --time=1:00:00
10
+ #SBATCH --output=${PULSE_ROOT}/results/ablation_fix_%j.log
11
+
12
+ # Fix: mocap+emg late+pretrained — pretrain MOCAP branch (idx=0) instead of emg
13
+ set -e
14
+ export PYTHONUNBUFFERED=1
15
+
16
+ PYTHON=python
17
+ BASEDIR=${PULSE_ROOT}
18
+ SCRIPT=${BASEDIR}/experiments/train_exp1.py
19
+ OUTDIR=${BASEDIR}/results/modality_ablation
20
+ COMMON="--model transformer --epochs 100 --batch_size 16 --lr 1e-3 --weight_decay 1e-4 --hidden_dim 128 --downsample 5 --patience 15 --proj_dim 0 --output_dir $OUTDIR"
21
+ SEEDS=(42 123 456 789 2024)
22
+
23
+ PT_MOCAP=${BASEDIR}/results/exp1_v8/transformer_mocap_early/model_best.pt
24
+
25
+ echo "=== Fix: mocap+emg / late+pretrained(mocap, idx=0) ==="
26
+ for seed in "${SEEDS[@]}"; do
27
+ echo " mocap+emg seed=$seed"
28
+ $PYTHON $SCRIPT --modalities mocap,emg --fusion late --seed $seed \
29
+ --pretrained_backbone $PT_MOCAP --freeze_backbone_idx 0 \
30
+ --tag ablation_pt_s${seed} $COMMON 2>&1 | tail -5
31
+ done
32
+
33
+ echo "=== Done ==="
experiments/slurm/run_ablation_fusion.sh ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=ablation_fuse
3
+ #SBATCH --partition=gpuA800
4
+ #SBATCH --gres=gpu:2
5
+ #SBATCH --nodes=1
6
+ #SBATCH --ntasks=1
7
+ #SBATCH --cpus-per-task=8
8
+ #SBATCH --mem=64G
9
+ #SBATCH --time=4:00:00
10
+ #SBATCH --output=${PULSE_ROOT}/results/ablation_fusion_%j.log
11
+
12
+ # Test confidence-weighted and learned-weight fusion on all multi-modal combos
13
+ # Compare against existing mean fusion results
14
+
15
+ set -e
16
+ export PYTHONUNBUFFERED=1
17
+
18
+ PYTHON=python
19
+ BASEDIR=${PULSE_ROOT}
20
+ SCRIPT=${BASEDIR}/experiments/train_exp1.py
21
+ OUTDIR=${BASEDIR}/results/modality_ablation
22
+ COMMON="--model transformer --epochs 100 --batch_size 16 --lr 1e-3 --weight_decay 1e-4 --hidden_dim 128 --downsample 5 --patience 15 --proj_dim 0 --output_dir $OUTDIR"
23
+ SEEDS=(42 123 456 789 2024)
24
+
25
+ PT_IMU=${BASEDIR}/results/exp1_v7/transformer_imu_early/model_best.pt
26
+ PT_MOCAP=${BASEDIR}/results/exp1_v8/transformer_mocap_early/model_best.pt
27
+
28
+ echo "=== Ablation: Confidence & Learned Fusion ==="
29
+
30
+ # ============================================================
31
+ # GPU 0: confidence-weighted fusion
32
+ # ============================================================
33
+ (
34
+ export CUDA_VISIBLE_DEVICES=0
35
+
36
+ # mocap+imu / confidence / pretrained imu (idx=1)
37
+ echo "--- GPU0: mocap+imu / confidence ---"
38
+ for seed in "${SEEDS[@]}"; do
39
+ echo " mocap+imu confidence seed=$seed"
40
+ $PYTHON $SCRIPT --modalities mocap,imu --fusion late --late_agg confidence \
41
+ --seed $seed --pretrained_backbone $PT_IMU --freeze_backbone_idx 1 \
42
+ --tag ablation_conf_s${seed} $COMMON 2>&1 | tail -3
43
+ done
44
+
45
+ # emg+imu / confidence / pretrained imu (idx=1)
46
+ echo "--- GPU0: emg+imu / confidence ---"
47
+ for seed in "${SEEDS[@]}"; do
48
+ echo " emg+imu confidence seed=$seed"
49
+ $PYTHON $SCRIPT --modalities emg,imu --fusion late --late_agg confidence \
50
+ --seed $seed --pretrained_backbone $PT_IMU --freeze_backbone_idx 1 \
51
+ --tag ablation_conf_s${seed} $COMMON 2>&1 | tail -3
52
+ done
53
+
54
+ # mocap+emg / confidence / pretrained mocap (idx=0)
55
+ echo "--- GPU0: mocap+emg / confidence ---"
56
+ for seed in "${SEEDS[@]}"; do
57
+ echo " mocap+emg confidence seed=$seed"
58
+ $PYTHON $SCRIPT --modalities mocap,emg --fusion late --late_agg confidence \
59
+ --seed $seed --pretrained_backbone $PT_MOCAP --freeze_backbone_idx 0 \
60
+ --tag ablation_conf_s${seed} $COMMON 2>&1 | tail -3
61
+ done
62
+
63
+ # mocap+emg+imu / confidence / pretrained imu (idx=2, modalities=mocap,emg,imu)
64
+ echo "--- GPU0: mocap+emg+imu / confidence ---"
65
+ for seed in "${SEEDS[@]}"; do
66
+ echo " mocap+emg+imu confidence seed=$seed"
67
+ $PYTHON $SCRIPT --modalities imu,mocap,emg --fusion late --late_agg confidence \
68
+ --seed $seed --pretrained_backbone $PT_IMU --freeze_backbone_idx 0 \
69
+ --tag ablation_conf_s${seed} $COMMON 2>&1 | tail -3
70
+ done
71
+
72
+ echo "--- GPU0 Done ---"
73
+ ) &
74
+ PID0=$!
75
+
76
+ # ============================================================
77
+ # GPU 1: learned-weight fusion
78
+ # ============================================================
79
+ (
80
+ export CUDA_VISIBLE_DEVICES=1
81
+
82
+ # mocap+imu / learned / pretrained imu (idx=1)
83
+ echo "--- GPU1: mocap+imu / learned ---"
84
+ for seed in "${SEEDS[@]}"; do
85
+ echo " mocap+imu learned seed=$seed"
86
+ $PYTHON $SCRIPT --modalities mocap,imu --fusion late --late_agg learned \
87
+ --seed $seed --pretrained_backbone $PT_IMU --freeze_backbone_idx 1 \
88
+ --tag ablation_lrn_s${seed} $COMMON 2>&1 | tail -3
89
+ done
90
+
91
+ # emg+imu / learned / pretrained imu (idx=1)
92
+ echo "--- GPU1: emg+imu / learned ---"
93
+ for seed in "${SEEDS[@]}"; do
94
+ echo " emg+imu learned seed=$seed"
95
+ $PYTHON $SCRIPT --modalities emg,imu --fusion late --late_agg learned \
96
+ --seed $seed --pretrained_backbone $PT_IMU --freeze_backbone_idx 1 \
97
+ --tag ablation_lrn_s${seed} $COMMON 2>&1 | tail -3
98
+ done
99
+
100
+ # mocap+emg / learned / pretrained mocap (idx=0)
101
+ echo "--- GPU1: mocap+emg / learned ---"
102
+ for seed in "${SEEDS[@]}"; do
103
+ echo " mocap+emg learned seed=$seed"
104
+ $PYTHON $SCRIPT --modalities mocap,emg --fusion late --late_agg learned \
105
+ --seed $seed --pretrained_backbone $PT_MOCAP --freeze_backbone_idx 0 \
106
+ --tag ablation_lrn_s${seed} $COMMON 2>&1 | tail -3
107
+ done
108
+
109
+ # mocap+emg+imu / learned / pretrained imu (idx=0, modalities=imu,mocap,emg)
110
+ echo "--- GPU1: mocap+emg+imu / learned ---"
111
+ for seed in "${SEEDS[@]}"; do
112
+ echo " mocap+emg+imu learned seed=$seed"
113
+ $PYTHON $SCRIPT --modalities imu,mocap,emg --fusion late --late_agg learned \
114
+ --seed $seed --pretrained_backbone $PT_IMU --freeze_backbone_idx 0 \
115
+ --tag ablation_lrn_s${seed} $COMMON 2>&1 | tail -3
116
+ done
117
+
118
+ echo "--- GPU1 Done ---"
119
+ ) &
120
+ PID1=$!
121
+
122
+ wait $PID0 $PID1
123
+
124
+ # ============================================================
125
+ # Collect results
126
+ # ============================================================
127
+ echo ""
128
+ echo "=== Fusion Comparison ==="
129
+ $PYTHON -c "
130
+ import json, os, numpy as np
131
+
132
+ base = '$OUTDIR'
133
+ v8_base = '${BASEDIR}/results/exp1_v8_multiseed'
134
+ v9_base = '${BASEDIR}/results/exp1_v9'
135
+ seeds = [42, 123, 456, 789, 2024]
136
+
137
+ configs = [
138
+ # (label, pattern_template)
139
+ # mean (from previous ablation run)
140
+ ('mocap+imu / mean', base + '/transformer_mocap-imu_late_ablation_pt_s{}/results.json'),
141
+ ('mocap+imu / confidence', base + '/transformer_mocap-imu_late_ablation_conf_s{}/results.json'),
142
+ ('mocap+imu / learned', base + '/transformer_mocap-imu_late_ablation_lrn_s{}/results.json'),
143
+ ('emg+imu / mean', base + '/transformer_emg-imu_late_ablation_pt_s{}/results.json'),
144
+ ('emg+imu / confidence', base + '/transformer_emg-imu_late_ablation_conf_s{}/results.json'),
145
+ ('emg+imu / learned', base + '/transformer_emg-imu_late_ablation_lrn_s{}/results.json'),
146
+ ('mocap+emg / mean', base + '/transformer_mocap-emg_late_ablation_pt_s{}/results.json'),
147
+ ('mocap+emg / confidence', base + '/transformer_mocap-emg_late_ablation_conf_s{}/results.json'),
148
+ ('mocap+emg / learned', base + '/transformer_mocap-emg_late_ablation_lrn_s{}/results.json'),
149
+ ('3mod / mean', v9_base + '/transformer_imu-mocap-emg_late_pt_s{}/results.json'),
150
+ ('3mod / confidence', base + '/transformer_imu-mocap-emg_late_ablation_conf_s{}/results.json'),
151
+ ('3mod / learned', base + '/transformer_imu-mocap-emg_late_ablation_lrn_s{}/results.json'),
152
+ ]
153
+
154
+ print(f'{\"Config\":<30} {\"F1 (mean±std)\":<20} {\"Acc (mean±std)\":<20} N')
155
+ print('-' * 75)
156
+ for label, pat in configs:
157
+ f1s, accs = [], []
158
+ for s in seeds:
159
+ path = pat.format(s)
160
+ if os.path.exists(path):
161
+ with open(path) as f:
162
+ d = json.load(f)
163
+ f1s.append(d['test_macro_f1'])
164
+ accs.append(d['test_accuracy'])
165
+ if f1s:
166
+ f1 = np.array(f1s)
167
+ acc = np.array(accs)
168
+ print(f'{label:<30} {f1.mean():.3f}±{f1.std():.3f} {acc.mean():.3f}±{acc.std():.3f} {len(f1s)}')
169
+ else:
170
+ print(f'{label:<30} (no results)')
171
+ "
172
+
173
+ echo ""
174
+ echo "=== All done ==="
experiments/slurm/run_asformer_exp3.sh ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --partition=gpuA800
3
+ #SBATCH --nodes=1
4
+ #SBATCH --ntasks=1
5
+ #SBATCH --cpus-per-task=4
6
+ #SBATCH --gres=gpu:1
7
+ #SBATCH --mem=32G
8
+ #SBATCH --time=4:00:00
9
+ #SBATCH --job-name=ASF_exp3
10
+ #SBATCH --output=${PULSE_ROOT}/results/asformer_exp3_%j.log
11
+
12
+ set -e
13
+ PYTHON=python
14
+ PROJECT=${PULSE_ROOT}
15
+ cd $PROJECT
16
+
17
+ EXP3_OUT=$PROJECT/results/published_baselines/exp3_asformer
18
+ mkdir -p $EXP3_OUT
19
+
20
+ echo "=== ASFormer Contact Detection ==="
21
+
22
+ for MOD in mocap emg imu "mocap,emg" "mocap,emg,eyetrack" "mocap,emg,eyetrack,imu"; do
23
+ echo "--- ASFormer / ${MOD} ---"
24
+ $PYTHON experiments/train_exp3.py \
25
+ --model asformer --modalities $MOD \
26
+ --hidden_dim 64 --epochs 50 --batch_size 32 \
27
+ --lr 1e-3 --weight_decay 1e-4 --downsample 2 \
28
+ --seed 42 --output_dir $EXP3_OUT 2>&1 | tail -8
29
+ done
30
+
31
+ echo ""
32
+ echo "=== Results ==="
33
+ for f in $EXP3_OUT/*/results.json; do
34
+ if [ -f "$f" ]; then
35
+ $PYTHON -c "
36
+ import json
37
+ with open('$f') as fp:
38
+ r = json.load(fp)
39
+ mods = ','.join(r.get('input_modalities', []))
40
+ m = r.get('test_metrics', {})
41
+ print(f' ASFormer | {mods:<30} | R_F1={m.get(\"right_f1\",0):.4f} L_F1={m.get(\"left_f1\",0):.4f} Avg_F1={m.get(\"avg_f1\",0):.4f}')
42
+ "
43
+ fi
44
+ done
experiments/slurm/run_exp1.sh ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH -J exp1_scene
3
+ #SBATCH -p gpuA800
4
+ #SBATCH --gres=gpu:1
5
+ #SBATCH -N 1
6
+ #SBATCH -n 1
7
+ #SBATCH --cpus-per-task=8
8
+ #SBATCH --mem=64G
9
+ #SBATCH -t 12:00:00
10
+ #SBATCH -o ${PULSE_ROOT}/results/exp1/slurm_%j.out
11
+ #SBATCH -e ${PULSE_ROOT}/results/exp1/slurm_%j.err
12
+
13
+ export PYTHONUNBUFFERED=1
14
+
15
+ echo "=== Job Info ==="
16
+ echo "Job ID: $SLURM_JOB_ID"
17
+ echo "Node: $SLURM_NODELIST"
18
+ echo "Start time: $(date)"
19
+ nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
20
+ echo "================"
21
+
22
+ PYTHON=python
23
+ SCRIPT=${PULSE_ROOT}/experiments/train_exp1.py
24
+ OUTDIR=${PULSE_ROOT}/results/exp1
25
+
26
+ cd ${PULSE_ROOT}
27
+
28
+ $PYTHON $SCRIPT --run_all \
29
+ --epochs 100 \
30
+ --batch_size 16 \
31
+ --lr 1e-3 \
32
+ --weight_decay 1e-4 \
33
+ --hidden_dim 128 \
34
+ --downsample 5 \
35
+ --patience 15 \
36
+ --seed 42 \
37
+ --output_dir $OUTDIR
38
+
39
+ echo "=== Done ==="
40
+ echo "End time: $(date)"
experiments/slurm/run_exp1_fusion.sh ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Submit all fusion experiments as individual 1-GPU SLURM jobs
3
+ # SLURM scheduler will automatically place them on any available GPU
4
+
5
+ PYTHON=python
6
+ SCRIPT=${PULSE_ROOT}/experiments/train_exp1.py
7
+ OUTDIR=${PULSE_ROOT}/results/exp1
8
+ LOGDIR=${OUTDIR}/slurm_logs
9
+ mkdir -p $LOGDIR
10
+
11
+ COMMON_ARGS="--model transformer --epochs 100 --batch_size 16 --lr 1e-3 --weight_decay 1e-4 --hidden_dim 128 --downsample 5 --patience 15 --seed 42 --output_dir $OUTDIR"
12
+
13
+ FUSIONS=(weighted_late gated_late stacking product moe late attention)
14
+ MODALITIES=("mocap,emg,eyetrack" "mocap,emg,eyetrack,imu,pressure")
15
+
16
+ for fusion in "${FUSIONS[@]}"; do
17
+ for mods in "${MODALITIES[@]}"; do
18
+ mod_tag=$(echo $mods | tr ',' '-')
19
+ job_name="f_${fusion}_${mod_tag}"
20
+ sbatch \
21
+ -J "$job_name" \
22
+ -p gpuA800 \
23
+ --gres=gpu:1 \
24
+ -N 1 -n 1 \
25
+ --cpus-per-task=8 \
26
+ --mem=32G \
27
+ -t 3:00:00 \
28
+ -o "${LOGDIR}/${job_name}_%j.out" \
29
+ -e "${LOGDIR}/${job_name}_%j.err" \
30
+ --export=ALL \
31
+ --wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --fusion $fusion --modalities $mods $COMMON_ARGS"
32
+ echo "Submitted: $job_name"
33
+ done
34
+ done
35
+
36
+ echo "All 14 fusion experiments submitted!"
experiments/slurm/run_exp1_parallel.sh ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Scene Recognition (Exp1) - Parallelized version
3
+ # Part 1: 9 modality combos × 3 backbones = 27 jobs (early fusion)
4
+ # Part 2: 7 fusion methods × transformer × (3-core + all-5) = 14 jobs
5
+ # Total: 41 jobs
6
+
7
+ PYTHON=python
8
+ BASEDIR=${PULSE_ROOT}
9
+ SCRIPT=${BASEDIR}/experiments/train_exp1.py
10
+ OUTDIR=${BASEDIR}/results/exp1_v2
11
+ LOGDIR=${OUTDIR}/slurm_logs
12
+ mkdir -p $LOGDIR
13
+
14
+ COMMON="--epochs 100 --batch_size 16 --lr 1e-3 --weight_decay 1e-4 --hidden_dim 128 --downsample 5 --patience 15 --seed 42 --output_dir $OUTDIR"
15
+
16
+ MODS=("mocap" "emg" "eyetrack" "imu" "pressure" "mocap,emg,eyetrack" "mocap,emg,eyetrack,imu" "mocap,emg,eyetrack,pressure" "mocap,emg,eyetrack,imu,pressure")
17
+ MODELS=("cnn" "lstm" "transformer")
18
+
19
+ # Part 1: Modality ablation × 3 backbones
20
+ echo "=== Part 1: Modality Ablation (27 jobs) ==="
21
+ for mods in "${MODS[@]}"; do
22
+ mod_tag=$(echo $mods | tr ',' '-')
23
+ for model in "${MODELS[@]}"; do
24
+ sbatch \
25
+ -J "exp1_${model}_${mod_tag}" \
26
+ -p gpuA800 \
27
+ --gres=gpu:1 \
28
+ -N 1 -n 1 \
29
+ --cpus-per-task=4 \
30
+ --mem=32G \
31
+ -t 2:00:00 \
32
+ -o "${LOGDIR}/${model}_${mod_tag}_early_%j.out" \
33
+ -e "${LOGDIR}/${model}_${mod_tag}_early_%j.err" \
34
+ --export=ALL \
35
+ --wrap="export PYTHONUNBUFFERED=1; cd ${BASEDIR}; $PYTHON $SCRIPT --model $model --modalities $mods --fusion early $COMMON"
36
+ echo " Submitted: $model / $mods / early"
37
+ done
38
+ done
39
+
40
+ # Part 2: Fusion methods × transformer
41
+ FUSIONS=("late" "attention" "weighted_late" "gated_late" "stacking" "product" "moe")
42
+ FUSION_MODS=("mocap,emg,eyetrack" "mocap,emg,eyetrack,imu,pressure")
43
+
44
+ echo ""
45
+ echo "=== Part 2: Fusion Ablation (14 jobs) ==="
46
+ for fmods in "${FUSION_MODS[@]}"; do
47
+ fmod_tag=$(echo $fmods | tr ',' '-')
48
+ for fusion in "${FUSIONS[@]}"; do
49
+ sbatch \
50
+ -J "exp1_tf_${fusion}_${fmod_tag}" \
51
+ -p gpuA800 \
52
+ --gres=gpu:1 \
53
+ -N 1 -n 1 \
54
+ --cpus-per-task=4 \
55
+ --mem=32G \
56
+ -t 2:00:00 \
57
+ -o "${LOGDIR}/transformer_${fmod_tag}_${fusion}_%j.out" \
58
+ -e "${LOGDIR}/transformer_${fmod_tag}_${fusion}_%j.err" \
59
+ --export=ALL \
60
+ --wrap="export PYTHONUNBUFFERED=1; cd ${BASEDIR}; $PYTHON $SCRIPT --model transformer --modalities $fmods --fusion $fusion $COMMON"
61
+ echo " Submitted: transformer / $fmods / $fusion"
62
+ done
63
+ done
64
+
65
+ echo ""
66
+ echo "Total: 41 jobs | Scene Recognition | Updated IMU data"
67
+ echo "Results: $OUTDIR"
experiments/slurm/run_exp1_small.sh ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Exp1 small model: hidden_dim=32, dropout=0.5, weight_decay=1e-3
3
+ # 3 modalities: mocap, emg, imu (exclude pressure & eyetrack)
4
+ # Output: results/exp1_small
5
+
6
+ PYTHON=python
7
+ SCRIPT=${PULSE_ROOT}/experiments/train_exp1.py
8
+ OUTDIR=${PULSE_ROOT}/results/exp1_small
9
+ LOGDIR=${OUTDIR}/slurm_logs
10
+ mkdir -p $LOGDIR
11
+
12
+ COMMON="--model transformer --epochs 100 --batch_size 16 --lr 1e-3 --weight_decay 1e-3 --hidden_dim 32 --downsample 5 --patience 15 --seed 42 --output_dir $OUTDIR"
13
+
14
+ # ============================================================
15
+ # Part 1: Single modality (early fusion = single backbone)
16
+ # ============================================================
17
+ for mod in mocap emg imu; do
18
+ job_name="s_${mod}"
19
+ sbatch \
20
+ -J "$job_name" \
21
+ -p gpuA800 \
22
+ --gres=gpu:1 \
23
+ -N 1 -n 1 \
24
+ --cpus-per-task=8 \
25
+ --mem=32G \
26
+ -t 1:00:00 \
27
+ -o "${LOGDIR}/${job_name}_%j.out" \
28
+ -e "${LOGDIR}/${job_name}_%j.err" \
29
+ --export=ALL \
30
+ --wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --fusion early --modalities $mod $COMMON"
31
+ echo "Submitted: $job_name"
32
+ done
33
+
34
+ # ============================================================
35
+ # Part 2: Multi-modality early fusion (4 combos)
36
+ # ============================================================
37
+ EARLY_COMBOS=("mocap,emg" "mocap,imu" "emg,imu" "mocap,emg,imu")
38
+ for mods in "${EARLY_COMBOS[@]}"; do
39
+ mod_tag=$(echo $mods | tr ',' '-')
40
+ job_name="e_${mod_tag}"
41
+ sbatch \
42
+ -J "$job_name" \
43
+ -p gpuA800 \
44
+ --gres=gpu:1 \
45
+ -N 1 -n 1 \
46
+ --cpus-per-task=8 \
47
+ --mem=32G \
48
+ -t 1:00:00 \
49
+ -o "${LOGDIR}/${job_name}_%j.out" \
50
+ -e "${LOGDIR}/${job_name}_%j.err" \
51
+ --export=ALL \
52
+ --wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --fusion early --modalities $mods $COMMON"
53
+ echo "Submitted: $job_name"
54
+ done
55
+
56
+ # ============================================================
57
+ # Part 3: Fusion methods x modality sets
58
+ # ============================================================
59
+ FUSIONS=(late attention weighted_late gated_late stacking product moe)
60
+ FUSION_MODS=("mocap,emg,imu" "mocap,imu")
61
+
62
+ for fusion in "${FUSIONS[@]}"; do
63
+ for mods in "${FUSION_MODS[@]}"; do
64
+ mod_tag=$(echo $mods | tr ',' '-')
65
+ job_name="f_${fusion}_${mod_tag}"
66
+ sbatch \
67
+ -J "$job_name" \
68
+ -p gpuA800 \
69
+ --gres=gpu:1 \
70
+ -N 1 -n 1 \
71
+ --cpus-per-task=8 \
72
+ --mem=32G \
73
+ -t 1:00:00 \
74
+ -o "${LOGDIR}/${job_name}_%j.out" \
75
+ -e "${LOGDIR}/${job_name}_%j.err" \
76
+ --export=ALL \
77
+ --wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --fusion $fusion --modalities $mods $COMMON"
78
+ echo "Submitted: $job_name"
79
+ done
80
+ done
81
+
82
+ echo ""
83
+ echo "Total: 3 single + 4 early + 14 fusion = 21 jobs submitted!"
84
+ echo "Results will be saved to: $OUTDIR"
experiments/slurm/run_exp1_small2.sh ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Exp1 small2: per-modality hidden_dim + missing emg+imu fusion experiments
3
+ # hidden_dim=32 base, scaled per modality: mocap(211)->48, imu(161)->48, emg(9)->16
4
+ # Output: results/exp1_small2
5
+
6
+ PYTHON=python
7
+ SCRIPT=${PULSE_ROOT}/experiments/train_exp1.py
8
+ OUTDIR=${PULSE_ROOT}/results/exp1_small2
9
+ LOGDIR=${OUTDIR}/slurm_logs
10
+ mkdir -p $LOGDIR
11
+
12
+ COMMON="--model transformer --epochs 100 --batch_size 16 --lr 1e-3 --weight_decay 1e-3 --hidden_dim 32 --downsample 5 --patience 15 --seed 42 --output_dir $OUTDIR"
13
+
14
+ # ============================================================
15
+ # Part 1: Single modality baselines (3 jobs)
16
+ # ============================================================
17
+ for mod in mocap emg imu; do
18
+ job_name="s2_${mod}"
19
+ sbatch \
20
+ -J "$job_name" \
21
+ -p gpuA800 \
22
+ --gres=gpu:1 \
23
+ -N 1 -n 1 \
24
+ --cpus-per-task=8 \
25
+ --mem=32G \
26
+ -t 1:00:00 \
27
+ -o "${LOGDIR}/${job_name}_%j.out" \
28
+ -e "${LOGDIR}/${job_name}_%j.err" \
29
+ --export=ALL \
30
+ --wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --fusion early --modalities $mod $COMMON"
31
+ echo "Submitted: $job_name"
32
+ done
33
+
34
+ # ============================================================
35
+ # Part 2: Early fusion baselines (3 combos)
36
+ # ============================================================
37
+ EARLY_COMBOS=("emg,imu" "mocap,imu" "mocap,emg,imu")
38
+ for mods in "${EARLY_COMBOS[@]}"; do
39
+ mod_tag=$(echo $mods | tr ',' '-')
40
+ job_name="s2_e_${mod_tag}"
41
+ sbatch \
42
+ -J "$job_name" \
43
+ -p gpuA800 \
44
+ --gres=gpu:1 \
45
+ -N 1 -n 1 \
46
+ --cpus-per-task=8 \
47
+ --mem=32G \
48
+ -t 1:00:00 \
49
+ -o "${LOGDIR}/${job_name}_%j.out" \
50
+ -e "${LOGDIR}/${job_name}_%j.err" \
51
+ --export=ALL \
52
+ --wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --fusion early --modalities $mods $COMMON"
53
+ echo "Submitted: $job_name"
54
+ done
55
+
56
+ # ============================================================
57
+ # Part 3: Fusion methods x modality combos (7 methods x 3 combos = 21 jobs)
58
+ # Key addition: emg,imu fusion (was missing in round 1)
59
+ # ============================================================
60
+ FUSIONS=(late attention weighted_late gated_late stacking product moe)
61
+ FUSION_MODS=("emg,imu" "mocap,imu" "mocap,emg,imu")
62
+
63
+ for fusion in "${FUSIONS[@]}"; do
64
+ for mods in "${FUSION_MODS[@]}"; do
65
+ mod_tag=$(echo $mods | tr ',' '-')
66
+ job_name="s2_${fusion}_${mod_tag}"
67
+ sbatch \
68
+ -J "$job_name" \
69
+ -p gpuA800 \
70
+ --gres=gpu:1 \
71
+ -N 1 -n 1 \
72
+ --cpus-per-task=8 \
73
+ --mem=32G \
74
+ -t 1:00:00 \
75
+ -o "${LOGDIR}/${job_name}_%j.out" \
76
+ -e "${LOGDIR}/${job_name}_%j.err" \
77
+ --export=ALL \
78
+ --wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --fusion $fusion --modalities $mods $COMMON"
79
+ echo "Submitted: $job_name"
80
+ done
81
+ done
82
+
83
+ echo ""
84
+ echo "Total: 3 single + 3 early + 21 fusion = 27 jobs submitted!"
85
+ echo "Results will be saved to: $OUTDIR"
experiments/slurm/run_exp1_small3.sh ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Exp1 small3: Data augmentation + Frozen pretrained IMU + Label smoothing
3
+ # Goal: Break the IMU-alone F1=0.771 ceiling with emg+imu fusion
4
+ # Phase 0: pretrain IMU with hidden_dim=48 (matches fusion branch)
5
+ # Baselines: IMU+aug+ls, emg+imu early+aug+ls
6
+ # Group A: 7 fusion + aug + ls (no freeze)
7
+ # Group B: 7 fusion + frozen IMU + ls (no aug) [dep: phase0]
8
+ # Group C: 7 fusion + frozen IMU + aug + ls [dep: phase0]
9
+ # Total: 1 + 2 + 7 + 7 + 7 = 24 jobs
10
+
11
+ PYTHON=python
12
+ SCRIPT=${PULSE_ROOT}/experiments/train_exp1.py
13
+ OUTDIR=${PULSE_ROOT}/results/exp1_small3
14
+ LOGDIR=${OUTDIR}/slurm_logs
15
+ mkdir -p $LOGDIR
16
+
17
+ COMMON="--model transformer --epochs 100 --batch_size 16 --lr 1e-3 --weight_decay 1e-3 --hidden_dim 32 --downsample 5 --patience 15 --seed 42"
18
+ FUSIONS=(late attention weighted_late gated_late stacking product moe)
19
+
20
+ # ============================================================
21
+ # Phase 0: Pretrain IMU with hidden_dim=48 (matches fusion branch)
22
+ # ============================================================
23
+ PHASE0_JOB=$(sbatch --parsable \
24
+ -J "s3_phase0_imu48" \
25
+ -p gpuA800 \
26
+ --gres=gpu:1 \
27
+ -N 1 -n 1 \
28
+ --cpus-per-task=8 \
29
+ --mem=32G \
30
+ -t 1:00:00 \
31
+ -o "${LOGDIR}/phase0_imu48_%j.out" \
32
+ -e "${LOGDIR}/phase0_imu48_%j.err" \
33
+ --export=ALL \
34
+ --wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --model transformer --fusion early --modalities imu --hidden_dim 48 --epochs 100 --batch_size 16 --lr 1e-3 --weight_decay 1e-3 --downsample 5 --patience 15 --seed 42 --output_dir ${OUTDIR}/phase0")
35
+ echo "Phase 0 (IMU h48): job $PHASE0_JOB"
36
+
37
+ PRETRAINED="${OUTDIR}/phase0/transformer_imu_early/model_best.pt"
38
+
39
+ # ============================================================
40
+ # Baselines (no dependency)
41
+ # ============================================================
42
+
43
+ # Baseline 1: IMU alone + augment + label_smoothing
44
+ sbatch \
45
+ -J "s3_bl_imu_aug" \
46
+ -p gpuA800 \
47
+ --gres=gpu:1 \
48
+ -N 1 -n 1 \
49
+ --cpus-per-task=8 \
50
+ --mem=32G \
51
+ -t 1:00:00 \
52
+ -o "${LOGDIR}/bl_imu_aug_%j.out" \
53
+ -e "${LOGDIR}/bl_imu_aug_%j.err" \
54
+ --export=ALL \
55
+ --wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --fusion early --modalities imu $COMMON --augment --label_smoothing 0.1 --tag bl_aug --output_dir $OUTDIR"
56
+ echo "Submitted: baseline IMU+aug+ls"
57
+
58
+ # Baseline 2: emg,imu early + augment + label_smoothing
59
+ sbatch \
60
+ -J "s3_bl_ei_aug" \
61
+ -p gpuA800 \
62
+ --gres=gpu:1 \
63
+ -N 1 -n 1 \
64
+ --cpus-per-task=8 \
65
+ --mem=32G \
66
+ -t 1:00:00 \
67
+ -o "${LOGDIR}/bl_ei_aug_%j.out" \
68
+ -e "${LOGDIR}/bl_ei_aug_%j.err" \
69
+ --export=ALL \
70
+ --wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --fusion early --modalities emg,imu $COMMON --augment --label_smoothing 0.1 --tag bl_aug --output_dir $OUTDIR"
71
+ echo "Submitted: baseline emg+imu early+aug+ls"
72
+
73
+ # ============================================================
74
+ # Group A: emg+imu x 7 fusion + augment + label_smoothing (no freeze)
75
+ # ============================================================
76
+ for fusion in "${FUSIONS[@]}"; do
77
+ sbatch \
78
+ -J "s3_A_${fusion}" \
79
+ -p gpuA800 \
80
+ --gres=gpu:1 \
81
+ -N 1 -n 1 \
82
+ --cpus-per-task=8 \
83
+ --mem=32G \
84
+ -t 1:00:00 \
85
+ -o "${LOGDIR}/grpA_${fusion}_%j.out" \
86
+ -e "${LOGDIR}/grpA_${fusion}_%j.err" \
87
+ --export=ALL \
88
+ --wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --fusion $fusion --modalities emg,imu $COMMON --augment --label_smoothing 0.1 --tag grpA --output_dir $OUTDIR"
89
+ echo "Submitted: Group A $fusion"
90
+ done
91
+
92
+ # ============================================================
93
+ # Group B: emg+imu x 7 fusion + frozen IMU + label_smoothing (no augment)
94
+ # Depends on Phase 0
95
+ # ============================================================
96
+ for fusion in "${FUSIONS[@]}"; do
97
+ sbatch \
98
+ --dependency=afterok:${PHASE0_JOB} \
99
+ -J "s3_B_${fusion}" \
100
+ -p gpuA800 \
101
+ --gres=gpu:1 \
102
+ -N 1 -n 1 \
103
+ --cpus-per-task=8 \
104
+ --mem=32G \
105
+ -t 1:00:00 \
106
+ -o "${LOGDIR}/grpB_${fusion}_%j.out" \
107
+ -e "${LOGDIR}/grpB_${fusion}_%j.err" \
108
+ --export=ALL \
109
+ --wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --fusion $fusion --modalities emg,imu $COMMON --label_smoothing 0.1 --pretrained_backbone $PRETRAINED --freeze_backbone_idx 1 --tag grpB --output_dir $OUTDIR"
110
+ echo "Submitted: Group B $fusion (dep: $PHASE0_JOB)"
111
+ done
112
+
113
+ # ============================================================
114
+ # Group C: emg+imu x 7 fusion + frozen IMU + augment + label_smoothing
115
+ # Depends on Phase 0
116
+ # ============================================================
117
+ for fusion in "${FUSIONS[@]}"; do
118
+ sbatch \
119
+ --dependency=afterok:${PHASE0_JOB} \
120
+ -J "s3_C_${fusion}" \
121
+ -p gpuA800 \
122
+ --gres=gpu:1 \
123
+ -N 1 -n 1 \
124
+ --cpus-per-task=8 \
125
+ --mem=32G \
126
+ -t 1:00:00 \
127
+ -o "${LOGDIR}/grpC_${fusion}_%j.out" \
128
+ -e "${LOGDIR}/grpC_${fusion}_%j.err" \
129
+ --export=ALL \
130
+ --wrap="export PYTHONUNBUFFERED=1; cd ${PULSE_ROOT}; $PYTHON $SCRIPT --fusion $fusion --modalities emg,imu $COMMON --augment --label_smoothing 0.1 --pretrained_backbone $PRETRAINED --freeze_backbone_idx 1 --tag grpC --output_dir $OUTDIR"
131
+ echo "Submitted: Group C $fusion (dep: $PHASE0_JOB)"
132
+ done
133
+
134
+ echo ""
135
+ echo "Total: 1 phase0 + 2 baselines + 7 grpA + 7 grpB + 7 grpC = 24 jobs"
136
+ echo "Results: $OUTDIR"
137
+ echo "Phase 0 job ID: $PHASE0_JOB (Groups B & C depend on it)"
experiments/slurm/run_exp1_v3.sh ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Scene Recognition (Exp1 v3) - Train 14 vols / Test 4 vols (no val)
3
+ # v23,v24 moved from val to train; v3 stays in test
4
+ # Part 1: 9 modality combos × 3 backbones = 27 jobs (early fusion)
5
+ # Part 2: 7 fusion methods × transformer × (3-core + all-5) = 14 jobs
6
+ # Total: 41 jobs
7
+
8
+ PYTHON=python
9
+ BASEDIR=${PULSE_ROOT}
10
+ SCRIPT=${BASEDIR}/experiments/train_exp1.py
11
+ OUTDIR=${BASEDIR}/results/exp1_v3
12
+ LOGDIR=${OUTDIR}/slurm_logs
13
+ mkdir -p $LOGDIR
14
+
15
+ COMMON="--epochs 100 --batch_size 16 --lr 1e-3 --weight_decay 1e-4 --hidden_dim 128 --downsample 5 --patience 15 --seed 42 --output_dir $OUTDIR"
16
+
17
+ MODS=("mocap" "emg" "eyetrack" "imu" "pressure" "mocap,emg,eyetrack" "mocap,emg,eyetrack,imu" "mocap,emg,eyetrack,pressure" "mocap,emg,eyetrack,imu,pressure")
18
+ MODELS=("cnn" "lstm" "transformer")
19
+
20
+ # Part 1: Modality ablation × 3 backbones
21
+ echo "=== Part 1: Modality Ablation (27 jobs) ==="
22
+ for mods in "${MODS[@]}"; do
23
+ mod_tag=$(echo $mods | tr ',' '-')
24
+ for model in "${MODELS[@]}"; do
25
+ sbatch \
26
+ -J "e1v3_${model}_${mod_tag}" \
27
+ -p gpuA800 \
28
+ --gres=gpu:1 \
29
+ -N 1 -n 1 \
30
+ --cpus-per-task=4 \
31
+ --mem=32G \
32
+ -t 2:00:00 \
33
+ -o "${LOGDIR}/${model}_${mod_tag}_early_%j.out" \
34
+ -e "${LOGDIR}/${model}_${mod_tag}_early_%j.err" \
35
+ --export=ALL \
36
+ --wrap="export PYTHONUNBUFFERED=1; cd ${BASEDIR}; $PYTHON $SCRIPT --model $model --modalities $mods --fusion early $COMMON"
37
+ echo " $model / $mods / early"
38
+ done
39
+ done
40
+
41
+ # Part 2: Fusion methods × transformer
42
+ FUSIONS=("late" "attention" "weighted_late" "gated_late" "stacking" "product" "moe")
43
+ FUSION_MODS=("mocap,emg,eyetrack" "mocap,emg,eyetrack,imu,pressure")
44
+
45
+ echo ""
46
+ echo "=== Part 2: Fusion Ablation (14 jobs) ==="
47
+ for fmods in "${FUSION_MODS[@]}"; do
48
+ fmod_tag=$(echo $fmods | tr ',' '-')
49
+ for fusion in "${FUSIONS[@]}"; do
50
+ sbatch \
51
+ -J "e1v3_tf_${fusion}" \
52
+ -p gpuA800 \
53
+ --gres=gpu:1 \
54
+ -N 1 -n 1 \
55
+ --cpus-per-task=4 \
56
+ --mem=32G \
57
+ -t 2:00:00 \
58
+ -o "${LOGDIR}/transformer_${fmod_tag}_${fusion}_%j.out" \
59
+ -e "${LOGDIR}/transformer_${fmod_tag}_${fusion}_%j.err" \
60
+ --export=ALL \
61
+ --wrap="export PYTHONUNBUFFERED=1; cd ${BASEDIR}; $PYTHON $SCRIPT --model transformer --modalities $fmods --fusion $fusion $COMMON"
62
+ echo " transformer / $fmods / $fusion"
63
+ done
64
+ done
65
+
66
+ echo ""
67
+ echo "Total: 41 jobs | Scene Recognition v3 | Train=14vols, Test=4vols"
68
+ echo "Results: $OUTDIR"
experiments/slurm/run_exp1_v4.sh ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Scene Recognition (Exp1 v4) - Per-modality projection to 50 dims
3
+ # All modalities projected to 50d via FC before backbone processing
4
+ # Train 14 vols / Test 4 vols (no val)
5
+ # Part 1: 9 modality combos × 3 backbones = 27 jobs (early fusion)
6
+ # Part 2: 7 fusion methods × transformer × (3-core + all-5) = 14 jobs
7
+ # Total: 41 jobs
8
+
9
+ PYTHON=python
10
+ BASEDIR=${PULSE_ROOT}
11
+ SCRIPT=${BASEDIR}/experiments/train_exp1.py
12
+ OUTDIR=${BASEDIR}/results/exp1_v4
13
+ LOGDIR=${OUTDIR}/slurm_logs
14
+ mkdir -p $LOGDIR
15
+
16
+ COMMON="--epochs 100 --batch_size 16 --lr 1e-3 --weight_decay 1e-4 --hidden_dim 128 --downsample 5 --patience 15 --seed 42 --output_dir $OUTDIR"
17
+
18
+ MODS=("mocap" "emg" "eyetrack" "imu" "pressure" "mocap,emg,eyetrack" "mocap,emg,eyetrack,imu" "mocap,emg,eyetrack,pressure" "mocap,emg,eyetrack,imu,pressure")
19
+ MODELS=("cnn" "lstm" "transformer")
20
+
21
+ # Part 1: Modality ablation × 3 backbones
22
+ echo "=== Part 1: Modality Ablation (27 jobs) ==="
23
+ for mods in "${MODS[@]}"; do
24
+ mod_tag=$(echo $mods | tr ',' '-')
25
+ for model in "${MODELS[@]}"; do
26
+ sbatch \
27
+ -J "e1v4_${model}_${mod_tag}" \
28
+ -p gpuA800 \
29
+ --gres=gpu:1 \
30
+ -N 1 -n 1 \
31
+ --cpus-per-task=4 \
32
+ --mem=32G \
33
+ -t 2:00:00 \
34
+ -o "${LOGDIR}/${model}_${mod_tag}_early_%j.out" \
35
+ -e "${LOGDIR}/${model}_${mod_tag}_early_%j.err" \
36
+ --export=ALL \
37
+ --wrap="export PYTHONUNBUFFERED=1; cd ${BASEDIR}; $PYTHON $SCRIPT --model $model --modalities $mods --fusion early $COMMON"
38
+ echo " $model / $mods / early"
39
+ done
40
+ done
41
+
42
+ # Part 2: Fusion methods × transformer
43
+ FUSIONS=("late" "attention" "weighted_late" "gated_late" "stacking" "product" "moe")
44
+ FUSION_MODS=("mocap,emg,eyetrack" "mocap,emg,eyetrack,imu,pressure")
45
+
46
+ echo ""
47
+ echo "=== Part 2: Fusion Ablation (14 jobs) ==="
48
+ for fmods in "${FUSION_MODS[@]}"; do
49
+ fmod_tag=$(echo $fmods | tr ',' '-')
50
+ for fusion in "${FUSIONS[@]}"; do
51
+ sbatch \
52
+ -J "e1v4_tf_${fusion}" \
53
+ -p gpuA800 \
54
+ --gres=gpu:1 \
55
+ -N 1 -n 1 \
56
+ --cpus-per-task=4 \
57
+ --mem=32G \
58
+ -t 2:00:00 \
59
+ -o "${LOGDIR}/transformer_${fmod_tag}_${fusion}_%j.out" \
60
+ -e "${LOGDIR}/transformer_${fmod_tag}_${fusion}_%j.err" \
61
+ --export=ALL \
62
+ --wrap="export PYTHONUNBUFFERED=1; cd ${BASEDIR}; $PYTHON $SCRIPT --model transformer --modalities $fmods --fusion $fusion $COMMON"
63
+ echo " transformer / $fmods / $fusion"
64
+ done
65
+ done
66
+
67
+ echo ""
68
+ echo "Total: 41 jobs | Scene Recognition v4 | Proj50d | Train=14vols, Test=4vols"
69
+ echo "Results: $OUTDIR"
experiments/slurm/run_exp1_v5.sh ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Scene Recognition (Exp1 v5) - Only imu, mocap, emg
3
+ # Per-modality projection to 50d
4
+ # Train 14 vols / Test 4 vols
5
+
6
+ PYTHON=python
7
+ BASEDIR=${PULSE_ROOT}
8
+ SCRIPT=${BASEDIR}/experiments/train_exp1.py
9
+ OUTDIR=${BASEDIR}/results/exp1_v5
10
+ LOGDIR=${OUTDIR}/slurm_logs
11
+ mkdir -p $LOGDIR
12
+
13
+ COMMON="--epochs 100 --batch_size 16 --lr 1e-3 --weight_decay 1e-4 --hidden_dim 128 --downsample 5 --patience 15 --seed 42 --output_dir $OUTDIR"
14
+ MODELS=("cnn" "lstm" "transformer")
15
+
16
+ # Part 1: Single modality (3 mods × 3 backbones = 9 jobs)
17
+ echo "=== Part 1: Single Modality (9 jobs) ==="
18
+ for mods in "imu" "mocap" "emg"; do
19
+ for model in "${MODELS[@]}"; do
20
+ sbatch -J "e1v5_${model}_${mods}" -p gpuA800 --gres=gpu:1 -N1 -n1 \
21
+ --cpus-per-task=4 --mem=32G -t 2:00:00 \
22
+ -o "${LOGDIR}/${model}_${mods}_early_%j.out" \
23
+ -e "${LOGDIR}/${model}_${mods}_early_%j.err" \
24
+ --export=ALL \
25
+ --wrap="export PYTHONUNBUFFERED=1; cd ${BASEDIR}; $PYTHON $SCRIPT --model $model --modalities $mods --fusion early $COMMON"
26
+ echo " $model / $mods / early"
27
+ done
28
+ done
29
+
30
+ # Part 2: Multi-modality early fusion (4 combos × 3 backbones = 12 jobs)
31
+ echo ""
32
+ echo "=== Part 2: Multi-Modality Early Fusion (12 jobs) ==="
33
+ for mods in "imu,mocap" "imu,emg" "mocap,emg" "imu,mocap,emg"; do
34
+ mod_tag=$(echo $mods | tr ',' '-')
35
+ for model in "${MODELS[@]}"; do
36
+ sbatch -J "e1v5_${model}_${mod_tag}" -p gpuA800 --gres=gpu:1 -N1 -n1 \
37
+ --cpus-per-task=4 --mem=32G -t 2:00:00 \
38
+ -o "${LOGDIR}/${model}_${mod_tag}_early_%j.out" \
39
+ -e "${LOGDIR}/${model}_${mod_tag}_early_%j.err" \
40
+ --export=ALL \
41
+ --wrap="export PYTHONUNBUFFERED=1; cd ${BASEDIR}; $PYTHON $SCRIPT --model $model --modalities $mods --fusion early $COMMON"
42
+ echo " $model / $mods / early"
43
+ done
44
+ done
45
+
46
+ # Part 3: Fusion ablation with imu+mocap+emg × transformer (7 jobs)
47
+ FUSIONS=("late" "attention" "weighted_late" "gated_late" "stacking" "product" "moe")
48
+ echo ""
49
+ echo "=== Part 3: Fusion Ablation - transformer × imu+mocap+emg (7 jobs) ==="
50
+ for fusion in "${FUSIONS[@]}"; do
51
+ sbatch -J "e1v5_tf_${fusion}" -p gpuA800 --gres=gpu:1 -N1 -n1 \
52
+ --cpus-per-task=4 --mem=32G -t 2:00:00 \
53
+ -o "${LOGDIR}/transformer_imu-mocap-emg_${fusion}_%j.out" \
54
+ -e "${LOGDIR}/transformer_imu-mocap-emg_${fusion}_%j.err" \
55
+ --export=ALL \
56
+ --wrap="export PYTHONUNBUFFERED=1; cd ${BASEDIR}; $PYTHON $SCRIPT --model transformer --modalities imu,mocap,emg --fusion $fusion $COMMON"
57
+ echo " transformer / imu,mocap,emg / $fusion"
58
+ done
59
+
60
+ echo ""
61
+ echo "Total: 28 jobs | 3 modalities: imu(160d→50d), mocap(156d→50d), emg(8d→50d)"
62
+ echo "Results: $OUTDIR"