rmxjck's picture
Initial release
a888b0c verified
#!/usr/bin/env python3
# SPDX-License-Identifier: MIT
"""Inference example for the LTAF HTF beat classifier (N / A / V)."""
from __future__ import annotations
from typing import List, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from model import EcgBeatHTFClassifier, BEAT_CLASS_NAMES
SOURCE_HZ = 128
WINDOW_SAMPLES = 256 # 2 s
def load_model(device: str = "cpu") -> EcgBeatHTFClassifier:
ckpt_path = hf_hub_download("rmxjck/ltaf-ecg-beats-classifier-htf", "best_classifier.pt")
return EcgBeatHTFClassifier.load(ckpt_path, device=device)
def zscore(window: np.ndarray) -> np.ndarray:
mean = window.mean(axis=-1, keepdims=True)
std = window.std(axis=-1, keepdims=True)
return ((window - mean) / (std + 1e-6)).astype(np.float32, copy=False)
def predict_beat(
model: EcgBeatHTFClassifier,
window: np.ndarray,
rr_history: np.ndarray,
label_history: np.ndarray | None = None,
device: str = "cpu",
) -> Tuple[str, float]:
"""Predict on one beat.
Args:
window: (2, 256) z-scored 2 s window centered on the R-peak.
rr_history: (5,) RR intervals (seconds) to preceding 5 beats.
Use 0.0 for unknown / record start.
label_history: (5,) int — preceding 5 beat labels (0=N, 1=A, 2=V).
Use -1 for unknown. If None, model uses zero labels.
device: torch device.
"""
if window.shape != (2, WINDOW_SAMPLES):
raise ValueError(f"window must be (2, {WINDOW_SAMPLES}), got {window.shape}")
if rr_history.shape != (model.history_k,):
raise ValueError(f"rr_history must be ({model.history_k},), got {rr_history.shape}")
if label_history is None:
label_history = np.full(model.history_k, -1, dtype=np.int64)
x_time = torch.from_numpy(window).float().unsqueeze(0).to(device)
rr = torch.from_numpy(rr_history).float().unsqueeze(0).to(device)
lbl = torch.from_numpy(label_history).long().unsqueeze(0).to(device)
with torch.no_grad():
probs = F.softmax(model(x_time, rr, lbl), dim=-1)[0]
idx = int(probs.argmax().item())
return model.class_names[idx], float(probs[idx].item())
def predict_beat_sequence(
model: EcgBeatHTFClassifier,
signal: np.ndarray,
r_peak_samples: List[int],
device: str = "cpu",
) -> List[Tuple[str, float]]:
"""Predict labels for a sequence of beats in order.
Uses each beat's *predicted* label as the history input for subsequent
beats (autoregressive), since the true labels aren't known at inference.
Args:
signal: (N, 2) raw float32 signal at 128 Hz.
r_peak_samples: list of R-peak sample indices.
device: torch device.
Returns:
list of (class_name, prob) per beat in order.
"""
half = WINDOW_SAMPLES // 2
K = model.history_k
n_classes = model.num_classes
n_total = signal.shape[0]
rr_buf = [0.0] * K
lbl_buf = [-1] * K
out = []
prev_sample = None
for s in r_peak_samples:
ws = max(0, s - half)
we = ws + WINDOW_SAMPLES
if we > n_total:
break
window = signal[ws:we].T.astype(np.float32, copy=False)
window = zscore(window)
if window.shape[1] < WINDOW_SAMPLES:
window = np.pad(window, ((0, 0), (0, WINDOW_SAMPLES - window.shape[1])))
rr_arr = np.array(rr_buf, dtype=np.float32)
lbl_arr = np.array(lbl_buf, dtype=np.int64)
cls, prob = predict_beat(model, window, rr_arr, lbl_arr, device=device)
out.append((cls, prob))
# Update history buffers for the next beat.
if prev_sample is not None:
new_rr = (s - prev_sample) / SOURCE_HZ
rr_buf = [new_rr] + rr_buf[:-1]
lbl_buf = [model.class_names.index(cls)] + lbl_buf[:-1]
prev_sample = s
return out
def demo():
print("Loading HTF beat model from HF...")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = load_model(device)
print(f"Loaded {model.__class__.__name__} on {device}")
print(f"Classes: {model.class_names}")
print(f"Params: {sum(p.numel() for p in model.parameters()):,}")
# Synthetic example
print("\n--- single-beat demo (random input) ---")
fake_window = zscore(np.random.randn(2, WINDOW_SAMPLES).astype(np.float32))
fake_rr = np.array([0.85, 0.83, 0.87, 0.82, 0.85], dtype=np.float32)
fake_lbl = np.array([0, 0, 0, 0, 0], dtype=np.int64) # 5 prior normal beats
cls, prob = predict_beat(model, fake_window, fake_rr, fake_lbl, device=device)
print(f"prediction: {cls} ({prob:.1%})")
print("\n--- beat sequence demo ---")
fake_signal = np.random.randn(30 * SOURCE_HZ, 2).astype(np.float32)
fake_peaks = list(range(150, 30 * SOURCE_HZ - 150, 100))[:10]
preds = predict_beat_sequence(model, fake_signal, fake_peaks, device=device)
for i, (c, p) in enumerate(preds):
print(f" beat {i + 1} (sample {fake_peaks[i]}): {c} ({p:.1%})")
if __name__ == "__main__":
demo()