"""Inference and visualization helpers for ECG MAE (training dashboards and deployment).""" from __future__ import annotations from dataclasses import asdict from pathlib import Path from typing import Any import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn from torch import Tensor from mae.config import MAEConfig from mae.pipeline import MAEReconstructionPipeline # PTB-XL / WFDB order: I, II, III, AVR, AVL, AVF, V1–V6 (display as standard names). STANDARD_12_LEAD_NAMES: tuple[str, ...] = ( "I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6", ) def load_wfdb_record_leads_first(record_path_without_ext: str | Path) -> tuple[np.ndarray, float]: """ Load a WFDB record the same way as ``PTBXLDataset``: ``wfdb.rdrecord`` then ``ptbxl_wfdb_to_leads_first``. Parameters ---------- record_path_without_ext Path **without** extension, e.g. ``.../00001_hr`` so that ``00001_hr.hea`` and ``00001_hr.dat`` exist beside it (PTB-XL high-rate layout). Returns ------- ecg ``(12, T)`` float32, leads-first (PTB-XL / WFDB column order). fs_hz Sampling rate from the record header (falls back to ``preprocessor.DEFAULT_FS`` if missing). """ import wfdb import preprocessor as pre p = Path(record_path_without_ext).resolve() parent = p.parent stem = p.name hea = parent / f"{stem}.hea" dat = parent / f"{stem}.dat" if not hea.is_file(): raise FileNotFoundError(f"Missing WFDB header: {hea}") if not dat.is_file(): raise FileNotFoundError(f"Missing WFDB signal file: {dat}") rec = wfdb.rdrecord(str(p)) x = pre.ptbxl_wfdb_to_leads_first(rec.p_signal) fs = float(rec.fs) if getattr(rec, "fs", None) else float(pre.DEFAULT_FS) return x, fs def _is_git_lfs_pointer_file(path: Path) -> bool: """True if ``path`` is a Git LFS stub (text pointer) instead of the real blob.""" try: with path.open("r", encoding="ascii", errors="replace") as f: head = f.read(128) except OSError: return False return head.startswith("version https://git-lfs.github.com/spec/v1") def load_pipeline( checkpoint_path: str | Path, device: torch.device | str | None = None, config: MAEConfig | None = None, ) -> MAEReconstructionPipeline: """Load ``MAEReconstructionPipeline`` weights from a training checkpoint dict.""" checkpoint_path = Path(checkpoint_path) device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) if _is_git_lfs_pointer_file(checkpoint_path): raise RuntimeError( f"{checkpoint_path} is a Git LFS pointer file, not checkpoint bytes. " "Install Git LFS if needed, then run: git lfs pull\n" "Or place a real checkpoint .pt at this path (e.g. from runs/ after training)." ) ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) cfg = config or ckpt.get("config") if cfg is None: cfg = MAEConfig() elif isinstance(cfg, dict): allowed = {k: v for k, v in cfg.items() if k in MAEConfig.__dataclass_fields__} cfg = MAEConfig(**allowed) pipe = MAEReconstructionPipeline(cfg).to(device) state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt pipe.load_state_dict(state, strict=True) pipe.eval() return pipe @torch.no_grad() def reconstruct( pipeline: MAEReconstructionPipeline, ecg: torch.Tensor, *, V: Tensor | None = None, M: Tensor | None = None, D: Tensor | None = None, generator: torch.Generator | None = None, return_loss: bool = True, ) -> dict[str, Any]: """ Run one forward pass. Parameters ---------- ecg : (B, 12, T) tensor on the same device as ``pipeline``. V, M, D : optional bool ``(B, C, N)`` masks partitioning the patch grid. If any is given, all three must be supplied (see ``ECGDataMAE.forward``). If omitted, STDM is sampled. """ out = pipeline.model( ecg, V=V, M=M, D=D, generator=generator, return_loss=return_loss, ) return out def patches_to_signal(pred: np.ndarray, num_leads: int, num_patches: int, patch_size: int) -> np.ndarray: """(C, N, P) -> (C, N*P)""" return pred.reshape(num_leads, num_patches * patch_size) def mask_m_to_time(mask_m: np.ndarray, patch_size: int) -> np.ndarray: """(C, N) bool -> (C, N*P) bool (patch-wise constant mask).""" c, n = mask_m.shape return np.repeat(mask_m, patch_size, axis=1) def figure_reconstruction( original: np.ndarray, predicted: np.ndarray, mask_m: np.ndarray, *, leads: tuple[int, ...] = tuple(range(12)), patch_size: int, title: str = "ECG MAE reconstruction", ): """ Build a matplotlib figure: a few leads with original vs prediction. ``mask_m`` is (C, N) bool; masked segments are shaded in the background. Y-axis labels use standard 12-lead names in PTB-XL order (indices 0–11). """ c, t = original.shape n = mask_m.shape[1] assert t == n * patch_size t_idx = np.arange(t) fig, axes = plt.subplots(len(leads), 1, figsize=(12, 2.5 * len(leads)), sharex=True) if len(leads) == 1: axes = [axes] mt = mask_m_to_time(mask_m, patch_size) for ax, lead in zip(axes, leads, strict=True): m = mt[lead] ax.fill_between( t_idx, original[lead].min(), original[lead].max(), where=m, alpha=0.25, color="gray", linewidth=0, ) ax.plot(t_idx, original[lead], label="original", linewidth=0.8) ax.plot(t_idx, predicted[lead], label="recon", linewidth=0.8, alpha=0.9) name = ( STANDARD_12_LEAD_NAMES[lead] if 0 <= lead < len(STANDARD_12_LEAD_NAMES) else f"Lead {lead}" ) ax.set_ylabel(name) ax.legend(loc="upper right", fontsize=8) axes[-1].set_xlabel("sample") fig.suptitle(title) fig.tight_layout() return fig def save_checkpoint( path: str | Path, pipeline: nn.Module, optimizer: torch.optim.Optimizer | None, epoch: int, step: int, config: MAEConfig, ) -> None: path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) torch.save( { "model": pipeline.state_dict(), "optimizer": optimizer.state_dict() if optimizer is not None else None, "epoch": epoch, "step": step, "config": asdict(config) if hasattr(config, "__dataclass_fields__") else config, }, path, ) def main() -> None: import argparse p = argparse.ArgumentParser(description="Run ECG MAE reconstruction and save a figure.") p.add_argument("--checkpoint", type=Path, default=Path("ckpts/checkpoint_last.pt")) p.add_argument("--data-root", type=Path, default=Path("dataset/ptb_xl")) p.add_argument("--split", choices=("train", "val", "test"), default="test") p.add_argument("--index", type=int, default=0) p.add_argument("--out", type=Path, default=Path("reconstruction.png")) p.add_argument("--device", default=None) args = p.parse_args() from ptb_xl_dataset import PTBXLDataset device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) print(f"Using device: {device}") pipe = load_pipeline(args.checkpoint, device=device) cfg = pipe.config ds = PTBXLDataset(args.data_root, split=args.split, rng_seed=0) batch = ds[args.index]["ecg"].unsqueeze(0).to(device) g = torch.Generator(device=device) g.manual_seed(0) with torch.no_grad(): out = reconstruct(pipe, batch, generator=g) orig = batch[0].cpu().numpy() pred = out["pred"][0].reshape(cfg.num_leads, -1).cpu().numpy() m = out["M"][0].cpu().numpy() fig = figure_reconstruction( orig, pred, m, patch_size=cfg.patch_size, title=f"split={args.split} idx={args.index}" ) fig.savefig(args.out, dpi=150) print(f"Saved {args.out}") if __name__ == "__main__": main()