| |
| """Train + evaluate frame-level future-signal forecasting (T8 v2). |
| |
| Predicts the raw future signal of one target modality (IMU, EMG, or MoCap) |
| from past T_obs of input modalities. Reports skill score against persistence |
| baseline, broken down by 4 contact-event types. |
| |
| Three configurations supported (driven by --modalities): |
| A. Target-only e.g. --modalities imu (target IMU) |
| B. Target + Pressure e.g. --modalities imu,pressure (target IMU) |
| C. Target + Pressure (zeroed) set --modalities imu,pressure --zero_pressure_at_eval |
| This loads the same checkpoint trained as B and re-evaluates with the |
| pressure channel forced to zero at test time, isolating pressure's |
| causal contribution net of model capacity. |
| |
| Skill score = 1 - MSE(pred, true) / MSE(persistence, true) |
| where persistence = repeat last observed target frame T_fut times. |
| """ |
| from __future__ import annotations |
| import argparse |
| import json |
| import random |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader |
|
|
| THIS = Path(__file__).resolve() |
| sys.path.insert(0, str(THIS.parent)) |
| sys.path.insert(0, str(THIS.parents[1])) |
| sys.path.insert(0, str(THIS.parents[1] / "table8" / "code")) |
|
|
| try: |
| from experiments.dataset_signal_forecast import ( |
| SignalForecastDataset, collate_signal_forecast, |
| build_signal_train_test, EVENT_NAMES, |
| ) |
| except ModuleNotFoundError: |
| from dataset_signal_forecast import ( |
| SignalForecastDataset, collate_signal_forecast, |
| build_signal_train_test, EVENT_NAMES, |
| ) |
| from nets.models_forecast import build_forecast_model |
|
|
|
|
| def set_seed(seed: int): |
| random.seed(seed); np.random.seed(seed) |
| torch.manual_seed(seed); torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def train_epoch(model, loader, optimizer, device): |
| """Model predicts residual to persistence: target = y - y_last.""" |
| model.train() |
| total, n = 0.0, 0 |
| for x, y, y_last, _et, _ in loader: |
| x = {m: v.to(device) for m, v in x.items()} |
| y = y.to(device) |
| y_last = y_last.to(device).unsqueeze(1) |
| residual_target = y - y_last |
| optimizer.zero_grad() |
| pred = model(x) |
| loss = ((pred - residual_target) ** 2).mean() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| total += loss.item() * y.numel() |
| n += y.numel() |
| return total / max(n, 1) |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, loader, device, t_fut: int, target_dim: int, |
| zero_pressure: bool = False): |
| """Return per-event-type and overall: MSE_model, MSE_persist, skill_score, |
| plus per-horizon skill_score.""" |
| model.eval() |
| |
| sse_m = np.zeros((5, t_fut), dtype=np.float64) |
| sse_p = np.zeros((5, t_fut), dtype=np.float64) |
| n_pairs = np.zeros((5, t_fut), dtype=np.int64) |
|
|
| for x, y, y_last, et, _ in loader: |
| x = {m: v.to(device) for m, v in x.items()} |
| if zero_pressure and "pressure" in x: |
| x["pressure"] = torch.zeros_like(x["pressure"]) |
| y = y.to(device) |
| y_last = y_last.to(device).unsqueeze(1) |
| pred = model(x) |
| pred_full = pred + y_last |
| persist = y_last.expand_as(y) |
| m_err = ((pred_full - y) ** 2).mean(dim=-1) |
| p_err = ((persist - y) ** 2).mean(dim=-1) |
| et_np = et.numpy() |
| m_np, p_np = m_err.cpu().numpy(), p_err.cpu().numpy() |
| for k in range(m_np.shape[0]): |
| e = int(et_np[k]) |
| sse_m[e] += m_np[k]; sse_p[e] += p_np[k]; n_pairs[e] += 1 |
| sse_m[4] += m_np[k]; sse_p[4] += p_np[k]; n_pairs[4] += 1 |
|
|
| out = {} |
| for e in range(5): |
| n = max(int(n_pairs[e].max()), 1) |
| mse_m = (sse_m[e] / np.maximum(n_pairs[e], 1)).mean() |
| mse_p = (sse_p[e] / np.maximum(n_pairs[e], 1)).mean() |
| skill = 1.0 - (mse_m / mse_p) if mse_p > 1e-9 else 0.0 |
| |
| per_h_m = sse_m[e] / np.maximum(n_pairs[e], 1) |
| per_h_p = sse_p[e] / np.maximum(n_pairs[e], 1) |
| per_h_skill = (1.0 - per_h_m / np.maximum(per_h_p, 1e-9)).tolist() |
| name = EVENT_NAMES.get(e, "overall") if e < 4 else "overall" |
| out[name] = { |
| "n_anchors": int(n), |
| "mse_model": float(mse_m), |
| "mse_persist": float(mse_p), |
| "skill_score": float(skill), |
| "per_h_skill": [float(s) for s in per_h_skill], |
| } |
| return out |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--model", required=True, choices=["daf", "futr", "deepconvlstm"]) |
| ap.add_argument("--input_modalities", required=True, |
| help="e.g. 'imu' or 'imu,pressure'") |
| ap.add_argument("--target_modality", required=True, choices=["imu", "emg", "mocap"]) |
| ap.add_argument("--t_obs", type=float, default=1.5) |
| ap.add_argument("--t_fut", type=float, default=0.5) |
| ap.add_argument("--anchor_stride", type=float, default=0.25) |
| ap.add_argument("--per_event_max", type=int, default=8000, |
| help="Cap each event-type pool to this many anchors (per split). " |
| "Use a large number to keep all anchors.") |
| ap.add_argument("--epochs", type=int, default=25) |
| ap.add_argument("--batch_size", type=int, default=64) |
| ap.add_argument("--lr", type=float, default=3e-4) |
| ap.add_argument("--weight_decay", type=float, default=1e-4) |
| ap.add_argument("--d_model", type=int, default=128) |
| ap.add_argument("--dropout", type=float, default=0.1) |
| ap.add_argument("--num_workers", type=int, default=2) |
| ap.add_argument("--seed", type=int, default=42) |
| ap.add_argument("--patience", type=int, default=5) |
| ap.add_argument("--zero_pressure_at_eval", action="store_true", |
| help="Eval-only: zero out the pressure input (causal-ablation control).") |
| ap.add_argument("--load_checkpoint", type=str, default=None, |
| help="Skip training, load checkpoint and run only eval (for control C).") |
| ap.add_argument("--output_dir", required=True) |
| args = ap.parse_args() |
|
|
| set_seed(args.seed) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| inputs = args.input_modalities.split(",") |
| print(f"device={device} seed={args.seed} model={args.model} " |
| f"inputs={inputs} target={args.target_modality} " |
| f"t_obs={args.t_obs} t_fut={args.t_fut} " |
| f"zero_pressure_at_eval={args.zero_pressure_at_eval}", flush=True) |
|
|
| train_ds, test_ds = build_signal_train_test( |
| input_modalities=inputs, |
| target_modality=args.target_modality, |
| t_obs_sec=args.t_obs, t_fut_sec=args.t_fut, |
| anchor_stride_sec=args.anchor_stride, |
| per_event_max=args.per_event_max, |
| rng_seed=args.seed, |
| ) |
| target_dim = train_ds.target_dim |
| print(f"train={len(train_ds)} test={len(test_ds)} target_dim={target_dim}", |
| flush=True) |
|
|
| tr_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, |
| num_workers=args.num_workers, collate_fn=collate_signal_forecast, |
| drop_last=False) |
| te_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, |
| num_workers=args.num_workers, collate_fn=collate_signal_forecast) |
|
|
| |
| model = build_forecast_model( |
| args.model, train_ds.modality_dims, |
| num_classes=target_dim, |
| t_obs=train_ds.T_obs, t_fut=train_ds.T_fut, |
| d_model=args.d_model, dropout=args.dropout, |
| ).to(device) |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"params={n_params:,}", flush=True) |
|
|
| out_dir = Path(args.output_dir); out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| if args.load_checkpoint is not None: |
| print(f"loading checkpoint {args.load_checkpoint}", flush=True) |
| sd = torch.load(args.load_checkpoint, map_location=device) |
| model.load_state_dict(sd) |
| ev = evaluate(model, te_loader, device, |
| t_fut=train_ds.T_fut, target_dim=target_dim, |
| zero_pressure=args.zero_pressure_at_eval) |
| out = { |
| "method": args.model, |
| "input_modalities": inputs, |
| "target_modality": args.target_modality, |
| "seed": args.seed, |
| "n_params": n_params, |
| "T_obs": train_ds.T_obs, "T_fut": train_ds.T_fut, "target_dim": target_dim, |
| "best_epoch": -1, "mode": "eval_only", |
| "zero_pressure_at_eval": bool(args.zero_pressure_at_eval), |
| "loaded_from": args.load_checkpoint, |
| "eval": ev, |
| "args": vars(args), |
| } |
| with open(out_dir / "results.json", "w") as f: |
| json.dump(out, f, indent=2) |
| print(f"[done] overall skill_score = {ev['overall']['skill_score']:.4f}", flush=True) |
| for e in ("non-contact", "pre-contact", "steady-grip", "release"): |
| print(f" {e:14s} skill={ev[e]['skill_score']:+.4f} (n={ev[e]['n_anchors']})", flush=True) |
| return |
|
|
| |
| optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=args.lr * 0.05) |
|
|
| best_skill = -1e9 |
| best_epoch = 0 |
| best_eval = None |
| patience_counter = 0 |
| for ep in range(1, args.epochs + 1): |
| t0 = time.time() |
| tr_loss = train_epoch(model, tr_loader, optimizer, device) |
| ev = evaluate(model, te_loader, device, |
| t_fut=train_ds.T_fut, target_dim=target_dim, |
| zero_pressure=False) |
| sched.step() |
| skill = ev["overall"]["skill_score"] |
| print(f" E{ep:2d} | tr_mse {tr_loss:.4f} | te_skill {skill:+.4f} " |
| f"| pre {ev['pre-contact']['skill_score']:+.3f} " |
| f"steady {ev['steady-grip']['skill_score']:+.3f} " |
| f"release {ev['release']['skill_score']:+.3f} " |
| f"non {ev['non-contact']['skill_score']:+.3f} " |
| f"| {time.time()-t0:.1f}s", flush=True) |
| if skill > best_skill: |
| best_skill = skill |
| best_epoch = ep |
| best_eval = ev |
| torch.save({k: v.cpu() for k, v in model.state_dict().items()}, |
| out_dir / "model_best.pt") |
| patience_counter = 0 |
| else: |
| patience_counter += 1 |
| if patience_counter >= args.patience: |
| print(f" early stop at epoch {ep} (best {best_epoch})", flush=True) |
| break |
|
|
| out = { |
| "method": args.model, |
| "input_modalities": inputs, |
| "target_modality": args.target_modality, |
| "seed": args.seed, |
| "n_params": n_params, |
| "T_obs": train_ds.T_obs, "T_fut": train_ds.T_fut, "target_dim": target_dim, |
| "best_epoch": int(best_epoch), |
| "best_skill": float(best_skill), |
| "eval": best_eval, |
| "args": vars(args), |
| } |
| with open(out_dir / "results.json", "w") as f: |
| json.dump(out, f, indent=2) |
| print(f"\n[done] best skill={best_skill:+.4f} at epoch {best_epoch}", flush=True) |
| print(f"saved to {out_dir}/results.json", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|