| """Frame-level future *signal* forecasting dataset (T8 v2). |
| |
| Task definition |
| --------------- |
| At a sampled anchor t in a recording: |
| past = sensor frames over [t - T_obs, t] ← input |
| future = target-modality frames over (t, t + T_fut] ← regression target |
| |
| Unlike the v1 ForecastDataset (which targets per-frame verb-fine class), this |
| predicts the raw *signal* values of one chosen target modality. This directly |
| tests the Johansson 1984 / monzee 2003 hypothesis that cutaneous force |
| feedback drives sub-second motor planning at the *signal* level (motor |
| commands / kinematics), not at the level of slow-changing semantic verbs. |
| |
| Anchor stratification (4 event types based on contact transitions) |
| ------------------------------------------------------------------ |
| For each candidate anchor, we compute pressure_sum on past and future windows |
| and label it by the (past_majority_contact, future_majority_contact) pair: |
| |
| type 0 = non-contact (past low, future low) — control: pressure ~ 0 |
| type 1 = pre-contact (past low, future high) — pressure foretells onset |
| type 2 = steady-grip (past high, future high) — sustained contact dynamics |
| type 3 = release (past high, future low) — letting-go dynamics |
| |
| Per-event-type counts are reported and (optionally) capped to balance. |
| Evaluation is broken down per event type so we can see WHERE pressure helps. |
| """ |
| from __future__ import annotations |
|
|
| import sys |
| from pathlib import Path |
| from typing import Dict, List, Optional, Sequence, Tuple |
|
|
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset |
|
|
| THIS = Path(__file__).resolve() |
| sys.path.insert(0, str(THIS.parent)) |
| sys.path.insert(0, str(THIS.parents[1])) |
|
|
| try: |
| from experiments.dataset_seqpred import ( |
| SAMPLING_RATE_HZ, _load_recording_sensors, |
| TRAIN_VOLS_V3, TEST_VOLS_V3, |
| DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR, |
| ) |
| except ModuleNotFoundError: |
| from dataset_seqpred import ( |
| SAMPLING_RATE_HZ, _load_recording_sensors, |
| TRAIN_VOLS_V3, TEST_VOLS_V3, |
| DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR, |
| ) |
|
|
|
|
| EVENT_NAMES = {0: "non-contact", 1: "pre-contact", 2: "steady-grip", 3: "release"} |
|
|
|
|
| class SignalForecastDataset(Dataset): |
| """Predict future T_fut frames of `target_modality` from past T_obs of `input_modalities`.""" |
|
|
| def __init__( |
| self, |
| volunteers: Sequence[str], |
| input_modalities: Sequence[str], |
| target_modality: str, |
| t_obs_sec: float = 1.5, |
| t_fut_sec: float = 0.5, |
| anchor_stride_sec: float = 0.25, |
| downsample: int = 5, |
| dataset_dir: Path = DEFAULT_DATASET_DIR, |
| annot_dir: Path = DEFAULT_ANNOT_DIR, |
| contact_threshold_g: float = 5.0, |
| per_event_max: Optional[int] = None, |
| input_stats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]] = None, |
| target_stats: Optional[Tuple[np.ndarray, np.ndarray]] = None, |
| future_pressure_stats: Optional[Tuple[np.ndarray, np.ndarray]] = None, |
| expected_input_dims: Optional[Dict[str, int]] = None, |
| expected_target_dim: Optional[int] = None, |
| include_future_pressure: bool = False, |
| rng_seed: int = 0, |
| log: bool = True, |
| ): |
| super().__init__() |
| self.input_modalities = list(input_modalities) |
| self.target_modality = str(target_modality) |
| self.t_obs_sec = float(t_obs_sec) |
| self.t_fut_sec = float(t_fut_sec) |
| self.anchor_stride_sec = float(anchor_stride_sec) |
| self.downsample = int(downsample) |
| self.sr = SAMPLING_RATE_HZ // self.downsample |
| self.dataset_dir = Path(dataset_dir) |
| self.annot_dir = Path(annot_dir) |
| self.contact_threshold_g = float(contact_threshold_g) |
| self.per_event_max = per_event_max |
| self.include_future_pressure = bool(include_future_pressure) |
| self.T_obs = int(round(self.t_obs_sec * self.sr)) |
| self.T_fut = int(round(self.t_fut_sec * self.sr)) |
|
|
| self._items: List[dict] = [] |
| self._modality_dims: Dict[str, int] = dict(expected_input_dims) if expected_input_dims else {} |
| self._target_dim: int = int(expected_target_dim) if expected_target_dim else -1 |
| rng = np.random.default_rng(rng_seed) |
|
|
| |
| load_mods = list(dict.fromkeys( |
| list(self.input_modalities) + [self.target_modality, "pressure"] |
| )) |
|
|
| |
| pools: Dict[int, List[dict]] = {0: [], 1: [], 2: [], 3: []} |
|
|
| for vol in volunteers: |
| vol_dir = self.dataset_dir / vol |
| if not vol_dir.is_dir(): |
| continue |
| for scenario_dir in sorted(vol_dir.glob("s*")): |
| if not scenario_dir.is_dir(): |
| continue |
| scene = scenario_dir.name |
| annot_path = self.annot_dir / vol / f"{scene}.json" |
| if not annot_path.exists(): |
| continue |
| try: |
| sensors_all = _load_recording_sensors( |
| scenario_dir, vol, scene, load_mods |
| ) |
| except Exception: |
| continue |
| if sensors_all is None or any(a is None for a in sensors_all.values()): |
| continue |
|
|
| pressure_full = sensors_all["pressure"] |
| target_full = sensors_all[self.target_modality] |
| input_arrs = {m: sensors_all[m] for m in self.input_modalities} |
|
|
| |
| for m, arr in input_arrs.items(): |
| self._enforce_dim(input_arrs, m, arr, self._modality_dims) |
| |
| if self._target_dim < 0: |
| self._target_dim = target_full.shape[1] |
| elif target_full.shape[1] != self._target_dim: |
| if target_full.shape[1] < self._target_dim: |
| pad = np.zeros((target_full.shape[0], self._target_dim - target_full.shape[1]), |
| dtype=np.float32) |
| target_full = np.concatenate([target_full, pad], axis=1) |
| else: |
| target_full = target_full[:, :self._target_dim] |
|
|
| T_avail = min(a.shape[0] for a in input_arrs.values()) |
| T_avail = min(T_avail, target_full.shape[0], pressure_full.shape[0]) |
| if T_avail < (self.T_obs + self.T_fut) * self.downsample: |
| continue |
|
|
| |
| input_ds = {m: arr[:T_avail:self.downsample] for m, arr in input_arrs.items()} |
| target_ds = target_full[:T_avail:self.downsample] |
| pressure_ds = pressure_full[:T_avail:self.downsample] |
| T_ds = target_ds.shape[0] |
| pressure_sum = pressure_ds.sum(axis=1) |
|
|
| stride = max(1, int(round(self.anchor_stride_sec * self.sr))) |
| first_anchor = self.T_obs |
| last_anchor = T_ds - self.T_fut |
| if last_anchor <= first_anchor: |
| continue |
|
|
| for anchor in range(first_anchor, last_anchor + 1, stride): |
| past_p = pressure_sum[anchor - self.T_obs:anchor] |
| fut_p = pressure_sum[anchor:anchor + self.T_fut] |
| past_high = (past_p > self.contact_threshold_g).mean() > 0.5 |
| fut_high = (fut_p > self.contact_threshold_g).mean() > 0.5 |
| if not past_high and not fut_high: |
| et = 0 |
| elif not past_high and fut_high: |
| et = 1 |
| elif past_high and fut_high: |
| et = 2 |
| else: |
| et = 3 |
|
|
| past_slice = {m: arr[anchor - self.T_obs:anchor] |
| for m, arr in input_ds.items()} |
| past_target_last = target_ds[anchor - 1].copy() |
| fut_target = target_ds[anchor:anchor + self.T_fut].copy() |
| if any(w.shape[0] != self.T_obs for w in past_slice.values()): |
| continue |
| if fut_target.shape[0] != self.T_fut: |
| continue |
|
|
| item = { |
| "x": past_slice, |
| "y": fut_target, |
| "y_last": past_target_last, |
| "event_type": int(et), |
| "meta": {"vol": vol, "scene": scene, "anchor_idx": int(anchor)}, |
| } |
| if self.include_future_pressure: |
| fut_press = pressure_ds[anchor:anchor + self.T_fut].copy() |
| if fut_press.shape[0] != self.T_fut: |
| continue |
| item["fp"] = fut_press |
| pools[et].append(item) |
|
|
| |
| for et, pool in pools.items(): |
| if self.per_event_max is not None and len(pool) > self.per_event_max: |
| idx = rng.choice(len(pool), size=self.per_event_max, replace=False) |
| pools[et] = [pool[i] for i in sorted(idx)] |
| self._items = [it for et in (0, 1, 2, 3) for it in pools[et]] |
|
|
| if not self._items: |
| raise RuntimeError("SignalForecastDataset: collected 0 anchors.") |
|
|
| |
| if input_stats is None: |
| input_stats = self._compute_input_stats() |
| self._input_stats = input_stats |
| self._apply_input_stats(input_stats) |
| if target_stats is None: |
| target_stats = self._compute_target_stats() |
| self._target_stats = target_stats |
| self._apply_target_stats(target_stats) |
| if self.include_future_pressure: |
| if future_pressure_stats is None: |
| future_pressure_stats = self._compute_fp_stats() |
| self._fp_stats = future_pressure_stats |
| self._apply_fp_stats(future_pressure_stats) |
| else: |
| self._fp_stats = None |
|
|
| if log: |
| counts = {EVENT_NAMES[k]: sum(1 for it in self._items if it["event_type"] == k) |
| for k in (0, 1, 2, 3)} |
| print(f"[SignalForecastDataset] vols={len(volunteers)} " |
| f"target={self.target_modality} inputs={self.input_modalities} " |
| f"anchors={len(self._items)} {counts} " |
| f"T_obs={self.T_obs} T_fut={self.T_fut} sr={self.sr}Hz " |
| f"input_dims={self._modality_dims} target_dim={self._target_dim}", |
| flush=True) |
|
|
| @staticmethod |
| def _enforce_dim(arrs, m, arr, dim_dict): |
| if m in dim_dict: |
| target = dim_dict[m] |
| if arr.shape[1] != target: |
| if arr.shape[1] < target: |
| pad = np.zeros((arr.shape[0], target - arr.shape[1]), dtype=np.float32) |
| arrs[m] = np.concatenate([arr, pad], axis=1) |
| else: |
| arrs[m] = arr[:, :target] |
| else: |
| dim_dict[m] = arr.shape[1] |
|
|
| def _compute_input_stats(self): |
| accs = {m: [] for m in self._modality_dims} |
| for it in self._items: |
| for m, w in it["x"].items(): |
| accs[m].append(w) |
| out = {} |
| for m, ws in accs.items(): |
| cat = np.concatenate(ws, axis=0) |
| mu = cat.mean(axis=0).astype(np.float32) |
| sd = cat.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd) |
| out[m] = (mu, sd.astype(np.float32)) |
| return out |
|
|
| def _apply_input_stats(self, stats): |
| for it in self._items: |
| for m, w in it["x"].items(): |
| if m in stats: |
| mu, sd = stats[m] |
| it["x"][m] = ((w - mu) / sd).astype(np.float32) |
|
|
| def _compute_target_stats(self): |
| ys = np.concatenate([it["y"] for it in self._items], axis=0) |
| mu = ys.mean(axis=0).astype(np.float32) |
| sd = ys.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd) |
| return (mu, sd.astype(np.float32)) |
|
|
| def _apply_target_stats(self, stats): |
| mu, sd = stats |
| for it in self._items: |
| it["y"] = ((it["y"] - mu) / sd).astype(np.float32) |
| it["y_last"] = ((it["y_last"] - mu) / sd).astype(np.float32) |
|
|
| def _compute_fp_stats(self): |
| fps = np.concatenate([it["fp"] for it in self._items], axis=0) |
| mu = fps.mean(axis=0).astype(np.float32) |
| sd = fps.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd) |
| return (mu, sd.astype(np.float32)) |
|
|
| def _apply_fp_stats(self, stats): |
| mu, sd = stats |
| for it in self._items: |
| it["fp"] = ((it["fp"] - mu) / sd).astype(np.float32) |
|
|
| def __len__(self): |
| return len(self._items) |
|
|
| def __getitem__(self, idx): |
| it = self._items[idx] |
| x = {m: torch.from_numpy(np.ascontiguousarray(w)) for m, w in it["x"].items()} |
| y = torch.from_numpy(np.ascontiguousarray(it["y"])) |
| y_last = torch.from_numpy(np.ascontiguousarray(it["y_last"])) |
| et = int(it["event_type"]) |
| if self.include_future_pressure: |
| fp = torch.from_numpy(np.ascontiguousarray(it["fp"])) |
| return x, y, y_last, fp, et, it["meta"] |
| return x, y, y_last, et, it["meta"] |
|
|
| @property |
| def modality_dims(self): |
| return dict(self._modality_dims) |
|
|
| @property |
| def target_dim(self): |
| return self._target_dim |
|
|
|
|
| def collate_signal_forecast(batch): |
| if len(batch[0]) == 6: |
| xs, ys, ylasts, fps, ets, metas = zip(*batch) |
| mods = list(xs[0].keys()) |
| x_out = {m: torch.stack([x[m] for x in xs], dim=0) for m in mods} |
| y_out = torch.stack(ys, dim=0) |
| yl_out = torch.stack(ylasts, dim=0) |
| fp_out = torch.stack(fps, dim=0) |
| et_out = torch.tensor(ets, dtype=torch.long) |
| return x_out, y_out, yl_out, fp_out, et_out, list(metas) |
| xs, ys, ylasts, ets, metas = zip(*batch) |
| mods = list(xs[0].keys()) |
| x_out = {m: torch.stack([x[m] for x in xs], dim=0) for m in mods} |
| y_out = torch.stack(ys, dim=0) |
| yl_out = torch.stack(ylasts, dim=0) |
| et_out = torch.tensor(ets, dtype=torch.long) |
| return x_out, y_out, yl_out, et_out, list(metas) |
|
|
|
|
| def build_signal_train_test( |
| input_modalities, target_modality, |
| t_obs_sec=1.5, t_fut_sec=0.5, anchor_stride_sec=0.25, |
| downsample=5, |
| dataset_dir=DEFAULT_DATASET_DIR, annot_dir=DEFAULT_ANNOT_DIR, |
| contact_threshold_g=5.0, per_event_max=None, |
| include_future_pressure=False, |
| rng_seed=0, |
| ): |
| train = SignalForecastDataset( |
| TRAIN_VOLS_V3, input_modalities=input_modalities, |
| target_modality=target_modality, |
| t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec, |
| anchor_stride_sec=anchor_stride_sec, downsample=downsample, |
| dataset_dir=dataset_dir, annot_dir=annot_dir, |
| contact_threshold_g=contact_threshold_g, per_event_max=per_event_max, |
| include_future_pressure=include_future_pressure, |
| rng_seed=rng_seed, log=True, |
| ) |
| test = SignalForecastDataset( |
| TEST_VOLS_V3, input_modalities=input_modalities, |
| target_modality=target_modality, |
| t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec, |
| anchor_stride_sec=anchor_stride_sec, downsample=downsample, |
| dataset_dir=dataset_dir, annot_dir=annot_dir, |
| contact_threshold_g=contact_threshold_g, per_event_max=per_event_max, |
| input_stats=train._input_stats, target_stats=train._target_stats, |
| future_pressure_stats=train._fp_stats, |
| expected_input_dims=train._modality_dims, |
| expected_target_dim=train._target_dim, |
| include_future_pressure=include_future_pressure, |
| rng_seed=rng_seed + 1, log=True, |
| ) |
| return train, test |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--input_modalities", default="imu") |
| ap.add_argument("--target_modality", default="imu") |
| ap.add_argument("--t_obs", type=float, default=1.5) |
| ap.add_argument("--t_fut", type=float, default=0.5) |
| args = ap.parse_args() |
| tr, te = build_signal_train_test( |
| input_modalities=args.input_modalities.split(","), |
| target_modality=args.target_modality, |
| t_obs_sec=args.t_obs, t_fut_sec=args.t_fut, |
| ) |
| x, y, y_last, et, meta = tr[0] |
| print(f"Sample: x={ {m: tuple(v.shape) for m,v in x.items()} } y={tuple(y.shape)} y_last={tuple(y_last.shape)} event_type={et}") |
|
|