ecg_reconstruction / inference.py
PhurinutR's picture
make it accept .dat file
300943f
"""Inference and visualization helpers for ECG MAE (training dashboards and deployment)."""
from __future__ import annotations
from dataclasses import asdict
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from mae.config import MAEConfig
from mae.pipeline import MAEReconstructionPipeline
# PTB-XL / WFDB order: I, II, III, AVR, AVL, AVF, V1–V6 (display as standard names).
STANDARD_12_LEAD_NAMES: tuple[str, ...] = (
"I",
"II",
"III",
"aVR",
"aVL",
"aVF",
"V1",
"V2",
"V3",
"V4",
"V5",
"V6",
)
def load_wfdb_record_leads_first(record_path_without_ext: str | Path) -> tuple[np.ndarray, float]:
"""
Load a WFDB record the same way as ``PTBXLDataset``: ``wfdb.rdrecord`` then ``ptbxl_wfdb_to_leads_first``.
Parameters
----------
record_path_without_ext
Path **without** extension, e.g. ``.../00001_hr`` so that ``00001_hr.hea`` and ``00001_hr.dat`` exist
beside it (PTB-XL high-rate layout).
Returns
-------
ecg
``(12, T)`` float32, leads-first (PTB-XL / WFDB column order).
fs_hz
Sampling rate from the record header (falls back to ``preprocessor.DEFAULT_FS`` if missing).
"""
import wfdb
import preprocessor as pre
p = Path(record_path_without_ext).resolve()
parent = p.parent
stem = p.name
hea = parent / f"{stem}.hea"
dat = parent / f"{stem}.dat"
if not hea.is_file():
raise FileNotFoundError(f"Missing WFDB header: {hea}")
if not dat.is_file():
raise FileNotFoundError(f"Missing WFDB signal file: {dat}")
rec = wfdb.rdrecord(str(p))
x = pre.ptbxl_wfdb_to_leads_first(rec.p_signal)
fs = float(rec.fs) if getattr(rec, "fs", None) else float(pre.DEFAULT_FS)
return x, fs
def _is_git_lfs_pointer_file(path: Path) -> bool:
"""True if ``path`` is a Git LFS stub (text pointer) instead of the real blob."""
try:
with path.open("r", encoding="ascii", errors="replace") as f:
head = f.read(128)
except OSError:
return False
return head.startswith("version https://git-lfs.github.com/spec/v1")
def load_pipeline(
checkpoint_path: str | Path,
device: torch.device | str | None = None,
config: MAEConfig | None = None,
) -> MAEReconstructionPipeline:
"""Load ``MAEReconstructionPipeline`` weights from a training checkpoint dict."""
checkpoint_path = Path(checkpoint_path)
device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
if _is_git_lfs_pointer_file(checkpoint_path):
raise RuntimeError(
f"{checkpoint_path} is a Git LFS pointer file, not checkpoint bytes. "
"Install Git LFS if needed, then run: git lfs pull\n"
"Or place a real checkpoint .pt at this path (e.g. from runs/ after training)."
)
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
cfg = config or ckpt.get("config")
if cfg is None:
cfg = MAEConfig()
elif isinstance(cfg, dict):
allowed = {k: v for k, v in cfg.items() if k in MAEConfig.__dataclass_fields__}
cfg = MAEConfig(**allowed)
pipe = MAEReconstructionPipeline(cfg).to(device)
state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt
pipe.load_state_dict(state, strict=True)
pipe.eval()
return pipe
@torch.no_grad()
def reconstruct(
pipeline: MAEReconstructionPipeline,
ecg: torch.Tensor,
*,
V: Tensor | None = None,
M: Tensor | None = None,
D: Tensor | None = None,
generator: torch.Generator | None = None,
return_loss: bool = True,
) -> dict[str, Any]:
"""
Run one forward pass.
Parameters
----------
ecg : (B, 12, T) tensor on the same device as ``pipeline``.
V, M, D : optional bool ``(B, C, N)`` masks partitioning the patch grid. If any is given,
all three must be supplied (see ``ECGDataMAE.forward``). If omitted, STDM is sampled.
"""
out = pipeline.model(
ecg,
V=V,
M=M,
D=D,
generator=generator,
return_loss=return_loss,
)
return out
def patches_to_signal(pred: np.ndarray, num_leads: int, num_patches: int, patch_size: int) -> np.ndarray:
"""(C, N, P) -> (C, N*P)"""
return pred.reshape(num_leads, num_patches * patch_size)
def mask_m_to_time(mask_m: np.ndarray, patch_size: int) -> np.ndarray:
"""(C, N) bool -> (C, N*P) bool (patch-wise constant mask)."""
c, n = mask_m.shape
return np.repeat(mask_m, patch_size, axis=1)
def figure_reconstruction(
original: np.ndarray,
predicted: np.ndarray,
mask_m: np.ndarray,
*,
leads: tuple[int, ...] = tuple(range(12)),
patch_size: int,
title: str = "ECG MAE reconstruction",
):
"""
Build a matplotlib figure: a few leads with original vs prediction.
``mask_m`` is (C, N) bool; masked segments are shaded in the background.
Y-axis labels use standard 12-lead names in PTB-XL order (indices 0–11).
"""
c, t = original.shape
n = mask_m.shape[1]
assert t == n * patch_size
t_idx = np.arange(t)
fig, axes = plt.subplots(len(leads), 1, figsize=(12, 2.5 * len(leads)), sharex=True)
if len(leads) == 1:
axes = [axes]
mt = mask_m_to_time(mask_m, patch_size)
for ax, lead in zip(axes, leads, strict=True):
m = mt[lead]
ax.fill_between(
t_idx,
original[lead].min(),
original[lead].max(),
where=m,
alpha=0.25,
color="gray",
linewidth=0,
)
ax.plot(t_idx, original[lead], label="original", linewidth=0.8)
ax.plot(t_idx, predicted[lead], label="recon", linewidth=0.8, alpha=0.9)
name = (
STANDARD_12_LEAD_NAMES[lead]
if 0 <= lead < len(STANDARD_12_LEAD_NAMES)
else f"Lead {lead}"
)
ax.set_ylabel(name)
ax.legend(loc="upper right", fontsize=8)
axes[-1].set_xlabel("sample")
fig.suptitle(title)
fig.tight_layout()
return fig
def save_checkpoint(
path: str | Path,
pipeline: nn.Module,
optimizer: torch.optim.Optimizer | None,
epoch: int,
step: int,
config: MAEConfig,
) -> None:
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
torch.save(
{
"model": pipeline.state_dict(),
"optimizer": optimizer.state_dict() if optimizer is not None else None,
"epoch": epoch,
"step": step,
"config": asdict(config) if hasattr(config, "__dataclass_fields__") else config,
},
path,
)
def main() -> None:
import argparse
p = argparse.ArgumentParser(description="Run ECG MAE reconstruction and save a figure.")
p.add_argument("--checkpoint", type=Path, default=Path("ckpts/checkpoint_last.pt"))
p.add_argument("--data-root", type=Path, default=Path("dataset/ptb_xl"))
p.add_argument("--split", choices=("train", "val", "test"), default="test")
p.add_argument("--index", type=int, default=0)
p.add_argument("--out", type=Path, default=Path("reconstruction.png"))
p.add_argument("--device", default=None)
args = p.parse_args()
from ptb_xl_dataset import PTBXLDataset
device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu"))
print(f"Using device: {device}")
pipe = load_pipeline(args.checkpoint, device=device)
cfg = pipe.config
ds = PTBXLDataset(args.data_root, split=args.split, rng_seed=0)
batch = ds[args.index]["ecg"].unsqueeze(0).to(device)
g = torch.Generator(device=device)
g.manual_seed(0)
with torch.no_grad():
out = reconstruct(pipe, batch, generator=g)
orig = batch[0].cpu().numpy()
pred = out["pred"][0].reshape(cfg.num_leads, -1).cpu().numpy()
m = out["M"][0].cpu().numpy()
fig = figure_reconstruction(
orig, pred, m, patch_size=cfg.patch_size, title=f"split={args.split} idx={args.index}"
)
fig.savefig(args.out, dpi=150)
print(f"Saved {args.out}")
if __name__ == "__main__":
main()