| |
| """Train + evaluate T8 v3 — privileged future-pressure conditioning (Option B). |
| |
| Compared to train_signal_forecast.py: |
| - Inputs: past 1.5s of `input_modalities` (e.g. just target modality) |
| + future T_fut s of pressure (privileged side channel) |
| - Output: future T_fut s of `target_modality` |
| - Comparison baseline (A_priv): existing `_no_pressure` runs from T8 v2. |
| - This run is the B_priv group; lift = skill(B_priv) - skill(A_priv). |
| |
| If lift >> 0, future pressure trajectory carries information about future |
| kinematics that past kinematics alone do not encode. This directly tests |
| the Johansson 1984 hypothesis at the algorithmic level. |
| """ |
| from __future__ import annotations |
| import argparse |
| import json |
| import random |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader |
|
|
| THIS = Path(__file__).resolve() |
| sys.path.insert(0, str(THIS.parent)) |
| sys.path.insert(0, str(THIS.parents[1])) |
|
|
| from data.dataset_signal_forecast import ( |
| SignalForecastDataset, collate_signal_forecast, |
| build_signal_train_test, EVENT_NAMES, |
| ) |
| from nets.models_forecast_priv import DAFFuturePressure |
|
|
|
|
| def set_seed(seed: int): |
| random.seed(seed); np.random.seed(seed) |
| torch.manual_seed(seed); torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def train_epoch(model, loader, optimizer, device): |
| model.train() |
| total, n = 0.0, 0 |
| for x, y, y_last, fp, _et, _ in loader: |
| x = {m: v.to(device) for m, v in x.items()} |
| y = y.to(device) |
| y_last = y_last.to(device).unsqueeze(1) |
| fp = fp.to(device) |
| residual_target = y - y_last |
| optimizer.zero_grad() |
| pred = model(x, fp) |
| loss = ((pred - residual_target) ** 2).mean() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| total += loss.item() * y.numel() |
| n += y.numel() |
| return total / max(n, 1) |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, loader, device, t_fut, target_dim): |
| model.eval() |
| sse_m = np.zeros((5, t_fut), dtype=np.float64) |
| sse_p = np.zeros((5, t_fut), dtype=np.float64) |
| n_pairs = np.zeros((5, t_fut), dtype=np.int64) |
|
|
| for x, y, y_last, fp, et, _ in loader: |
| x = {m: v.to(device) for m, v in x.items()} |
| y = y.to(device) |
| y_last = y_last.to(device).unsqueeze(1) |
| fp = fp.to(device) |
| pred = model(x, fp) |
| pred_full = pred + y_last |
| persist = y_last.expand_as(y) |
| m_err = ((pred_full - y) ** 2).mean(dim=-1) |
| p_err = ((persist - y) ** 2).mean(dim=-1) |
| et_np = et.numpy() |
| m_np, p_np = m_err.cpu().numpy(), p_err.cpu().numpy() |
| for k in range(m_np.shape[0]): |
| e = int(et_np[k]) |
| sse_m[e] += m_np[k]; sse_p[e] += p_np[k]; n_pairs[e] += 1 |
| sse_m[4] += m_np[k]; sse_p[4] += p_np[k]; n_pairs[4] += 1 |
|
|
| out = {} |
| for e in range(5): |
| n = max(int(n_pairs[e].max()), 1) |
| mse_m = (sse_m[e] / np.maximum(n_pairs[e], 1)).mean() |
| mse_p = (sse_p[e] / np.maximum(n_pairs[e], 1)).mean() |
| skill = 1.0 - (mse_m / mse_p) if mse_p > 1e-9 else 0.0 |
| per_h_skill = (1.0 - (sse_m[e] / np.maximum(n_pairs[e], 1)) / |
| np.maximum(sse_p[e] / np.maximum(n_pairs[e], 1), 1e-9)).tolist() |
| name = EVENT_NAMES.get(e, "overall") if e < 4 else "overall" |
| out[name] = { |
| "n_anchors": int(n), |
| "mse_model": float(mse_m), |
| "mse_persist": float(mse_p), |
| "skill_score": float(skill), |
| "per_h_skill": [float(s) for s in per_h_skill], |
| } |
| return out |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--input_modalities", required=True, |
| help="comma-separated; pressure NOT included unless you want past pressure too") |
| ap.add_argument("--target_modality", required=True, choices=["imu", "emg", "mocap"]) |
| ap.add_argument("--t_obs", type=float, default=1.5) |
| ap.add_argument("--t_fut", type=float, default=0.5) |
| ap.add_argument("--anchor_stride", type=float, default=0.25) |
| ap.add_argument("--per_event_max", type=int, default=8000) |
| ap.add_argument("--epochs", type=int, default=25) |
| ap.add_argument("--batch_size", type=int, default=64) |
| ap.add_argument("--lr", type=float, default=3e-4) |
| ap.add_argument("--weight_decay", type=float, default=1e-4) |
| ap.add_argument("--d_model", type=int, default=128) |
| ap.add_argument("--dropout", type=float, default=0.1) |
| ap.add_argument("--num_workers", type=int, default=2) |
| ap.add_argument("--seed", type=int, default=42) |
| ap.add_argument("--patience", type=int, default=6) |
| ap.add_argument("--output_dir", required=True) |
| args = ap.parse_args() |
|
|
| set_seed(args.seed) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| inputs = args.input_modalities.split(",") |
| print(f"device={device} seed={args.seed} model=DAF-priv " |
| f"inputs={inputs} target={args.target_modality} " |
| f"t_obs={args.t_obs} t_fut={args.t_fut}", flush=True) |
|
|
| train_ds, test_ds = build_signal_train_test( |
| input_modalities=inputs, |
| target_modality=args.target_modality, |
| t_obs_sec=args.t_obs, t_fut_sec=args.t_fut, |
| anchor_stride_sec=args.anchor_stride, |
| per_event_max=args.per_event_max, |
| include_future_pressure=True, |
| rng_seed=args.seed, |
| ) |
| target_dim = train_ds.target_dim |
| print(f"train={len(train_ds)} test={len(test_ds)} target_dim={target_dim}", |
| flush=True) |
|
|
| tr_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, |
| num_workers=args.num_workers, |
| collate_fn=collate_signal_forecast, drop_last=False) |
| te_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, |
| num_workers=args.num_workers, |
| collate_fn=collate_signal_forecast) |
|
|
| model = DAFFuturePressure( |
| train_ds.modality_dims, target_dim=target_dim, |
| t_obs=train_ds.T_obs, t_fut=train_ds.T_fut, |
| future_pressure_dim=50, |
| d_model=args.d_model, dropout=args.dropout, |
| ).to(device) |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"params={n_params:,}", flush=True) |
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, |
| weight_decay=args.weight_decay) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, T_max=args.epochs, eta_min=args.lr * 0.05 |
| ) |
|
|
| out_dir = Path(args.output_dir); out_dir.mkdir(parents=True, exist_ok=True) |
| best_skill = -1e9 |
| best_epoch, best_eval = 0, None |
| patience_counter = 0 |
| for ep in range(1, args.epochs + 1): |
| t0 = time.time() |
| tr_loss = train_epoch(model, tr_loader, optimizer, device) |
| ev = evaluate(model, te_loader, device, |
| t_fut=train_ds.T_fut, target_dim=target_dim) |
| sched.step() |
| skill = ev["overall"]["skill_score"] |
| print(f" E{ep:2d} | tr_mse {tr_loss:.4f} | te_skill {skill:+.4f} " |
| f"| pre {ev['pre-contact']['skill_score']:+.3f} " |
| f"steady {ev['steady-grip']['skill_score']:+.3f} " |
| f"release {ev['release']['skill_score']:+.3f} " |
| f"non {ev['non-contact']['skill_score']:+.3f} " |
| f"| {time.time()-t0:.1f}s", flush=True) |
| if skill > best_skill: |
| best_skill = skill |
| best_epoch = ep |
| best_eval = ev |
| torch.save({k: v.cpu() for k, v in model.state_dict().items()}, |
| out_dir / "model_best.pt") |
| patience_counter = 0 |
| else: |
| patience_counter += 1 |
| if patience_counter >= args.patience: |
| print(f" early stop at epoch {ep} (best {best_epoch})", flush=True) |
| break |
|
|
| out = { |
| "method": "daf_priv", |
| "input_modalities": inputs, |
| "target_modality": args.target_modality, |
| "future_pressure": True, |
| "seed": args.seed, "n_params": n_params, |
| "T_obs": train_ds.T_obs, "T_fut": train_ds.T_fut, "target_dim": target_dim, |
| "best_epoch": int(best_epoch), "best_skill": float(best_skill), |
| "eval": best_eval, "args": vars(args), |
| } |
| with open(out_dir / "results.json", "w") as f: |
| json.dump(out, f, indent=2) |
| print(f"\n[done] best skill={best_skill:+.4f} at epoch {best_epoch}", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|