File size: 17,049 Bytes
b4b2877 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 | """Frame-level future *signal* forecasting dataset (T8 v2).
Task definition
---------------
At a sampled anchor t in a recording:
past = sensor frames over [t - T_obs, t] ← input
future = target-modality frames over (t, t + T_fut] ← regression target
Unlike the v1 ForecastDataset (which targets per-frame verb-fine class), this
predicts the raw *signal* values of one chosen target modality. This directly
tests the Johansson 1984 / monzee 2003 hypothesis that cutaneous force
feedback drives sub-second motor planning at the *signal* level (motor
commands / kinematics), not at the level of slow-changing semantic verbs.
Anchor stratification (4 event types based on contact transitions)
------------------------------------------------------------------
For each candidate anchor, we compute pressure_sum on past and future windows
and label it by the (past_majority_contact, future_majority_contact) pair:
type 0 = non-contact (past low, future low) — control: pressure ~ 0
type 1 = pre-contact (past low, future high) — pressure foretells onset
type 2 = steady-grip (past high, future high) — sustained contact dynamics
type 3 = release (past high, future low) — letting-go dynamics
Per-event-type counts are reported and (optionally) capped to balance.
Evaluation is broken down per event type so we can see WHERE pressure helps.
"""
from __future__ import annotations
import sys
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
import torch
from torch.utils.data import Dataset
THIS = Path(__file__).resolve()
sys.path.insert(0, str(THIS.parent))
sys.path.insert(0, str(THIS.parents[1]))
try:
from experiments.dataset_seqpred import (
SAMPLING_RATE_HZ, _load_recording_sensors,
TRAIN_VOLS_V3, TEST_VOLS_V3,
DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR,
)
except ModuleNotFoundError:
from dataset_seqpred import (
SAMPLING_RATE_HZ, _load_recording_sensors,
TRAIN_VOLS_V3, TEST_VOLS_V3,
DEFAULT_DATASET_DIR, DEFAULT_ANNOT_DIR,
)
EVENT_NAMES = {0: "non-contact", 1: "pre-contact", 2: "steady-grip", 3: "release"}
class SignalForecastDataset(Dataset):
"""Predict future T_fut frames of `target_modality` from past T_obs of `input_modalities`."""
def __init__(
self,
volunteers: Sequence[str],
input_modalities: Sequence[str],
target_modality: str,
t_obs_sec: float = 1.5,
t_fut_sec: float = 0.5,
anchor_stride_sec: float = 0.25,
downsample: int = 5,
dataset_dir: Path = DEFAULT_DATASET_DIR,
annot_dir: Path = DEFAULT_ANNOT_DIR,
contact_threshold_g: float = 5.0,
per_event_max: Optional[int] = None,
input_stats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]] = None,
target_stats: Optional[Tuple[np.ndarray, np.ndarray]] = None,
future_pressure_stats: Optional[Tuple[np.ndarray, np.ndarray]] = None,
expected_input_dims: Optional[Dict[str, int]] = None,
expected_target_dim: Optional[int] = None,
include_future_pressure: bool = False,
rng_seed: int = 0,
log: bool = True,
):
super().__init__()
self.input_modalities = list(input_modalities)
self.target_modality = str(target_modality)
self.t_obs_sec = float(t_obs_sec)
self.t_fut_sec = float(t_fut_sec)
self.anchor_stride_sec = float(anchor_stride_sec)
self.downsample = int(downsample)
self.sr = SAMPLING_RATE_HZ // self.downsample
self.dataset_dir = Path(dataset_dir)
self.annot_dir = Path(annot_dir)
self.contact_threshold_g = float(contact_threshold_g)
self.per_event_max = per_event_max
self.include_future_pressure = bool(include_future_pressure)
self.T_obs = int(round(self.t_obs_sec * self.sr))
self.T_fut = int(round(self.t_fut_sec * self.sr))
self._items: List[dict] = []
self._modality_dims: Dict[str, int] = dict(expected_input_dims) if expected_input_dims else {}
self._target_dim: int = int(expected_target_dim) if expected_target_dim else -1
rng = np.random.default_rng(rng_seed)
# Modalities to load: union of inputs + target + pressure (for filter)
load_mods = list(dict.fromkeys(
list(self.input_modalities) + [self.target_modality, "pressure"]
))
# Per-event-type pool of candidate anchor records
pools: Dict[int, List[dict]] = {0: [], 1: [], 2: [], 3: []}
for vol in volunteers:
vol_dir = self.dataset_dir / vol
if not vol_dir.is_dir():
continue
for scenario_dir in sorted(vol_dir.glob("s*")):
if not scenario_dir.is_dir():
continue
scene = scenario_dir.name
annot_path = self.annot_dir / vol / f"{scene}.json"
if not annot_path.exists():
continue
try:
sensors_all = _load_recording_sensors(
scenario_dir, vol, scene, load_mods
)
except Exception:
continue
if sensors_all is None or any(a is None for a in sensors_all.values()):
continue
pressure_full = sensors_all["pressure"] # (T, 50)
target_full = sensors_all[self.target_modality]
input_arrs = {m: sensors_all[m] for m in self.input_modalities}
# Track input modality dims
for m, arr in input_arrs.items():
self._enforce_dim(input_arrs, m, arr, self._modality_dims)
# Track target dim
if self._target_dim < 0:
self._target_dim = target_full.shape[1]
elif target_full.shape[1] != self._target_dim:
if target_full.shape[1] < self._target_dim:
pad = np.zeros((target_full.shape[0], self._target_dim - target_full.shape[1]),
dtype=np.float32)
target_full = np.concatenate([target_full, pad], axis=1)
else:
target_full = target_full[:, :self._target_dim]
T_avail = min(a.shape[0] for a in input_arrs.values())
T_avail = min(T_avail, target_full.shape[0], pressure_full.shape[0])
if T_avail < (self.T_obs + self.T_fut) * self.downsample:
continue
# Downsample to 20 Hz
input_ds = {m: arr[:T_avail:self.downsample] for m, arr in input_arrs.items()}
target_ds = target_full[:T_avail:self.downsample]
pressure_ds = pressure_full[:T_avail:self.downsample]
T_ds = target_ds.shape[0]
pressure_sum = pressure_ds.sum(axis=1) # (T_ds,)
stride = max(1, int(round(self.anchor_stride_sec * self.sr)))
first_anchor = self.T_obs
last_anchor = T_ds - self.T_fut
if last_anchor <= first_anchor:
continue
for anchor in range(first_anchor, last_anchor + 1, stride):
past_p = pressure_sum[anchor - self.T_obs:anchor]
fut_p = pressure_sum[anchor:anchor + self.T_fut]
past_high = (past_p > self.contact_threshold_g).mean() > 0.5
fut_high = (fut_p > self.contact_threshold_g).mean() > 0.5
if not past_high and not fut_high:
et = 0
elif not past_high and fut_high:
et = 1
elif past_high and fut_high:
et = 2
else:
et = 3
past_slice = {m: arr[anchor - self.T_obs:anchor]
for m, arr in input_ds.items()}
past_target_last = target_ds[anchor - 1].copy() # (target_dim,)
fut_target = target_ds[anchor:anchor + self.T_fut].copy()
if any(w.shape[0] != self.T_obs for w in past_slice.values()):
continue
if fut_target.shape[0] != self.T_fut:
continue
item = {
"x": past_slice,
"y": fut_target,
"y_last": past_target_last, # for persistence
"event_type": int(et),
"meta": {"vol": vol, "scene": scene, "anchor_idx": int(anchor)},
}
if self.include_future_pressure:
fut_press = pressure_ds[anchor:anchor + self.T_fut].copy()
if fut_press.shape[0] != self.T_fut:
continue
item["fp"] = fut_press # (T_fut, 50)
pools[et].append(item)
# Cap per-event count if requested (uniform downsample for balance)
for et, pool in pools.items():
if self.per_event_max is not None and len(pool) > self.per_event_max:
idx = rng.choice(len(pool), size=self.per_event_max, replace=False)
pools[et] = [pool[i] for i in sorted(idx)]
self._items = [it for et in (0, 1, 2, 3) for it in pools[et]]
if not self._items:
raise RuntimeError("SignalForecastDataset: collected 0 anchors.")
# Z-score inputs and target separately
if input_stats is None:
input_stats = self._compute_input_stats()
self._input_stats = input_stats
self._apply_input_stats(input_stats)
if target_stats is None:
target_stats = self._compute_target_stats()
self._target_stats = target_stats
self._apply_target_stats(target_stats)
if self.include_future_pressure:
if future_pressure_stats is None:
future_pressure_stats = self._compute_fp_stats()
self._fp_stats = future_pressure_stats
self._apply_fp_stats(future_pressure_stats)
else:
self._fp_stats = None
if log:
counts = {EVENT_NAMES[k]: sum(1 for it in self._items if it["event_type"] == k)
for k in (0, 1, 2, 3)}
print(f"[SignalForecastDataset] vols={len(volunteers)} "
f"target={self.target_modality} inputs={self.input_modalities} "
f"anchors={len(self._items)} {counts} "
f"T_obs={self.T_obs} T_fut={self.T_fut} sr={self.sr}Hz "
f"input_dims={self._modality_dims} target_dim={self._target_dim}",
flush=True)
@staticmethod
def _enforce_dim(arrs, m, arr, dim_dict):
if m in dim_dict:
target = dim_dict[m]
if arr.shape[1] != target:
if arr.shape[1] < target:
pad = np.zeros((arr.shape[0], target - arr.shape[1]), dtype=np.float32)
arrs[m] = np.concatenate([arr, pad], axis=1)
else:
arrs[m] = arr[:, :target]
else:
dim_dict[m] = arr.shape[1]
def _compute_input_stats(self):
accs = {m: [] for m in self._modality_dims}
for it in self._items:
for m, w in it["x"].items():
accs[m].append(w)
out = {}
for m, ws in accs.items():
cat = np.concatenate(ws, axis=0)
mu = cat.mean(axis=0).astype(np.float32)
sd = cat.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd)
out[m] = (mu, sd.astype(np.float32))
return out
def _apply_input_stats(self, stats):
for it in self._items:
for m, w in it["x"].items():
if m in stats:
mu, sd = stats[m]
it["x"][m] = ((w - mu) / sd).astype(np.float32)
def _compute_target_stats(self):
ys = np.concatenate([it["y"] for it in self._items], axis=0)
mu = ys.mean(axis=0).astype(np.float32)
sd = ys.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd)
return (mu, sd.astype(np.float32))
def _apply_target_stats(self, stats):
mu, sd = stats
for it in self._items:
it["y"] = ((it["y"] - mu) / sd).astype(np.float32)
it["y_last"] = ((it["y_last"] - mu) / sd).astype(np.float32)
def _compute_fp_stats(self):
fps = np.concatenate([it["fp"] for it in self._items], axis=0)
mu = fps.mean(axis=0).astype(np.float32)
sd = fps.std(axis=0); sd = np.where(sd < 1e-6, 1.0, sd)
return (mu, sd.astype(np.float32))
def _apply_fp_stats(self, stats):
mu, sd = stats
for it in self._items:
it["fp"] = ((it["fp"] - mu) / sd).astype(np.float32)
def __len__(self):
return len(self._items)
def __getitem__(self, idx):
it = self._items[idx]
x = {m: torch.from_numpy(np.ascontiguousarray(w)) for m, w in it["x"].items()}
y = torch.from_numpy(np.ascontiguousarray(it["y"])) # (T_fut, target_dim)
y_last = torch.from_numpy(np.ascontiguousarray(it["y_last"])) # (target_dim,)
et = int(it["event_type"])
if self.include_future_pressure:
fp = torch.from_numpy(np.ascontiguousarray(it["fp"])) # (T_fut, 50)
return x, y, y_last, fp, et, it["meta"]
return x, y, y_last, et, it["meta"]
@property
def modality_dims(self):
return dict(self._modality_dims)
@property
def target_dim(self):
return self._target_dim
def collate_signal_forecast(batch):
if len(batch[0]) == 6: # has future pressure
xs, ys, ylasts, fps, ets, metas = zip(*batch)
mods = list(xs[0].keys())
x_out = {m: torch.stack([x[m] for x in xs], dim=0) for m in mods}
y_out = torch.stack(ys, dim=0)
yl_out = torch.stack(ylasts, dim=0)
fp_out = torch.stack(fps, dim=0) # (B, T_fut, 50)
et_out = torch.tensor(ets, dtype=torch.long)
return x_out, y_out, yl_out, fp_out, et_out, list(metas)
xs, ys, ylasts, ets, metas = zip(*batch)
mods = list(xs[0].keys())
x_out = {m: torch.stack([x[m] for x in xs], dim=0) for m in mods}
y_out = torch.stack(ys, dim=0)
yl_out = torch.stack(ylasts, dim=0)
et_out = torch.tensor(ets, dtype=torch.long)
return x_out, y_out, yl_out, et_out, list(metas)
def build_signal_train_test(
input_modalities, target_modality,
t_obs_sec=1.5, t_fut_sec=0.5, anchor_stride_sec=0.25,
downsample=5,
dataset_dir=DEFAULT_DATASET_DIR, annot_dir=DEFAULT_ANNOT_DIR,
contact_threshold_g=5.0, per_event_max=None,
include_future_pressure=False,
rng_seed=0,
):
train = SignalForecastDataset(
TRAIN_VOLS_V3, input_modalities=input_modalities,
target_modality=target_modality,
t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec,
anchor_stride_sec=anchor_stride_sec, downsample=downsample,
dataset_dir=dataset_dir, annot_dir=annot_dir,
contact_threshold_g=contact_threshold_g, per_event_max=per_event_max,
include_future_pressure=include_future_pressure,
rng_seed=rng_seed, log=True,
)
test = SignalForecastDataset(
TEST_VOLS_V3, input_modalities=input_modalities,
target_modality=target_modality,
t_obs_sec=t_obs_sec, t_fut_sec=t_fut_sec,
anchor_stride_sec=anchor_stride_sec, downsample=downsample,
dataset_dir=dataset_dir, annot_dir=annot_dir,
contact_threshold_g=contact_threshold_g, per_event_max=per_event_max,
input_stats=train._input_stats, target_stats=train._target_stats,
future_pressure_stats=train._fp_stats,
expected_input_dims=train._modality_dims,
expected_target_dim=train._target_dim,
include_future_pressure=include_future_pressure,
rng_seed=rng_seed + 1, log=True,
)
return train, test
if __name__ == "__main__":
import argparse
ap = argparse.ArgumentParser()
ap.add_argument("--input_modalities", default="imu")
ap.add_argument("--target_modality", default="imu")
ap.add_argument("--t_obs", type=float, default=1.5)
ap.add_argument("--t_fut", type=float, default=0.5)
args = ap.parse_args()
tr, te = build_signal_train_test(
input_modalities=args.input_modalities.split(","),
target_modality=args.target_modality,
t_obs_sec=args.t_obs, t_fut_sec=args.t_fut,
)
x, y, y_last, et, meta = tr[0]
print(f"Sample: x={ {m: tuple(v.shape) for m,v in x.items()} } y={tuple(y.shape)} y_last={tuple(y_last.shape)} event_type={et}")
|