| |
| """ |
| Experiment E: Grasp onset anticipation. |
| |
| Binary classification task derived from the paper's case-study finding that |
| EMG activation and hand motion precede physical contact by ~570--590 ms. |
| |
| Task: given a 1.0s pre-contact sensor window ending at t = contact_onset - |
| 500 ms, classify whether a grasp contact event follows within the next 500 ms. |
| |
| Positive samples = "clean" grasp events (contact rises from <5g to >5g, |
| with quiescent baseline over [-1500,-1000]ms and rise over [-500,0]ms). |
| Negative samples = random 1.0s windows drawn from quiescent periods (no |
| contact above 5g for the following 1.5 s). |
| |
| This turns the paper's anticipatory-coordination analysis into a |
| reproducible benchmark, directly exploiting the unique value of |
| synchronised multi-modal sensing. |
| """ |
|
|
| import os |
| import sys |
| import json |
| import time |
| import random |
| import argparse |
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import Dataset, DataLoader |
| from torch.nn.utils.rnn import pad_sequence |
| from sklearn.metrics import ( |
| accuracy_score, f1_score, roc_auc_score, average_precision_score, |
| ) |
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| from data.dataset import ( |
| DATASET_DIR, MODALITY_FILES, TRAIN_VOLS, TEST_VOLS, |
| load_modality_array, SCENE_LABELS, |
| ) |
|
|
| WINDOW_LEN_SEC = 1.0 |
| LEAD_SEC = 0.5 |
| BASELINE_WINDOW_SEC = (1.5, 1.0) |
| RISE_WINDOW_SEC = (0.5, 0.0) |
| CONTACT_THRESHOLD = 5.0 |
|
|
|
|
| |
| |
| |
|
|
| def detect_grasp_events(pressure_csv, sr=100): |
| """Return list of contact-onset indices (int) on clean grasp events.""" |
| try: |
| df = pd.read_csv(pressure_csv) |
| except Exception: |
| return [] |
| vals = df.iloc[:, 1:].values.astype(np.float32) |
| total = vals.sum(axis=1) |
| events = [] |
| below = True |
| T = len(total) |
| i = 0 |
| while i < T: |
| if below and total[i] > CONTACT_THRESHOLD: |
| |
| onset = i |
| b0 = int(onset - BASELINE_WINDOW_SEC[0] * sr) |
| b1 = int(onset - BASELINE_WINDOW_SEC[1] * sr) |
| r0 = int(onset - RISE_WINDOW_SEC[0] * sr) |
| r1 = int(onset - RISE_WINDOW_SEC[1] * sr) |
| if b0 >= 0 and r0 >= 0: |
| baseline = total[b0:b1] |
| rise = total[r0:r1] |
| if (baseline.max() < CONTACT_THRESHOLD and |
| rise.mean() < 3 * CONTACT_THRESHOLD): |
| events.append(onset) |
| below = False |
| i += int(0.5 * sr) |
| else: |
| if total[i] < 1.0: |
| below = True |
| i += 1 |
| return events |
|
|
|
|
| def sample_negative_windows(total_signal, positives, n_neg, rng, sr=100, |
| win_sec=WINDOW_LEN_SEC, lookahead_sec=1.5): |
| """Pick random onsets where the following lookahead period is contact-free.""" |
| T = len(total_signal) |
| wlen = int(win_sec * sr) |
| la = int(lookahead_sec * sr) |
| pos_set = set(positives) |
| tries = 0 |
| found = [] |
| while len(found) < n_neg and tries < 10 * n_neg: |
| tries += 1 |
| t = rng.randint(wlen + int(LEAD_SEC * sr), |
| max(T - la, wlen + int(LEAD_SEC * sr) + 1)) |
| |
| if any(abs(t - p) < 2 * sr for p in positives): |
| continue |
| |
| if total_signal[t:t + la].max() >= CONTACT_THRESHOLD: |
| continue |
| found.append(t) |
| return found |
|
|
|
|
| |
| |
| |
|
|
| class AnticipationDataset(Dataset): |
| """Per-event sensor window -> binary label.""" |
|
|
| def __init__(self, volunteers, modalities, downsample=5, stats=None, |
| seed=0, neg_per_pos=1.0): |
| self.modalities = modalities |
| self.downsample = downsample |
| self.items = [] |
| self._modality_dims = {} |
| rng = np.random.RandomState(seed) |
| n_pos = 0 |
| n_neg = 0 |
|
|
| for vol in volunteers: |
| vol_dir = os.path.join(DATASET_DIR, vol) |
| if not os.path.isdir(vol_dir): |
| continue |
| for scenario in sorted(os.listdir(vol_dir)): |
| scenario_dir = os.path.join(vol_dir, scenario) |
| if not os.path.isdir(scenario_dir) or scenario not in SCENE_LABELS: |
| continue |
| pressure_fp = os.path.join(scenario_dir, |
| 'aligned_pressure_100hz.csv') |
| if not os.path.exists(pressure_fp): |
| continue |
|
|
| |
| parts = [] |
| skip = False |
| for mod in modalities: |
| if mod == 'mocap': |
| fp = os.path.join( |
| scenario_dir, f"aligned_{vol}{scenario}_s_Q.tsv" |
| ) |
| else: |
| fp = os.path.join(scenario_dir, MODALITY_FILES[mod]) |
| if not os.path.exists(fp): |
| skip = True |
| break |
| arr = load_modality_array(fp, mod) |
| if arr is None: |
| skip = True |
| break |
| if mod in self._modality_dims and arr.shape[1] != self._modality_dims[mod]: |
| expected = self._modality_dims[mod] |
| if arr.shape[1] < expected: |
| pad = np.zeros((arr.shape[0], expected - arr.shape[1]), |
| dtype=np.float32) |
| arr = np.concatenate([arr, pad], axis=1) |
| else: |
| arr = arr[:, :expected] |
| if mod not in self._modality_dims: |
| self._modality_dims[mod] = arr.shape[1] |
| parts.append(arr) |
| if skip: |
| continue |
|
|
| T_min = min(p.shape[0] for p in parts) |
| combined = np.concatenate([p[:T_min] for p in parts], axis=1) |
|
|
| |
| try: |
| pdf = pd.read_csv(pressure_fp) |
| pvals = pdf.iloc[:, 1:].values.astype(np.float32)[:T_min] |
| total = pvals.sum(axis=1) |
| except Exception: |
| continue |
| positives = detect_grasp_events(pressure_fp) |
| positives = [p for p in positives |
| if p - int((WINDOW_LEN_SEC + LEAD_SEC) * 100) >= 0 |
| and p < T_min] |
|
|
| |
| win_samples = int(WINDOW_LEN_SEC * 100) |
| lead_samples = int(LEAD_SEC * 100) |
| for p in positives: |
| s = p - win_samples - lead_samples |
| e = p - lead_samples |
| if s < 0 or e > T_min: |
| continue |
| window = combined[s:e] |
| window = window[::downsample] |
| if window.shape[0] < 4: |
| continue |
| self.items.append({'x': window.astype(np.float32), 'y': 1, |
| 'src': f"{vol}/{scenario}@{p}"}) |
| n_pos += 1 |
|
|
| |
| n_neg_want = int(len(positives) * neg_per_pos) |
| neg_onsets = sample_negative_windows(total, positives, n_neg_want, |
| rng) |
| for t in neg_onsets: |
| s = t - win_samples - lead_samples |
| e = t - lead_samples |
| if s < 0 or e > T_min: |
| continue |
| window = combined[s:e] |
| window = window[::downsample] |
| if window.shape[0] < 4: |
| continue |
| self.items.append({'x': window.astype(np.float32), 'y': 0, |
| 'src': f"{vol}/{scenario}@{t}-neg"}) |
| n_neg += 1 |
|
|
| if len(self.items) == 0: |
| raise RuntimeError("No samples collected.") |
| print(f" pos={n_pos} neg={n_neg} total={len(self.items)} " |
| f"feat_dim={sum(self._modality_dims.values())}") |
|
|
| |
| all_ = np.concatenate([it['x'] for it in self.items], axis=0).astype(np.float64) |
| if stats is not None: |
| self.mean, self.std = stats |
| else: |
| self.mean = all_.mean(axis=0, keepdims=True) |
| self.std = all_.std(axis=0, keepdims=True) |
| self.std[self.std < 1e-8] = 1.0 |
| for it in self.items: |
| it['x'] = ((it['x'].astype(np.float64) - self.mean) / |
| self.std).astype(np.float32) |
| it['x'] = np.nan_to_num(it['x'], nan=0.0, posinf=0.0, neginf=0.0) |
|
|
| def get_stats(self): |
| return (self.mean, self.std) |
|
|
| @property |
| def feat_dim(self): |
| return sum(self._modality_dims.values()) |
|
|
| def __len__(self): |
| return len(self.items) |
|
|
| def __getitem__(self, idx): |
| it = self.items[idx] |
| return torch.from_numpy(it['x']), it['y'] |
|
|
|
|
| def collate_fn(batch): |
| seqs, ys = zip(*batch) |
| lens = torch.LongTensor([s.shape[0] for s in seqs]) |
| padded = pad_sequence(seqs, batch_first=True, padding_value=0.0) |
| max_len = padded.shape[1] |
| mask = torch.arange(max_len).unsqueeze(0) < lens.unsqueeze(1) |
| return padded, torch.LongTensor(ys), mask, lens |
|
|
|
|
| |
| |
| |
|
|
| class BinaryClassifier(nn.Module): |
| def __init__(self, feat_dim, hidden_dim=128, n_layers=2, n_heads=4, |
| dropout=0.2, backbone='transformer'): |
| super().__init__() |
| self.backbone = backbone |
| if backbone == 'transformer': |
| self.in_proj = nn.Linear(feat_dim, hidden_dim) |
| self.pos = nn.Parameter(torch.zeros(1, 256, hidden_dim)) |
| nn.init.trunc_normal_(self.pos, std=0.02) |
| layer = nn.TransformerEncoderLayer( |
| d_model=hidden_dim, nhead=n_heads, |
| dim_feedforward=4 * hidden_dim, dropout=dropout, |
| batch_first=True, activation='gelu', |
| ) |
| self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers) |
| self.head = nn.Sequential( |
| nn.LayerNorm(hidden_dim), |
| nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), |
| nn.Linear(hidden_dim, 2), |
| ) |
| elif backbone == 'lstm': |
| self.lstm = nn.LSTM(feat_dim, hidden_dim, num_layers=2, |
| batch_first=True, bidirectional=True, |
| dropout=dropout) |
| self.head = nn.Sequential( |
| nn.LayerNorm(2 * hidden_dim), |
| nn.Linear(2 * hidden_dim, hidden_dim), nn.GELU(), |
| nn.Dropout(dropout), nn.Linear(hidden_dim, 2), |
| ) |
| else: |
| raise ValueError(backbone) |
|
|
| def forward(self, x, mask): |
| if self.backbone == 'transformer': |
| T = x.size(1) |
| h = self.in_proj(x) + self.pos[:, :T, :] |
| key_padding = ~mask |
| h = self.encoder(h, src_key_padding_mask=key_padding) |
| else: |
| h, _ = self.lstm(x) |
| m = mask.unsqueeze(-1).float() |
| pooled = (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0) |
| return self.head(pooled) |
|
|
|
|
| |
| |
| |
|
|
| def set_seed(seed): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def run_experiment(args): |
| set_seed(args.seed) |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
| modalities = args.modalities.split(',') |
| print(f"Backbone: {args.backbone} | Modalities: {modalities} | Seed: {args.seed}") |
|
|
| print("Loading train...") |
| train_ds = AnticipationDataset(TRAIN_VOLS, modalities, |
| downsample=args.downsample, seed=args.seed) |
| stats = train_ds.get_stats() |
| print("Loading test...") |
| test_ds = AnticipationDataset(TEST_VOLS, modalities, |
| downsample=args.downsample, |
| stats=stats, seed=args.seed + 100) |
|
|
| train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, |
| collate_fn=collate_fn, num_workers=0, drop_last=True) |
| test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, |
| collate_fn=collate_fn, num_workers=0) |
|
|
| model = BinaryClassifier(train_ds.feat_dim, hidden_dim=args.hidden_dim, |
| dropout=args.dropout, backbone=args.backbone).to(device) |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"Params: {n_params:,}") |
|
|
| optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, |
| weight_decay=args.weight_decay) |
| criterion = nn.CrossEntropyLoss(label_smoothing=0.1) |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
| optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-6, |
| ) |
|
|
| mod_str = '-'.join(modalities) |
| exp_name = f"antic_{args.backbone}_{mod_str}_seed{args.seed}" |
| if args.tag: |
| exp_name += f"_{args.tag}" |
| out_dir = os.path.join(args.output_dir, exp_name) |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| best_f1 = 0.0 |
| best_metrics = None |
| best_state = None |
| best_epoch = 0 |
| patience_counter = 0 |
|
|
| for epoch in range(1, args.epochs + 1): |
| t0 = time.time() |
| model.train() |
| tr_loss, tr_n = 0.0, 0 |
| for x, y, mask, _ in train_loader: |
| x, y, mask = x.to(device), y.to(device), mask.to(device) |
| optimizer.zero_grad() |
| logits = model(x, mask) |
| loss = criterion(logits, y) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| tr_loss += loss.item() * y.size(0) |
| tr_n += y.size(0) |
| tr_loss /= max(tr_n, 1) |
|
|
| |
| model.eval() |
| all_logits, all_y = [], [] |
| te_loss, te_n = 0.0, 0 |
| with torch.no_grad(): |
| for x, y, mask, _ in test_loader: |
| x, y, mask = x.to(device), y.to(device), mask.to(device) |
| logits = model(x, mask) |
| loss = criterion(logits, y) |
| te_loss += loss.item() * y.size(0) |
| te_n += y.size(0) |
| all_logits.append(logits.cpu()) |
| all_y.append(y.cpu()) |
| all_logits = torch.cat(all_logits, dim=0).numpy() |
| all_y = torch.cat(all_y, dim=0).numpy() |
| preds = all_logits.argmax(axis=1) |
| probs = torch.softmax(torch.from_numpy(all_logits), dim=1)[:, 1].numpy() |
| acc = accuracy_score(all_y, preds) |
| f1 = f1_score(all_y, preds, average='binary', zero_division=0) |
| try: |
| auc = roc_auc_score(all_y, probs) |
| except Exception: |
| auc = 0.5 |
| try: |
| ap = average_precision_score(all_y, probs) |
| except Exception: |
| ap = 0.5 |
| scheduler.step(te_loss / max(te_n, 1)) |
|
|
| print(f" E{epoch:3d} | tr {tr_loss:.4f} | te {te_loss/max(te_n,1):.4f} " |
| f"acc {acc:.3f} f1 {f1:.3f} auc {auc:.3f} ap {ap:.3f} | " |
| f"{time.time()-t0:.1f}s") |
| if f1 > best_f1: |
| best_f1 = f1 |
| best_metrics = {'acc': float(acc), 'f1': float(f1), |
| 'auc': float(auc), 'ap': float(ap)} |
| best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} |
| best_epoch = epoch |
| patience_counter = 0 |
| else: |
| patience_counter += 1 |
| if patience_counter >= args.patience: |
| print(f" Early stop (best epoch {best_epoch})") |
| break |
|
|
| if best_state is not None: |
| torch.save(best_state, os.path.join(out_dir, 'model_best.pt')) |
|
|
| results = { |
| 'experiment': exp_name, |
| 'backbone': args.backbone, |
| 'modalities': modalities, |
| 'seed': args.seed, |
| 'best_epoch': best_epoch, |
| 'best_test_metrics': best_metrics, |
| 'train_size': len(train_ds), |
| 'test_size': len(test_ds), |
| 'train_pos_frac': float(np.mean([it['y'] for it in train_ds.items])), |
| 'test_pos_frac': float(np.mean([it['y'] for it in test_ds.items])), |
| 'feat_dim': train_ds.feat_dim, |
| 'window_sec': WINDOW_LEN_SEC, |
| 'lead_sec': LEAD_SEC, |
| 'args': vars(args), |
| } |
| with open(os.path.join(out_dir, 'results.json'), 'w') as f: |
| json.dump(results, f, indent=2) |
| print(f"Saved: {out_dir}/results.json") |
| return results |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument('--backbone', type=str, default='transformer', |
| choices=['transformer', 'lstm']) |
| p.add_argument('--modalities', type=str, default='emg,imu') |
| p.add_argument('--epochs', type=int, default=50) |
| p.add_argument('--batch_size', type=int, default=32) |
| p.add_argument('--lr', type=float, default=5e-4) |
| p.add_argument('--weight_decay', type=float, default=1e-4) |
| p.add_argument('--hidden_dim', type=int, default=128) |
| p.add_argument('--dropout', type=float, default=0.2) |
| p.add_argument('--downsample', type=int, default=5) |
| p.add_argument('--patience', type=int, default=10) |
| p.add_argument('--seed', type=int, default=42) |
| p.add_argument('--output_dir', type=str, required=True) |
| p.add_argument('--tag', type=str, default='') |
| args = p.parse_args() |
| run_experiment(args) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|