PULSE-code / experiments /tasks /train_exp_anticipate.py
velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
#!/usr/bin/env python3
"""
Experiment E: Grasp onset anticipation.
Binary classification task derived from the paper's case-study finding that
EMG activation and hand motion precede physical contact by ~570--590 ms.
Task: given a 1.0s pre-contact sensor window ending at t = contact_onset -
500 ms, classify whether a grasp contact event follows within the next 500 ms.
Positive samples = "clean" grasp events (contact rises from <5g to >5g,
with quiescent baseline over [-1500,-1000]ms and rise over [-500,0]ms).
Negative samples = random 1.0s windows drawn from quiescent periods (no
contact above 5g for the following 1.5 s).
This turns the paper's anticipatory-coordination analysis into a
reproducible benchmark, directly exploiting the unique value of
synchronised multi-modal sensing.
"""
import os
import sys
import json
import time
import random
import argparse
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.metrics import (
accuracy_score, f1_score, roc_auc_score, average_precision_score,
)
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.dataset import (
DATASET_DIR, MODALITY_FILES, TRAIN_VOLS, TEST_VOLS,
load_modality_array, SCENE_LABELS,
)
WINDOW_LEN_SEC = 1.0
LEAD_SEC = 0.5 # gap between window end and contact onset
BASELINE_WINDOW_SEC = (1.5, 1.0) # [-1.5, -1.0]s should be quiescent
RISE_WINDOW_SEC = (0.5, 0.0) # [-0.5, 0]s should show rise
CONTACT_THRESHOLD = 5.0 # grams
# ---------------------------------------------------------------------------
# Event detection
# ---------------------------------------------------------------------------
def detect_grasp_events(pressure_csv, sr=100):
"""Return list of contact-onset indices (int) on clean grasp events."""
try:
df = pd.read_csv(pressure_csv)
except Exception:
return []
vals = df.iloc[:, 1:].values.astype(np.float32) # (T, 50) grams
total = vals.sum(axis=1)
events = []
below = True
T = len(total)
i = 0
while i < T:
if below and total[i] > CONTACT_THRESHOLD:
# detected rise onset; verify clean-grasp conditions
onset = i
b0 = int(onset - BASELINE_WINDOW_SEC[0] * sr)
b1 = int(onset - BASELINE_WINDOW_SEC[1] * sr)
r0 = int(onset - RISE_WINDOW_SEC[0] * sr)
r1 = int(onset - RISE_WINDOW_SEC[1] * sr)
if b0 >= 0 and r0 >= 0:
baseline = total[b0:b1]
rise = total[r0:r1]
if (baseline.max() < CONTACT_THRESHOLD and
rise.mean() < 3 * CONTACT_THRESHOLD):
events.append(onset)
below = False
i += int(0.5 * sr) # skip ahead 0.5 s to avoid double-detect
else:
if total[i] < 1.0:
below = True
i += 1
return events
def sample_negative_windows(total_signal, positives, n_neg, rng, sr=100,
win_sec=WINDOW_LEN_SEC, lookahead_sec=1.5):
"""Pick random onsets where the following lookahead period is contact-free."""
T = len(total_signal)
wlen = int(win_sec * sr)
la = int(lookahead_sec * sr)
pos_set = set(positives)
tries = 0
found = []
while len(found) < n_neg and tries < 10 * n_neg:
tries += 1
t = rng.randint(wlen + int(LEAD_SEC * sr),
max(T - la, wlen + int(LEAD_SEC * sr) + 1))
# reject if near a positive
if any(abs(t - p) < 2 * sr for p in positives):
continue
# require no contact above threshold in [t, t+la]
if total_signal[t:t + la].max() >= CONTACT_THRESHOLD:
continue
found.append(t)
return found
# ---------------------------------------------------------------------------
# Dataset
# ---------------------------------------------------------------------------
class AnticipationDataset(Dataset):
"""Per-event sensor window -> binary label."""
def __init__(self, volunteers, modalities, downsample=5, stats=None,
seed=0, neg_per_pos=1.0):
self.modalities = modalities
self.downsample = downsample
self.items = []
self._modality_dims = {}
rng = np.random.RandomState(seed)
n_pos = 0
n_neg = 0
for vol in volunteers:
vol_dir = os.path.join(DATASET_DIR, vol)
if not os.path.isdir(vol_dir):
continue
for scenario in sorted(os.listdir(vol_dir)):
scenario_dir = os.path.join(vol_dir, scenario)
if not os.path.isdir(scenario_dir) or scenario not in SCENE_LABELS:
continue
pressure_fp = os.path.join(scenario_dir,
'aligned_pressure_100hz.csv')
if not os.path.exists(pressure_fp):
continue
# Load sensor modalities
parts = []
skip = False
for mod in modalities:
if mod == 'mocap':
fp = os.path.join(
scenario_dir, f"aligned_{vol}{scenario}_s_Q.tsv"
)
else:
fp = os.path.join(scenario_dir, MODALITY_FILES[mod])
if not os.path.exists(fp):
skip = True
break
arr = load_modality_array(fp, mod)
if arr is None:
skip = True
break
if mod in self._modality_dims and arr.shape[1] != self._modality_dims[mod]:
expected = self._modality_dims[mod]
if arr.shape[1] < expected:
pad = np.zeros((arr.shape[0], expected - arr.shape[1]),
dtype=np.float32)
arr = np.concatenate([arr, pad], axis=1)
else:
arr = arr[:, :expected]
if mod not in self._modality_dims:
self._modality_dims[mod] = arr.shape[1]
parts.append(arr)
if skip:
continue
T_min = min(p.shape[0] for p in parts)
combined = np.concatenate([p[:T_min] for p in parts], axis=1)
# Detect positive grasp events
try:
pdf = pd.read_csv(pressure_fp)
pvals = pdf.iloc[:, 1:].values.astype(np.float32)[:T_min]
total = pvals.sum(axis=1)
except Exception:
continue
positives = detect_grasp_events(pressure_fp)
positives = [p for p in positives
if p - int((WINDOW_LEN_SEC + LEAD_SEC) * 100) >= 0
and p < T_min]
# Window = [contact - (win + lead), contact - lead]
win_samples = int(WINDOW_LEN_SEC * 100)
lead_samples = int(LEAD_SEC * 100)
for p in positives:
s = p - win_samples - lead_samples
e = p - lead_samples
if s < 0 or e > T_min:
continue
window = combined[s:e]
window = window[::downsample]
if window.shape[0] < 4:
continue
self.items.append({'x': window.astype(np.float32), 'y': 1,
'src': f"{vol}/{scenario}@{p}"})
n_pos += 1
# Sample negatives
n_neg_want = int(len(positives) * neg_per_pos)
neg_onsets = sample_negative_windows(total, positives, n_neg_want,
rng)
for t in neg_onsets:
s = t - win_samples - lead_samples
e = t - lead_samples
if s < 0 or e > T_min:
continue
window = combined[s:e]
window = window[::downsample]
if window.shape[0] < 4:
continue
self.items.append({'x': window.astype(np.float32), 'y': 0,
'src': f"{vol}/{scenario}@{t}-neg"})
n_neg += 1
if len(self.items) == 0:
raise RuntimeError("No samples collected.")
print(f" pos={n_pos} neg={n_neg} total={len(self.items)} "
f"feat_dim={sum(self._modality_dims.values())}")
# Normalize
all_ = np.concatenate([it['x'] for it in self.items], axis=0).astype(np.float64)
if stats is not None:
self.mean, self.std = stats
else:
self.mean = all_.mean(axis=0, keepdims=True)
self.std = all_.std(axis=0, keepdims=True)
self.std[self.std < 1e-8] = 1.0
for it in self.items:
it['x'] = ((it['x'].astype(np.float64) - self.mean) /
self.std).astype(np.float32)
it['x'] = np.nan_to_num(it['x'], nan=0.0, posinf=0.0, neginf=0.0)
def get_stats(self):
return (self.mean, self.std)
@property
def feat_dim(self):
return sum(self._modality_dims.values())
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
it = self.items[idx]
return torch.from_numpy(it['x']), it['y']
def collate_fn(batch):
seqs, ys = zip(*batch)
lens = torch.LongTensor([s.shape[0] for s in seqs])
padded = pad_sequence(seqs, batch_first=True, padding_value=0.0)
max_len = padded.shape[1]
mask = torch.arange(max_len).unsqueeze(0) < lens.unsqueeze(1)
return padded, torch.LongTensor(ys), mask, lens
# ---------------------------------------------------------------------------
# Model (binary classifier, reuse Transformer backbone idea)
# ---------------------------------------------------------------------------
class BinaryClassifier(nn.Module):
def __init__(self, feat_dim, hidden_dim=128, n_layers=2, n_heads=4,
dropout=0.2, backbone='transformer'):
super().__init__()
self.backbone = backbone
if backbone == 'transformer':
self.in_proj = nn.Linear(feat_dim, hidden_dim)
self.pos = nn.Parameter(torch.zeros(1, 256, hidden_dim))
nn.init.trunc_normal_(self.pos, std=0.02)
layer = nn.TransformerEncoderLayer(
d_model=hidden_dim, nhead=n_heads,
dim_feedforward=4 * hidden_dim, dropout=dropout,
batch_first=True, activation='gelu',
)
self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)
self.head = nn.Sequential(
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout),
nn.Linear(hidden_dim, 2),
)
elif backbone == 'lstm':
self.lstm = nn.LSTM(feat_dim, hidden_dim, num_layers=2,
batch_first=True, bidirectional=True,
dropout=dropout)
self.head = nn.Sequential(
nn.LayerNorm(2 * hidden_dim),
nn.Linear(2 * hidden_dim, hidden_dim), nn.GELU(),
nn.Dropout(dropout), nn.Linear(hidden_dim, 2),
)
else:
raise ValueError(backbone)
def forward(self, x, mask):
if self.backbone == 'transformer':
T = x.size(1)
h = self.in_proj(x) + self.pos[:, :T, :]
key_padding = ~mask
h = self.encoder(h, src_key_padding_mask=key_padding)
else:
h, _ = self.lstm(x)
m = mask.unsqueeze(-1).float()
pooled = (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)
return self.head(pooled)
# ---------------------------------------------------------------------------
# Train / Eval
# ---------------------------------------------------------------------------
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def run_experiment(args):
set_seed(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
modalities = args.modalities.split(',')
print(f"Backbone: {args.backbone} | Modalities: {modalities} | Seed: {args.seed}")
print("Loading train...")
train_ds = AnticipationDataset(TRAIN_VOLS, modalities,
downsample=args.downsample, seed=args.seed)
stats = train_ds.get_stats()
print("Loading test...")
test_ds = AnticipationDataset(TEST_VOLS, modalities,
downsample=args.downsample,
stats=stats, seed=args.seed + 100)
train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
collate_fn=collate_fn, num_workers=0, drop_last=True)
test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False,
collate_fn=collate_fn, num_workers=0)
model = BinaryClassifier(train_ds.feat_dim, hidden_dim=args.hidden_dim,
dropout=args.dropout, backbone=args.backbone).to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"Params: {n_params:,}")
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
weight_decay=args.weight_decay)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-6,
)
mod_str = '-'.join(modalities)
exp_name = f"antic_{args.backbone}_{mod_str}_seed{args.seed}"
if args.tag:
exp_name += f"_{args.tag}"
out_dir = os.path.join(args.output_dir, exp_name)
os.makedirs(out_dir, exist_ok=True)
best_f1 = 0.0
best_metrics = None
best_state = None
best_epoch = 0
patience_counter = 0
for epoch in range(1, args.epochs + 1):
t0 = time.time()
model.train()
tr_loss, tr_n = 0.0, 0
for x, y, mask, _ in train_loader:
x, y, mask = x.to(device), y.to(device), mask.to(device)
optimizer.zero_grad()
logits = model(x, mask)
loss = criterion(logits, y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
tr_loss += loss.item() * y.size(0)
tr_n += y.size(0)
tr_loss /= max(tr_n, 1)
# Eval
model.eval()
all_logits, all_y = [], []
te_loss, te_n = 0.0, 0
with torch.no_grad():
for x, y, mask, _ in test_loader:
x, y, mask = x.to(device), y.to(device), mask.to(device)
logits = model(x, mask)
loss = criterion(logits, y)
te_loss += loss.item() * y.size(0)
te_n += y.size(0)
all_logits.append(logits.cpu())
all_y.append(y.cpu())
all_logits = torch.cat(all_logits, dim=0).numpy()
all_y = torch.cat(all_y, dim=0).numpy()
preds = all_logits.argmax(axis=1)
probs = torch.softmax(torch.from_numpy(all_logits), dim=1)[:, 1].numpy()
acc = accuracy_score(all_y, preds)
f1 = f1_score(all_y, preds, average='binary', zero_division=0)
try:
auc = roc_auc_score(all_y, probs)
except Exception:
auc = 0.5
try:
ap = average_precision_score(all_y, probs)
except Exception:
ap = 0.5
scheduler.step(te_loss / max(te_n, 1))
print(f" E{epoch:3d} | tr {tr_loss:.4f} | te {te_loss/max(te_n,1):.4f} "
f"acc {acc:.3f} f1 {f1:.3f} auc {auc:.3f} ap {ap:.3f} | "
f"{time.time()-t0:.1f}s")
if f1 > best_f1:
best_f1 = f1
best_metrics = {'acc': float(acc), 'f1': float(f1),
'auc': float(auc), 'ap': float(ap)}
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
best_epoch = epoch
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= args.patience:
print(f" Early stop (best epoch {best_epoch})")
break
if best_state is not None:
torch.save(best_state, os.path.join(out_dir, 'model_best.pt'))
results = {
'experiment': exp_name,
'backbone': args.backbone,
'modalities': modalities,
'seed': args.seed,
'best_epoch': best_epoch,
'best_test_metrics': best_metrics,
'train_size': len(train_ds),
'test_size': len(test_ds),
'train_pos_frac': float(np.mean([it['y'] for it in train_ds.items])),
'test_pos_frac': float(np.mean([it['y'] for it in test_ds.items])),
'feat_dim': train_ds.feat_dim,
'window_sec': WINDOW_LEN_SEC,
'lead_sec': LEAD_SEC,
'args': vars(args),
}
with open(os.path.join(out_dir, 'results.json'), 'w') as f:
json.dump(results, f, indent=2)
print(f"Saved: {out_dir}/results.json")
return results
def main():
p = argparse.ArgumentParser()
p.add_argument('--backbone', type=str, default='transformer',
choices=['transformer', 'lstm'])
p.add_argument('--modalities', type=str, default='emg,imu')
p.add_argument('--epochs', type=int, default=50)
p.add_argument('--batch_size', type=int, default=32)
p.add_argument('--lr', type=float, default=5e-4)
p.add_argument('--weight_decay', type=float, default=1e-4)
p.add_argument('--hidden_dim', type=int, default=128)
p.add_argument('--dropout', type=float, default=0.2)
p.add_argument('--downsample', type=int, default=5)
p.add_argument('--patience', type=int, default=10)
p.add_argument('--seed', type=int, default=42)
p.add_argument('--output_dir', type=str, required=True)
p.add_argument('--tag', type=str, default='')
args = p.parse_args()
run_experiment(args)
if __name__ == '__main__':
main()