#!/usr/bin/env python3 """ Experiment E: Grasp onset anticipation. Binary classification task derived from the paper's case-study finding that EMG activation and hand motion precede physical contact by ~570--590 ms. Task: given a 1.0s pre-contact sensor window ending at t = contact_onset - 500 ms, classify whether a grasp contact event follows within the next 500 ms. Positive samples = "clean" grasp events (contact rises from <5g to >5g, with quiescent baseline over [-1500,-1000]ms and rise over [-500,0]ms). Negative samples = random 1.0s windows drawn from quiescent periods (no contact above 5g for the following 1.5 s). This turns the paper's anticipatory-coordination analysis into a reproducible benchmark, directly exploiting the unique value of synchronised multi-modal sensing. """ import os import sys import json import time import random import argparse import numpy as np import pandas as pd import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from torch.nn.utils.rnn import pad_sequence from sklearn.metrics import ( accuracy_score, f1_score, roc_auc_score, average_precision_score, ) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.dataset import ( DATASET_DIR, MODALITY_FILES, TRAIN_VOLS, TEST_VOLS, load_modality_array, SCENE_LABELS, ) WINDOW_LEN_SEC = 1.0 LEAD_SEC = 0.5 # gap between window end and contact onset BASELINE_WINDOW_SEC = (1.5, 1.0) # [-1.5, -1.0]s should be quiescent RISE_WINDOW_SEC = (0.5, 0.0) # [-0.5, 0]s should show rise CONTACT_THRESHOLD = 5.0 # grams # --------------------------------------------------------------------------- # Event detection # --------------------------------------------------------------------------- def detect_grasp_events(pressure_csv, sr=100): """Return list of contact-onset indices (int) on clean grasp events.""" try: df = pd.read_csv(pressure_csv) except Exception: return [] vals = df.iloc[:, 1:].values.astype(np.float32) # (T, 50) grams total = vals.sum(axis=1) events = [] below = True T = len(total) i = 0 while i < T: if below and total[i] > CONTACT_THRESHOLD: # detected rise onset; verify clean-grasp conditions onset = i b0 = int(onset - BASELINE_WINDOW_SEC[0] * sr) b1 = int(onset - BASELINE_WINDOW_SEC[1] * sr) r0 = int(onset - RISE_WINDOW_SEC[0] * sr) r1 = int(onset - RISE_WINDOW_SEC[1] * sr) if b0 >= 0 and r0 >= 0: baseline = total[b0:b1] rise = total[r0:r1] if (baseline.max() < CONTACT_THRESHOLD and rise.mean() < 3 * CONTACT_THRESHOLD): events.append(onset) below = False i += int(0.5 * sr) # skip ahead 0.5 s to avoid double-detect else: if total[i] < 1.0: below = True i += 1 return events def sample_negative_windows(total_signal, positives, n_neg, rng, sr=100, win_sec=WINDOW_LEN_SEC, lookahead_sec=1.5): """Pick random onsets where the following lookahead period is contact-free.""" T = len(total_signal) wlen = int(win_sec * sr) la = int(lookahead_sec * sr) pos_set = set(positives) tries = 0 found = [] while len(found) < n_neg and tries < 10 * n_neg: tries += 1 t = rng.randint(wlen + int(LEAD_SEC * sr), max(T - la, wlen + int(LEAD_SEC * sr) + 1)) # reject if near a positive if any(abs(t - p) < 2 * sr for p in positives): continue # require no contact above threshold in [t, t+la] if total_signal[t:t + la].max() >= CONTACT_THRESHOLD: continue found.append(t) return found # --------------------------------------------------------------------------- # Dataset # --------------------------------------------------------------------------- class AnticipationDataset(Dataset): """Per-event sensor window -> binary label.""" def __init__(self, volunteers, modalities, downsample=5, stats=None, seed=0, neg_per_pos=1.0): self.modalities = modalities self.downsample = downsample self.items = [] self._modality_dims = {} rng = np.random.RandomState(seed) n_pos = 0 n_neg = 0 for vol in volunteers: vol_dir = os.path.join(DATASET_DIR, vol) if not os.path.isdir(vol_dir): continue for scenario in sorted(os.listdir(vol_dir)): scenario_dir = os.path.join(vol_dir, scenario) if not os.path.isdir(scenario_dir) or scenario not in SCENE_LABELS: continue pressure_fp = os.path.join(scenario_dir, 'aligned_pressure_100hz.csv') if not os.path.exists(pressure_fp): continue # Load sensor modalities parts = [] skip = False for mod in modalities: if mod == 'mocap': fp = os.path.join( scenario_dir, f"aligned_{vol}{scenario}_s_Q.tsv" ) else: fp = os.path.join(scenario_dir, MODALITY_FILES[mod]) if not os.path.exists(fp): skip = True break arr = load_modality_array(fp, mod) if arr is None: skip = True break if mod in self._modality_dims and arr.shape[1] != self._modality_dims[mod]: expected = self._modality_dims[mod] if arr.shape[1] < expected: pad = np.zeros((arr.shape[0], expected - arr.shape[1]), dtype=np.float32) arr = np.concatenate([arr, pad], axis=1) else: arr = arr[:, :expected] if mod not in self._modality_dims: self._modality_dims[mod] = arr.shape[1] parts.append(arr) if skip: continue T_min = min(p.shape[0] for p in parts) combined = np.concatenate([p[:T_min] for p in parts], axis=1) # Detect positive grasp events try: pdf = pd.read_csv(pressure_fp) pvals = pdf.iloc[:, 1:].values.astype(np.float32)[:T_min] total = pvals.sum(axis=1) except Exception: continue positives = detect_grasp_events(pressure_fp) positives = [p for p in positives if p - int((WINDOW_LEN_SEC + LEAD_SEC) * 100) >= 0 and p < T_min] # Window = [contact - (win + lead), contact - lead] win_samples = int(WINDOW_LEN_SEC * 100) lead_samples = int(LEAD_SEC * 100) for p in positives: s = p - win_samples - lead_samples e = p - lead_samples if s < 0 or e > T_min: continue window = combined[s:e] window = window[::downsample] if window.shape[0] < 4: continue self.items.append({'x': window.astype(np.float32), 'y': 1, 'src': f"{vol}/{scenario}@{p}"}) n_pos += 1 # Sample negatives n_neg_want = int(len(positives) * neg_per_pos) neg_onsets = sample_negative_windows(total, positives, n_neg_want, rng) for t in neg_onsets: s = t - win_samples - lead_samples e = t - lead_samples if s < 0 or e > T_min: continue window = combined[s:e] window = window[::downsample] if window.shape[0] < 4: continue self.items.append({'x': window.astype(np.float32), 'y': 0, 'src': f"{vol}/{scenario}@{t}-neg"}) n_neg += 1 if len(self.items) == 0: raise RuntimeError("No samples collected.") print(f" pos={n_pos} neg={n_neg} total={len(self.items)} " f"feat_dim={sum(self._modality_dims.values())}") # Normalize all_ = np.concatenate([it['x'] for it in self.items], axis=0).astype(np.float64) if stats is not None: self.mean, self.std = stats else: self.mean = all_.mean(axis=0, keepdims=True) self.std = all_.std(axis=0, keepdims=True) self.std[self.std < 1e-8] = 1.0 for it in self.items: it['x'] = ((it['x'].astype(np.float64) - self.mean) / self.std).astype(np.float32) it['x'] = np.nan_to_num(it['x'], nan=0.0, posinf=0.0, neginf=0.0) def get_stats(self): return (self.mean, self.std) @property def feat_dim(self): return sum(self._modality_dims.values()) def __len__(self): return len(self.items) def __getitem__(self, idx): it = self.items[idx] return torch.from_numpy(it['x']), it['y'] def collate_fn(batch): seqs, ys = zip(*batch) lens = torch.LongTensor([s.shape[0] for s in seqs]) padded = pad_sequence(seqs, batch_first=True, padding_value=0.0) max_len = padded.shape[1] mask = torch.arange(max_len).unsqueeze(0) < lens.unsqueeze(1) return padded, torch.LongTensor(ys), mask, lens # --------------------------------------------------------------------------- # Model (binary classifier, reuse Transformer backbone idea) # --------------------------------------------------------------------------- class BinaryClassifier(nn.Module): def __init__(self, feat_dim, hidden_dim=128, n_layers=2, n_heads=4, dropout=0.2, backbone='transformer'): super().__init__() self.backbone = backbone if backbone == 'transformer': self.in_proj = nn.Linear(feat_dim, hidden_dim) self.pos = nn.Parameter(torch.zeros(1, 256, hidden_dim)) nn.init.trunc_normal_(self.pos, std=0.02) layer = nn.TransformerEncoderLayer( d_model=hidden_dim, nhead=n_heads, dim_feedforward=4 * hidden_dim, dropout=dropout, batch_first=True, activation='gelu', ) self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers) self.head = nn.Sequential( nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, 2), ) elif backbone == 'lstm': self.lstm = nn.LSTM(feat_dim, hidden_dim, num_layers=2, batch_first=True, bidirectional=True, dropout=dropout) self.head = nn.Sequential( nn.LayerNorm(2 * hidden_dim), nn.Linear(2 * hidden_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, 2), ) else: raise ValueError(backbone) def forward(self, x, mask): if self.backbone == 'transformer': T = x.size(1) h = self.in_proj(x) + self.pos[:, :T, :] key_padding = ~mask h = self.encoder(h, src_key_padding_mask=key_padding) else: h, _ = self.lstm(x) m = mask.unsqueeze(-1).float() pooled = (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0) return self.head(pooled) # --------------------------------------------------------------------------- # Train / Eval # --------------------------------------------------------------------------- def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def run_experiment(args): set_seed(args.seed) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") modalities = args.modalities.split(',') print(f"Backbone: {args.backbone} | Modalities: {modalities} | Seed: {args.seed}") print("Loading train...") train_ds = AnticipationDataset(TRAIN_VOLS, modalities, downsample=args.downsample, seed=args.seed) stats = train_ds.get_stats() print("Loading test...") test_ds = AnticipationDataset(TEST_VOLS, modalities, downsample=args.downsample, stats=stats, seed=args.seed + 100) train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=0, drop_last=True) test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0) model = BinaryClassifier(train_ds.feat_dim, hidden_dim=args.hidden_dim, dropout=args.dropout, backbone=args.backbone).to(device) n_params = sum(p.numel() for p in model.parameters()) print(f"Params: {n_params:,}") optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) criterion = nn.CrossEntropyLoss(label_smoothing=0.1) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-6, ) mod_str = '-'.join(modalities) exp_name = f"antic_{args.backbone}_{mod_str}_seed{args.seed}" if args.tag: exp_name += f"_{args.tag}" out_dir = os.path.join(args.output_dir, exp_name) os.makedirs(out_dir, exist_ok=True) best_f1 = 0.0 best_metrics = None best_state = None best_epoch = 0 patience_counter = 0 for epoch in range(1, args.epochs + 1): t0 = time.time() model.train() tr_loss, tr_n = 0.0, 0 for x, y, mask, _ in train_loader: x, y, mask = x.to(device), y.to(device), mask.to(device) optimizer.zero_grad() logits = model(x, mask) loss = criterion(logits, y) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() tr_loss += loss.item() * y.size(0) tr_n += y.size(0) tr_loss /= max(tr_n, 1) # Eval model.eval() all_logits, all_y = [], [] te_loss, te_n = 0.0, 0 with torch.no_grad(): for x, y, mask, _ in test_loader: x, y, mask = x.to(device), y.to(device), mask.to(device) logits = model(x, mask) loss = criterion(logits, y) te_loss += loss.item() * y.size(0) te_n += y.size(0) all_logits.append(logits.cpu()) all_y.append(y.cpu()) all_logits = torch.cat(all_logits, dim=0).numpy() all_y = torch.cat(all_y, dim=0).numpy() preds = all_logits.argmax(axis=1) probs = torch.softmax(torch.from_numpy(all_logits), dim=1)[:, 1].numpy() acc = accuracy_score(all_y, preds) f1 = f1_score(all_y, preds, average='binary', zero_division=0) try: auc = roc_auc_score(all_y, probs) except Exception: auc = 0.5 try: ap = average_precision_score(all_y, probs) except Exception: ap = 0.5 scheduler.step(te_loss / max(te_n, 1)) print(f" E{epoch:3d} | tr {tr_loss:.4f} | te {te_loss/max(te_n,1):.4f} " f"acc {acc:.3f} f1 {f1:.3f} auc {auc:.3f} ap {ap:.3f} | " f"{time.time()-t0:.1f}s") if f1 > best_f1: best_f1 = f1 best_metrics = {'acc': float(acc), 'f1': float(f1), 'auc': float(auc), 'ap': float(ap)} best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} best_epoch = epoch patience_counter = 0 else: patience_counter += 1 if patience_counter >= args.patience: print(f" Early stop (best epoch {best_epoch})") break if best_state is not None: torch.save(best_state, os.path.join(out_dir, 'model_best.pt')) results = { 'experiment': exp_name, 'backbone': args.backbone, 'modalities': modalities, 'seed': args.seed, 'best_epoch': best_epoch, 'best_test_metrics': best_metrics, 'train_size': len(train_ds), 'test_size': len(test_ds), 'train_pos_frac': float(np.mean([it['y'] for it in train_ds.items])), 'test_pos_frac': float(np.mean([it['y'] for it in test_ds.items])), 'feat_dim': train_ds.feat_dim, 'window_sec': WINDOW_LEN_SEC, 'lead_sec': LEAD_SEC, 'args': vars(args), } with open(os.path.join(out_dir, 'results.json'), 'w') as f: json.dump(results, f, indent=2) print(f"Saved: {out_dir}/results.json") return results def main(): p = argparse.ArgumentParser() p.add_argument('--backbone', type=str, default='transformer', choices=['transformer', 'lstm']) p.add_argument('--modalities', type=str, default='emg,imu') p.add_argument('--epochs', type=int, default=50) p.add_argument('--batch_size', type=int, default=32) p.add_argument('--lr', type=float, default=5e-4) p.add_argument('--weight_decay', type=float, default=1e-4) p.add_argument('--hidden_dim', type=int, default=128) p.add_argument('--dropout', type=float, default=0.2) p.add_argument('--downsample', type=int, default=5) p.add_argument('--patience', type=int, default=10) p.add_argument('--seed', type=int, default=42) p.add_argument('--output_dir', type=str, required=True) p.add_argument('--tag', type=str, default='') args = p.parse_args() run_experiment(args) if __name__ == '__main__': main()