Spaces:
Sleeping
Sleeping
| """Inference and visualization helpers for ECG MAE (training dashboards and deployment).""" | |
| from __future__ import annotations | |
| from dataclasses import asdict | |
| from pathlib import Path | |
| from typing import Any | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor | |
| from mae.config import MAEConfig | |
| from mae.pipeline import MAEReconstructionPipeline | |
| # PTB-XL / WFDB order: I, II, III, AVR, AVL, AVF, V1–V6 (display as standard names). | |
| STANDARD_12_LEAD_NAMES: tuple[str, ...] = ( | |
| "I", | |
| "II", | |
| "III", | |
| "aVR", | |
| "aVL", | |
| "aVF", | |
| "V1", | |
| "V2", | |
| "V3", | |
| "V4", | |
| "V5", | |
| "V6", | |
| ) | |
| def load_wfdb_record_leads_first(record_path_without_ext: str | Path) -> tuple[np.ndarray, float]: | |
| """ | |
| Load a WFDB record the same way as ``PTBXLDataset``: ``wfdb.rdrecord`` then ``ptbxl_wfdb_to_leads_first``. | |
| Parameters | |
| ---------- | |
| record_path_without_ext | |
| Path **without** extension, e.g. ``.../00001_hr`` so that ``00001_hr.hea`` and ``00001_hr.dat`` exist | |
| beside it (PTB-XL high-rate layout). | |
| Returns | |
| ------- | |
| ecg | |
| ``(12, T)`` float32, leads-first (PTB-XL / WFDB column order). | |
| fs_hz | |
| Sampling rate from the record header (falls back to ``preprocessor.DEFAULT_FS`` if missing). | |
| """ | |
| import wfdb | |
| import preprocessor as pre | |
| p = Path(record_path_without_ext).resolve() | |
| parent = p.parent | |
| stem = p.name | |
| hea = parent / f"{stem}.hea" | |
| dat = parent / f"{stem}.dat" | |
| if not hea.is_file(): | |
| raise FileNotFoundError(f"Missing WFDB header: {hea}") | |
| if not dat.is_file(): | |
| raise FileNotFoundError(f"Missing WFDB signal file: {dat}") | |
| rec = wfdb.rdrecord(str(p)) | |
| x = pre.ptbxl_wfdb_to_leads_first(rec.p_signal) | |
| fs = float(rec.fs) if getattr(rec, "fs", None) else float(pre.DEFAULT_FS) | |
| return x, fs | |
| def _is_git_lfs_pointer_file(path: Path) -> bool: | |
| """True if ``path`` is a Git LFS stub (text pointer) instead of the real blob.""" | |
| try: | |
| with path.open("r", encoding="ascii", errors="replace") as f: | |
| head = f.read(128) | |
| except OSError: | |
| return False | |
| return head.startswith("version https://git-lfs.github.com/spec/v1") | |
| def load_pipeline( | |
| checkpoint_path: str | Path, | |
| device: torch.device | str | None = None, | |
| config: MAEConfig | None = None, | |
| ) -> MAEReconstructionPipeline: | |
| """Load ``MAEReconstructionPipeline`` weights from a training checkpoint dict.""" | |
| checkpoint_path = Path(checkpoint_path) | |
| device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) | |
| if _is_git_lfs_pointer_file(checkpoint_path): | |
| raise RuntimeError( | |
| f"{checkpoint_path} is a Git LFS pointer file, not checkpoint bytes. " | |
| "Install Git LFS if needed, then run: git lfs pull\n" | |
| "Or place a real checkpoint .pt at this path (e.g. from runs/ after training)." | |
| ) | |
| ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
| cfg = config or ckpt.get("config") | |
| if cfg is None: | |
| cfg = MAEConfig() | |
| elif isinstance(cfg, dict): | |
| allowed = {k: v for k, v in cfg.items() if k in MAEConfig.__dataclass_fields__} | |
| cfg = MAEConfig(**allowed) | |
| pipe = MAEReconstructionPipeline(cfg).to(device) | |
| state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt | |
| pipe.load_state_dict(state, strict=True) | |
| pipe.eval() | |
| return pipe | |
| def reconstruct( | |
| pipeline: MAEReconstructionPipeline, | |
| ecg: torch.Tensor, | |
| *, | |
| V: Tensor | None = None, | |
| M: Tensor | None = None, | |
| D: Tensor | None = None, | |
| generator: torch.Generator | None = None, | |
| return_loss: bool = True, | |
| ) -> dict[str, Any]: | |
| """ | |
| Run one forward pass. | |
| Parameters | |
| ---------- | |
| ecg : (B, 12, T) tensor on the same device as ``pipeline``. | |
| V, M, D : optional bool ``(B, C, N)`` masks partitioning the patch grid. If any is given, | |
| all three must be supplied (see ``ECGDataMAE.forward``). If omitted, STDM is sampled. | |
| """ | |
| out = pipeline.model( | |
| ecg, | |
| V=V, | |
| M=M, | |
| D=D, | |
| generator=generator, | |
| return_loss=return_loss, | |
| ) | |
| return out | |
| def patches_to_signal(pred: np.ndarray, num_leads: int, num_patches: int, patch_size: int) -> np.ndarray: | |
| """(C, N, P) -> (C, N*P)""" | |
| return pred.reshape(num_leads, num_patches * patch_size) | |
| def mask_m_to_time(mask_m: np.ndarray, patch_size: int) -> np.ndarray: | |
| """(C, N) bool -> (C, N*P) bool (patch-wise constant mask).""" | |
| c, n = mask_m.shape | |
| return np.repeat(mask_m, patch_size, axis=1) | |
| def figure_reconstruction( | |
| original: np.ndarray, | |
| predicted: np.ndarray, | |
| mask_m: np.ndarray, | |
| *, | |
| leads: tuple[int, ...] = tuple(range(12)), | |
| patch_size: int, | |
| title: str = "ECG MAE reconstruction", | |
| ): | |
| """ | |
| Build a matplotlib figure: a few leads with original vs prediction. | |
| ``mask_m`` is (C, N) bool; masked segments are shaded in the background. | |
| Y-axis labels use standard 12-lead names in PTB-XL order (indices 0–11). | |
| """ | |
| c, t = original.shape | |
| n = mask_m.shape[1] | |
| assert t == n * patch_size | |
| t_idx = np.arange(t) | |
| fig, axes = plt.subplots(len(leads), 1, figsize=(12, 2.5 * len(leads)), sharex=True) | |
| if len(leads) == 1: | |
| axes = [axes] | |
| mt = mask_m_to_time(mask_m, patch_size) | |
| for ax, lead in zip(axes, leads, strict=True): | |
| m = mt[lead] | |
| ax.fill_between( | |
| t_idx, | |
| original[lead].min(), | |
| original[lead].max(), | |
| where=m, | |
| alpha=0.25, | |
| color="gray", | |
| linewidth=0, | |
| ) | |
| ax.plot(t_idx, original[lead], label="original", linewidth=0.8) | |
| ax.plot(t_idx, predicted[lead], label="recon", linewidth=0.8, alpha=0.9) | |
| name = ( | |
| STANDARD_12_LEAD_NAMES[lead] | |
| if 0 <= lead < len(STANDARD_12_LEAD_NAMES) | |
| else f"Lead {lead}" | |
| ) | |
| ax.set_ylabel(name) | |
| ax.legend(loc="upper right", fontsize=8) | |
| axes[-1].set_xlabel("sample") | |
| fig.suptitle(title) | |
| fig.tight_layout() | |
| return fig | |
| def save_checkpoint( | |
| path: str | Path, | |
| pipeline: nn.Module, | |
| optimizer: torch.optim.Optimizer | None, | |
| epoch: int, | |
| step: int, | |
| config: MAEConfig, | |
| ) -> None: | |
| path = Path(path) | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| torch.save( | |
| { | |
| "model": pipeline.state_dict(), | |
| "optimizer": optimizer.state_dict() if optimizer is not None else None, | |
| "epoch": epoch, | |
| "step": step, | |
| "config": asdict(config) if hasattr(config, "__dataclass_fields__") else config, | |
| }, | |
| path, | |
| ) | |
| def main() -> None: | |
| import argparse | |
| p = argparse.ArgumentParser(description="Run ECG MAE reconstruction and save a figure.") | |
| p.add_argument("--checkpoint", type=Path, default=Path("ckpts/checkpoint_last.pt")) | |
| p.add_argument("--data-root", type=Path, default=Path("dataset/ptb_xl")) | |
| p.add_argument("--split", choices=("train", "val", "test"), default="test") | |
| p.add_argument("--index", type=int, default=0) | |
| p.add_argument("--out", type=Path, default=Path("reconstruction.png")) | |
| p.add_argument("--device", default=None) | |
| args = p.parse_args() | |
| from ptb_xl_dataset import PTBXLDataset | |
| device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) | |
| print(f"Using device: {device}") | |
| pipe = load_pipeline(args.checkpoint, device=device) | |
| cfg = pipe.config | |
| ds = PTBXLDataset(args.data_root, split=args.split, rng_seed=0) | |
| batch = ds[args.index]["ecg"].unsqueeze(0).to(device) | |
| g = torch.Generator(device=device) | |
| g.manual_seed(0) | |
| with torch.no_grad(): | |
| out = reconstruct(pipe, batch, generator=g) | |
| orig = batch[0].cpu().numpy() | |
| pred = out["pred"][0].reshape(cfg.num_leads, -1).cpu().numpy() | |
| m = out["M"][0].cpu().numpy() | |
| fig = figure_reconstruction( | |
| orig, pred, m, patch_size=cfg.patch_size, title=f"split={args.split} idx={args.index}" | |
| ) | |
| fig.savefig(args.out, dpi=150) | |
| print(f"Saved {args.out}") | |
| if __name__ == "__main__": | |
| main() | |