"""Frame-level future *signal* forecasting dataset (T8 v2). Task definition --------------- At a sampled anchor t in a recording: past = sensor frames over [t - T_obs, t] ← input future = target-modality frames over (t, t + T_fut] ← regression target Unlike the v1 ForecastDataset (which targets per-frame verb-fine class), this predicts the raw *signal* values of one chosen target modality. This directly tests the Johansson 1984 / monzee 2003 hypothesis that cutaneous force feedback drives sub-second motor planning at the *signal* level (motor commands / kinematics), not at the level of slow-changing semantic verbs. Anchor stratification (4 event types based on contact transitions) ------------------------------------------------------------------ For each candidate anchor, we compute pressure_sum on past and future windows and label it by the (past_majority_contact, future_majority_contact) pair: type 0 = non-contact (past low, future low) — control: pressure ~ 0 type 1 = pre-contact (past low, future high) — pressure foretells onset type 2 = steady-grip (past high, future high) — sustained contact dynamics type 3 = release (past high, future low) — letting-go dynamics Per-event-type counts are reported and (optionally) capped to balance. Evaluation is broken down per event type so we can see WHERE pressure helps. """ from __future__ import annotations import sys from pathlib import Path from typing import Dict, List, Optional, Sequence, Tuple import numpy as np import torch from torch.utils.data import Dataset THIS = Path(__file__).resolve() sys.path.insert(0, str(THIS.parent)) sys.path.insert(0, str(THIS.parents[1])) try: from experiments.dataset_seqpred import ( SAMPLING_RATE_HZ, _load_recording_sensors, TRAIN_VOLS_V3, TEST_VOLS_V3, DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR, ) except ModuleNotFoundError: from dataset_seqpred import ( SAMPLING_RATE_HZ, _load_recording_sensors, TRAIN_VOLS_V3, TEST_VOLS_V3, DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR, ) EVENT_NAMES = {0: "non-contact", 1: "pre-contact", 2: "steady-grip", 3: "release"} class SignalForecastDataset(Dataset): """Predict future T_fut frames of `target_modality` from past T_obs of `input_modalities`.""" def __init__( self, volunteers: Sequence[str], input_modalities: Sequence[str], target_modality: str, t_obs_sec: float = 1.5, t_fut_sec: float = 0.5, anchor_stride_sec: float = 0.25, downsample: int = 5, dataset_dir: Path = DEFAULT_DATASET_DIR, annot_dir: Path = DEFAULT_ANNOT_DIR, contact_threshold_g: float = 5.0, per_event_max: Optional[int] = None, input_stats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]] = None, target_stats: Optional[Tuple[np.ndarray, np.ndarray]] = None, future_pressure_stats: Optional[Tuple[np.ndarray, np.ndarray]] = None, expected_input_dims: Optional[Dict[str, int]] = None, expected_target_dim: Optional[int] = None, include_future_pressure: bool = False, rng_seed: int = 0, log: bool = True, ): super().__init__() self.input_modalities = list(input_modalities) self.target_modality = str(target_modality) self.t_obs_sec = float(t_obs_sec) self.t_fut_sec = float(t_fut_sec) self.anchor_stride_sec = float(anchor_stride_sec) self.downsample = int(downsample) self.sr = SAMPLING_RATE_HZ // self.downsample self.dataset_dir = Path(dataset_dir) self.annot_dir = Path(annot_dir) self.contact_threshold_g = float(contact_threshold_g) self.per_event_max = per_event_max self.include_future_pressure = bool(include_future_pressure) self.T_obs = int(round(self.t_obs_sec * self.sr)) self.T_fut = int(round(self.t_fut_sec * self.sr)) self._items: List[dict] = [] self._modality_dims: Dict[str, int] = dict(expected_input_dims) if expected_input_dims else {} self._target_dim: int = int(expected_target_dim) if expected_target_dim else -1 rng = np.random.default_rng(rng_seed) # Modalities to load: union of inputs + target + pressure (for filter) load_mods = list(dict.fromkeys( list(self.input_modalities) + [self.target_modality, "pressure"] )) # Per-event-type pool of candidate anchor records pools: Dict[int, List[dict]] = {0: [], 1: [], 2: [], 3: []} for vol in volunteers: vol_dir = self.dataset_dir / vol if not vol_dir.is_dir(): continue for scenario_dir in sorted(vol_dir.glob("s*")): if not scenario_dir.is_dir(): continue scene = scenario_dir.name annot_path = self.annot_dir / vol / f"{scene}.json" if not annot_path.exists(): continue try: sensors_all = _load_recording_sensors( scenario_dir, vol, scene, load_mods ) except Exception: continue if sensors_all is None or any(a is None for a in sensors_all.values()): continue pressure_full = sensors_all["pressure"] # (T, 50) target_full = sensors_all[self.target_modality] input_arrs = {m: sensors_all[m] for m in self.input_modalities} # Track input modality dims for m, arr in input_arrs.items(): self._enforce_dim(input_arrs, m, arr, self._modality_dims) # Track target dim if self._target_dim < 0: self._target_dim = target_full.shape[1] elif target_full.shape[1] != self._target_dim: if target_full.shape[1] < self._target_dim: pad = np.zeros((target_full.shape[0], self._target_dim - target_full.shape[1]), dtype=np.float32) target_full = np.concatenate([target_full, pad], axis=1) else: target_full = target_full[:, :self._target_dim] T_avail = min(a.shape[0] for a in input_arrs.values()) T_avail = min(T_avail, target_full.shape[0], pressure_full.shape[0]) if T_avail < (self.T_obs + self.T_fut) * self.downsample: continue # Downsample to 20 Hz input_ds = {m: arr[:T_avail:self.downsample] for m, arr in input_arrs.items()} target_ds = target_full[:T_avail:self.downsample] pressure_ds = pressure_full[:T_avail:self.downsample] T_ds = target_ds.shape[0] pressure_sum = pressure_ds.sum(axis=1) # (T_ds,) stride = max(1, int(round(self.anchor_stride_sec * self.sr))) first_anchor = self.T_obs last_anchor = T_ds - self.T_fut if last_anchor <= first_anchor: continue for anchor in range(first_anchor, last_anchor + 1, stride): past_p = pressure_sum[anchor - self.T_obs:anchor] fut_p = pressure_sum[anchor:anchor + self.T_fut] past_high = (past_p > self.contact_threshold_g).mean() > 0.5 fut_high = (fut_p > self.contact_threshold_g).mean() > 0.5 if not past_high and not fut_high: et = 0 elif not past_high and fut_high: et = 1 elif past_high and fut_high: et = 2 else: et = 3 past_slice = {m: arr[anchor - self.T_obs:anchor] for m, arr in input_ds.items()} past_target_last = target_ds[anchor - 1].copy() # (target_dim,) fut_target = target_ds[anchor:anchor + self.T_fut].copy() if any(w.shape[0] != self.T_obs for w in past_slice.values()): continue if fut_target.shape[0] != self.T_fut: continue item = { "x": past_slice, "y": fut_target, "y_last": past_target_last, # for persistence "event_type": int(et), "meta": {"vol": vol, "scene": scene, "anchor_idx": int(anchor)}, } if self.include_future_pressure: fut_press = pressure_ds[anchor:anchor + self.T_fut].copy() if fut_press.shape[0] != self.T_fut: continue item["fp"] = fut_press # (T_fut, 50) pools[et].append(item) # Cap per-event count if requested (uniform downsample for balance) for et, pool in pools.items(): if self.per_event_max is not None and len(pool) > self.per_event_max: idx = rng.choice(len(pool), size=self.per_event_max, replace=False) pools[et] = [pool[i] for i in sorted(idx)] self._items = [it for et in (0, 1, 2, 3) for it in pools[et]] if not self._items: raise RuntimeError("SignalForecastDataset: collected 0 anchors.") # Z-score inputs and target separately if input_stats is None: input_stats = self._compute_input_stats() self._input_stats = input_stats self._apply_input_stats(input_stats) if target_stats is None: target_stats = self._compute_target_stats() self._target_stats = target_stats self._apply_target_stats(target_stats) if self.include_future_pressure: if future_pressure_stats is None: future_pressure_stats = self._compute_fp_stats() self._fp_stats = future_pressure_stats self._apply_fp_stats(future_pressure_stats) else: self._fp_stats = None if log: counts = {EVENT_NAMES[k]: sum(1 for it in self._items if it["event_type"] == k) for k in (0, 1, 2, 3)} print(f"[SignalForecastDataset] vols={len(volunteers)} " f"target={self.target_modality} inputs={self.input_modalities} " f"anchors={len(self._items)} {counts} " f"T_obs={self.T_obs} T_fut={self.T_fut} sr={self.sr}Hz " f"input_dims={self._modality_dims} target_dim={self._target_dim}", flush=True) @staticmethod def _enforce_dim(arrs, m, arr, dim_dict): if m in dim_dict: target = dim_dict[m] if arr.shape[1] != target: if arr.shape[1] < target: pad = np.zeros((arr.shape[0], target - arr.shape[1]), dtype=np.float32) arrs[m] = np.concatenate([arr, pad], axis=1) else: arrs[m] = arr[:, :target] else: dim_dict[m] = arr.shape[1] def _compute_input_stats(self): accs = {m: [] for m in self._modality_dims} for it in self._items: for m, w in it["x"].items(): accs[m].append(w) out = {} for m, ws in accs.items(): cat = np.concatenate(ws, axis=0) mu = cat.mean(axis=0).astype(np.float32) sd = cat.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd) out[m] = (mu, sd.astype(np.float32)) return out def _apply_input_stats(self, stats): for it in self._items: for m, w in it["x"].items(): if m in stats: mu, sd = stats[m] it["x"][m] = ((w - mu) / sd).astype(np.float32) def _compute_target_stats(self): ys = np.concatenate([it["y"] for it in self._items], axis=0) mu = ys.mean(axis=0).astype(np.float32) sd = ys.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd) return (mu, sd.astype(np.float32)) def _apply_target_stats(self, stats): mu, sd = stats for it in self._items: it["y"] = ((it["y"] - mu) / sd).astype(np.float32) it["y_last"] = ((it["y_last"] - mu) / sd).astype(np.float32) def _compute_fp_stats(self): fps = np.concatenate([it["fp"] for it in self._items], axis=0) mu = fps.mean(axis=0).astype(np.float32) sd = fps.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd) return (mu, sd.astype(np.float32)) def _apply_fp_stats(self, stats): mu, sd = stats for it in self._items: it["fp"] = ((it["fp"] - mu) / sd).astype(np.float32) def __len__(self): return len(self._items) def __getitem__(self, idx): it = self._items[idx] x = {m: torch.from_numpy(np.ascontiguousarray(w)) for m, w in it["x"].items()} y = torch.from_numpy(np.ascontiguousarray(it["y"])) # (T_fut, target_dim) y_last = torch.from_numpy(np.ascontiguousarray(it["y_last"])) # (target_dim,) et = int(it["event_type"]) if self.include_future_pressure: fp = torch.from_numpy(np.ascontiguousarray(it["fp"])) # (T_fut, 50) return x, y, y_last, fp, et, it["meta"] return x, y, y_last, et, it["meta"] @property def modality_dims(self): return dict(self._modality_dims) @property def target_dim(self): return self._target_dim def collate_signal_forecast(batch): if len(batch[0]) == 6: # has future pressure xs, ys, ylasts, fps, ets, metas = zip(*batch) mods = list(xs[0].keys()) x_out = {m: torch.stack([x[m] for x in xs], dim=0) for m in mods} y_out = torch.stack(ys, dim=0) yl_out = torch.stack(ylasts, dim=0) fp_out = torch.stack(fps, dim=0) # (B, T_fut, 50) et_out = torch.tensor(ets, dtype=torch.long) return x_out, y_out, yl_out, fp_out, et_out, list(metas) xs, ys, ylasts, ets, metas = zip(*batch) mods = list(xs[0].keys()) x_out = {m: torch.stack([x[m] for x in xs], dim=0) for m in mods} y_out = torch.stack(ys, dim=0) yl_out = torch.stack(ylasts, dim=0) et_out = torch.tensor(ets, dtype=torch.long) return x_out, y_out, yl_out, et_out, list(metas) def build_signal_train_test( input_modalities, target_modality, t_obs_sec=1.5, t_fut_sec=0.5, anchor_stride_sec=0.25, downsample=5, dataset_dir=DEFAULT_DATASET_DIR, annot_dir=DEFAULT_ANNOT_DIR, contact_threshold_g=5.0, per_event_max=None, include_future_pressure=False, rng_seed=0, ): train = SignalForecastDataset( TRAIN_VOLS_V3, input_modalities=input_modalities, target_modality=target_modality, t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec, anchor_stride_sec=anchor_stride_sec, downsample=downsample, dataset_dir=dataset_dir, annot_dir=annot_dir, contact_threshold_g=contact_threshold_g, per_event_max=per_event_max, include_future_pressure=include_future_pressure, rng_seed=rng_seed, log=True, ) test = SignalForecastDataset( TEST_VOLS_V3, input_modalities=input_modalities, target_modality=target_modality, t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec, anchor_stride_sec=anchor_stride_sec, downsample=downsample, dataset_dir=dataset_dir, annot_dir=annot_dir, contact_threshold_g=contact_threshold_g, per_event_max=per_event_max, input_stats=train._input_stats, target_stats=train._target_stats, future_pressure_stats=train._fp_stats, expected_input_dims=train._modality_dims, expected_target_dim=train._target_dim, include_future_pressure=include_future_pressure, rng_seed=rng_seed + 1, log=True, ) return train, test if __name__ == "__main__": import argparse ap = argparse.ArgumentParser() ap.add_argument("--input_modalities", default="imu") ap.add_argument("--target_modality", default="imu") ap.add_argument("--t_obs", type=float, default=1.5) ap.add_argument("--t_fut", type=float, default=0.5) args = ap.parse_args() tr, te = build_signal_train_test( input_modalities=args.input_modalities.split(","), target_modality=args.target_modality, t_obs_sec=args.t_obs, t_fut_sec=args.t_fut, ) x, y, y_last, et, meta = tr[0] print(f"Sample: x={ {m: tuple(v.shape) for m,v in x.items()} } y={tuple(y.shape)} y_last={tuple(y_last.shape)} event_type={et}")