PULSE-code / experiments /data /dataset_signal_forecast.py
velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
"""Frame-level future *signal* forecasting dataset (T8 v2).
Task definition
---------------
At a sampled anchor t in a recording:
past = sensor frames over [t - T_obs, t] ← input
future = target-modality frames over (t, t + T_fut] ← regression target
Unlike the v1 ForecastDataset (which targets per-frame verb-fine class), this
predicts the raw *signal* values of one chosen target modality. This directly
tests the Johansson 1984 / monzee 2003 hypothesis that cutaneous force
feedback drives sub-second motor planning at the *signal* level (motor
commands / kinematics), not at the level of slow-changing semantic verbs.
Anchor stratification (4 event types based on contact transitions)
------------------------------------------------------------------
For each candidate anchor, we compute pressure_sum on past and future windows
and label it by the (past_majority_contact, future_majority_contact) pair:
type 0 = non-contact (past low, future low) — control: pressure ~ 0
type 1 = pre-contact (past low, future high) — pressure foretells onset
type 2 = steady-grip (past high, future high) — sustained contact dynamics
type 3 = release (past high, future low) — letting-go dynamics
Per-event-type counts are reported and (optionally) capped to balance.
Evaluation is broken down per event type so we can see WHERE pressure helps.
"""
from __future__ import annotations
import sys
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
import torch
from torch.utils.data import Dataset
THIS = Path(__file__).resolve()
sys.path.insert(0, str(THIS.parent))
sys.path.insert(0, str(THIS.parents[1]))
try:
from experiments.dataset_seqpred import (
SAMPLING_RATE_HZ, _load_recording_sensors,
TRAIN_VOLS_V3, TEST_VOLS_V3,
DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR,
)
except ModuleNotFoundError:
from dataset_seqpred import (
SAMPLING_RATE_HZ, _load_recording_sensors,
TRAIN_VOLS_V3, TEST_VOLS_V3,
DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR,
)
EVENT_NAMES = {0: "non-contact", 1: "pre-contact", 2: "steady-grip", 3: "release"}
class SignalForecastDataset(Dataset):
"""Predict future T_fut frames of `target_modality` from past T_obs of `input_modalities`."""
def __init__(
self,
volunteers: Sequence[str],
input_modalities: Sequence[str],
target_modality: str,
t_obs_sec: float = 1.5,
t_fut_sec: float = 0.5,
anchor_stride_sec: float = 0.25,
downsample: int = 5,
dataset_dir: Path = DEFAULT_DATASET_DIR,
annot_dir: Path = DEFAULT_ANNOT_DIR,
contact_threshold_g: float = 5.0,
per_event_max: Optional[int] = None,
input_stats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]] = None,
target_stats: Optional[Tuple[np.ndarray, np.ndarray]] = None,
future_pressure_stats: Optional[Tuple[np.ndarray, np.ndarray]] = None,
expected_input_dims: Optional[Dict[str, int]] = None,
expected_target_dim: Optional[int] = None,
include_future_pressure: bool = False,
rng_seed: int = 0,
log: bool = True,
):
super().__init__()
self.input_modalities = list(input_modalities)
self.target_modality = str(target_modality)
self.t_obs_sec = float(t_obs_sec)
self.t_fut_sec = float(t_fut_sec)
self.anchor_stride_sec = float(anchor_stride_sec)
self.downsample = int(downsample)
self.sr = SAMPLING_RATE_HZ // self.downsample
self.dataset_dir = Path(dataset_dir)
self.annot_dir = Path(annot_dir)
self.contact_threshold_g = float(contact_threshold_g)
self.per_event_max = per_event_max
self.include_future_pressure = bool(include_future_pressure)
self.T_obs = int(round(self.t_obs_sec * self.sr))
self.T_fut = int(round(self.t_fut_sec * self.sr))
self._items: List[dict] = []
self._modality_dims: Dict[str, int] = dict(expected_input_dims) if expected_input_dims else {}
self._target_dim: int = int(expected_target_dim) if expected_target_dim else -1
rng = np.random.default_rng(rng_seed)
# Modalities to load: union of inputs + target + pressure (for filter)
load_mods = list(dict.fromkeys(
list(self.input_modalities) + [self.target_modality, "pressure"]
))
# Per-event-type pool of candidate anchor records
pools: Dict[int, List[dict]] = {0: [], 1: [], 2: [], 3: []}
for vol in volunteers:
vol_dir = self.dataset_dir / vol
if not vol_dir.is_dir():
continue
for scenario_dir in sorted(vol_dir.glob("s*")):
if not scenario_dir.is_dir():
continue
scene = scenario_dir.name
annot_path = self.annot_dir / vol / f"{scene}.json"
if not annot_path.exists():
continue
try:
sensors_all = _load_recording_sensors(
scenario_dir, vol, scene, load_mods
)
except Exception:
continue
if sensors_all is None or any(a is None for a in sensors_all.values()):
continue
pressure_full = sensors_all["pressure"] # (T, 50)
target_full = sensors_all[self.target_modality]
input_arrs = {m: sensors_all[m] for m in self.input_modalities}
# Track input modality dims
for m, arr in input_arrs.items():
self._enforce_dim(input_arrs, m, arr, self._modality_dims)
# Track target dim
if self._target_dim < 0:
self._target_dim = target_full.shape[1]
elif target_full.shape[1] != self._target_dim:
if target_full.shape[1] < self._target_dim:
pad = np.zeros((target_full.shape[0], self._target_dim - target_full.shape[1]),
dtype=np.float32)
target_full = np.concatenate([target_full, pad], axis=1)
else:
target_full = target_full[:, :self._target_dim]
T_avail = min(a.shape[0] for a in input_arrs.values())
T_avail = min(T_avail, target_full.shape[0], pressure_full.shape[0])
if T_avail < (self.T_obs + self.T_fut) * self.downsample:
continue
# Downsample to 20 Hz
input_ds = {m: arr[:T_avail:self.downsample] for m, arr in input_arrs.items()}
target_ds = target_full[:T_avail:self.downsample]
pressure_ds = pressure_full[:T_avail:self.downsample]
T_ds = target_ds.shape[0]
pressure_sum = pressure_ds.sum(axis=1) # (T_ds,)
stride = max(1, int(round(self.anchor_stride_sec * self.sr)))
first_anchor = self.T_obs
last_anchor = T_ds - self.T_fut
if last_anchor <= first_anchor:
continue
for anchor in range(first_anchor, last_anchor + 1, stride):
past_p = pressure_sum[anchor - self.T_obs:anchor]
fut_p = pressure_sum[anchor:anchor + self.T_fut]
past_high = (past_p > self.contact_threshold_g).mean() > 0.5
fut_high = (fut_p > self.contact_threshold_g).mean() > 0.5
if not past_high and not fut_high:
et = 0
elif not past_high and fut_high:
et = 1
elif past_high and fut_high:
et = 2
else:
et = 3
past_slice = {m: arr[anchor - self.T_obs:anchor]
for m, arr in input_ds.items()}
past_target_last = target_ds[anchor - 1].copy() # (target_dim,)
fut_target = target_ds[anchor:anchor + self.T_fut].copy()
if any(w.shape[0] != self.T_obs for w in past_slice.values()):
continue
if fut_target.shape[0] != self.T_fut:
continue
item = {
"x": past_slice,
"y": fut_target,
"y_last": past_target_last, # for persistence
"event_type": int(et),
"meta": {"vol": vol, "scene": scene, "anchor_idx": int(anchor)},
}
if self.include_future_pressure:
fut_press = pressure_ds[anchor:anchor + self.T_fut].copy()
if fut_press.shape[0] != self.T_fut:
continue
item["fp"] = fut_press # (T_fut, 50)
pools[et].append(item)
# Cap per-event count if requested (uniform downsample for balance)
for et, pool in pools.items():
if self.per_event_max is not None and len(pool) > self.per_event_max:
idx = rng.choice(len(pool), size=self.per_event_max, replace=False)
pools[et] = [pool[i] for i in sorted(idx)]
self._items = [it for et in (0, 1, 2, 3) for it in pools[et]]
if not self._items:
raise RuntimeError("SignalForecastDataset: collected 0 anchors.")
# Z-score inputs and target separately
if input_stats is None:
input_stats = self._compute_input_stats()
self._input_stats = input_stats
self._apply_input_stats(input_stats)
if target_stats is None:
target_stats = self._compute_target_stats()
self._target_stats = target_stats
self._apply_target_stats(target_stats)
if self.include_future_pressure:
if future_pressure_stats is None:
future_pressure_stats = self._compute_fp_stats()
self._fp_stats = future_pressure_stats
self._apply_fp_stats(future_pressure_stats)
else:
self._fp_stats = None
if log:
counts = {EVENT_NAMES[k]: sum(1 for it in self._items if it["event_type"] == k)
for k in (0, 1, 2, 3)}
print(f"[SignalForecastDataset] vols={len(volunteers)} "
f"target={self.target_modality} inputs={self.input_modalities} "
f"anchors={len(self._items)} {counts} "
f"T_obs={self.T_obs} T_fut={self.T_fut} sr={self.sr}Hz "
f"input_dims={self._modality_dims} target_dim={self._target_dim}",
flush=True)
@staticmethod
def _enforce_dim(arrs, m, arr, dim_dict):
if m in dim_dict:
target = dim_dict[m]
if arr.shape[1] != target:
if arr.shape[1] < target:
pad = np.zeros((arr.shape[0], target - arr.shape[1]), dtype=np.float32)
arrs[m] = np.concatenate([arr, pad], axis=1)
else:
arrs[m] = arr[:, :target]
else:
dim_dict[m] = arr.shape[1]
def _compute_input_stats(self):
accs = {m: [] for m in self._modality_dims}
for it in self._items:
for m, w in it["x"].items():
accs[m].append(w)
out = {}
for m, ws in accs.items():
cat = np.concatenate(ws, axis=0)
mu = cat.mean(axis=0).astype(np.float32)
sd = cat.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd)
out[m] = (mu, sd.astype(np.float32))
return out
def _apply_input_stats(self, stats):
for it in self._items:
for m, w in it["x"].items():
if m in stats:
mu, sd = stats[m]
it["x"][m] = ((w - mu) / sd).astype(np.float32)
def _compute_target_stats(self):
ys = np.concatenate([it["y"] for it in self._items], axis=0)
mu = ys.mean(axis=0).astype(np.float32)
sd = ys.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd)
return (mu, sd.astype(np.float32))
def _apply_target_stats(self, stats):
mu, sd = stats
for it in self._items:
it["y"] = ((it["y"] - mu) / sd).astype(np.float32)
it["y_last"] = ((it["y_last"] - mu) / sd).astype(np.float32)
def _compute_fp_stats(self):
fps = np.concatenate([it["fp"] for it in self._items], axis=0)
mu = fps.mean(axis=0).astype(np.float32)
sd = fps.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd)
return (mu, sd.astype(np.float32))
def _apply_fp_stats(self, stats):
mu, sd = stats
for it in self._items:
it["fp"] = ((it["fp"] - mu) / sd).astype(np.float32)
def __len__(self):
return len(self._items)
def __getitem__(self, idx):
it = self._items[idx]
x = {m: torch.from_numpy(np.ascontiguousarray(w)) for m, w in it["x"].items()}
y = torch.from_numpy(np.ascontiguousarray(it["y"])) # (T_fut, target_dim)
y_last = torch.from_numpy(np.ascontiguousarray(it["y_last"])) # (target_dim,)
et = int(it["event_type"])
if self.include_future_pressure:
fp = torch.from_numpy(np.ascontiguousarray(it["fp"])) # (T_fut, 50)
return x, y, y_last, fp, et, it["meta"]
return x, y, y_last, et, it["meta"]
@property
def modality_dims(self):
return dict(self._modality_dims)
@property
def target_dim(self):
return self._target_dim
def collate_signal_forecast(batch):
if len(batch[0]) == 6: # has future pressure
xs, ys, ylasts, fps, ets, metas = zip(*batch)
mods = list(xs[0].keys())
x_out = {m: torch.stack([x[m] for x in xs], dim=0) for m in mods}
y_out = torch.stack(ys, dim=0)
yl_out = torch.stack(ylasts, dim=0)
fp_out = torch.stack(fps, dim=0) # (B, T_fut, 50)
et_out = torch.tensor(ets, dtype=torch.long)
return x_out, y_out, yl_out, fp_out, et_out, list(metas)
xs, ys, ylasts, ets, metas = zip(*batch)
mods = list(xs[0].keys())
x_out = {m: torch.stack([x[m] for x in xs], dim=0) for m in mods}
y_out = torch.stack(ys, dim=0)
yl_out = torch.stack(ylasts, dim=0)
et_out = torch.tensor(ets, dtype=torch.long)
return x_out, y_out, yl_out, et_out, list(metas)
def build_signal_train_test(
input_modalities, target_modality,
t_obs_sec=1.5, t_fut_sec=0.5, anchor_stride_sec=0.25,
downsample=5,
dataset_dir=DEFAULT_DATASET_DIR, annot_dir=DEFAULT_ANNOT_DIR,
contact_threshold_g=5.0, per_event_max=None,
include_future_pressure=False,
rng_seed=0,
):
train = SignalForecastDataset(
TRAIN_VOLS_V3, input_modalities=input_modalities,
target_modality=target_modality,
t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec,
anchor_stride_sec=anchor_stride_sec, downsample=downsample,
dataset_dir=dataset_dir, annot_dir=annot_dir,
contact_threshold_g=contact_threshold_g, per_event_max=per_event_max,
include_future_pressure=include_future_pressure,
rng_seed=rng_seed, log=True,
)
test = SignalForecastDataset(
TEST_VOLS_V3, input_modalities=input_modalities,
target_modality=target_modality,
t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec,
anchor_stride_sec=anchor_stride_sec, downsample=downsample,
dataset_dir=dataset_dir, annot_dir=annot_dir,
contact_threshold_g=contact_threshold_g, per_event_max=per_event_max,
input_stats=train._input_stats, target_stats=train._target_stats,
future_pressure_stats=train._fp_stats,
expected_input_dims=train._modality_dims,
expected_target_dim=train._target_dim,
include_future_pressure=include_future_pressure,
rng_seed=rng_seed + 1, log=True,
)
return train, test
if __name__ == "__main__":
import argparse
ap = argparse.ArgumentParser()
ap.add_argument("--input_modalities", default="imu")
ap.add_argument("--target_modality", default="imu")
ap.add_argument("--t_obs", type=float, default=1.5)
ap.add_argument("--t_fut", type=float, default=0.5)
args = ap.parse_args()
tr, te = build_signal_train_test(
input_modalities=args.input_modalities.split(","),
target_modality=args.target_modality,
t_obs_sec=args.t_obs, t_fut_sec=args.t_fut,
)
x, y, y_last, et, meta = tr[0]
print(f"Sample: x={ {m: tuple(v.shape) for m,v in x.items()} } y={tuple(y.shape)} y_last={tuple(y_last.shape)} event_type={et}")