| |
| |
| """Inference example for the LTAF HTF beat classifier (N / A / V).""" |
|
|
| from __future__ import annotations |
|
|
| from typing import List, Tuple |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from huggingface_hub import hf_hub_download |
|
|
| from model import EcgBeatHTFClassifier, BEAT_CLASS_NAMES |
|
|
|
|
| SOURCE_HZ = 128 |
| WINDOW_SAMPLES = 256 |
|
|
|
|
| def load_model(device: str = "cpu") -> EcgBeatHTFClassifier: |
| ckpt_path = hf_hub_download("rmxjck/ltaf-ecg-beats-classifier-htf", "best_classifier.pt") |
| return EcgBeatHTFClassifier.load(ckpt_path, device=device) |
|
|
|
|
| def zscore(window: np.ndarray) -> np.ndarray: |
| mean = window.mean(axis=-1, keepdims=True) |
| std = window.std(axis=-1, keepdims=True) |
| return ((window - mean) / (std + 1e-6)).astype(np.float32, copy=False) |
|
|
|
|
| def predict_beat( |
| model: EcgBeatHTFClassifier, |
| window: np.ndarray, |
| rr_history: np.ndarray, |
| label_history: np.ndarray | None = None, |
| device: str = "cpu", |
| ) -> Tuple[str, float]: |
| """Predict on one beat. |
| |
| Args: |
| window: (2, 256) z-scored 2 s window centered on the R-peak. |
| rr_history: (5,) RR intervals (seconds) to preceding 5 beats. |
| Use 0.0 for unknown / record start. |
| label_history: (5,) int — preceding 5 beat labels (0=N, 1=A, 2=V). |
| Use -1 for unknown. If None, model uses zero labels. |
| device: torch device. |
| """ |
| if window.shape != (2, WINDOW_SAMPLES): |
| raise ValueError(f"window must be (2, {WINDOW_SAMPLES}), got {window.shape}") |
| if rr_history.shape != (model.history_k,): |
| raise ValueError(f"rr_history must be ({model.history_k},), got {rr_history.shape}") |
| if label_history is None: |
| label_history = np.full(model.history_k, -1, dtype=np.int64) |
| x_time = torch.from_numpy(window).float().unsqueeze(0).to(device) |
| rr = torch.from_numpy(rr_history).float().unsqueeze(0).to(device) |
| lbl = torch.from_numpy(label_history).long().unsqueeze(0).to(device) |
| with torch.no_grad(): |
| probs = F.softmax(model(x_time, rr, lbl), dim=-1)[0] |
| idx = int(probs.argmax().item()) |
| return model.class_names[idx], float(probs[idx].item()) |
|
|
|
|
| def predict_beat_sequence( |
| model: EcgBeatHTFClassifier, |
| signal: np.ndarray, |
| r_peak_samples: List[int], |
| device: str = "cpu", |
| ) -> List[Tuple[str, float]]: |
| """Predict labels for a sequence of beats in order. |
| |
| Uses each beat's *predicted* label as the history input for subsequent |
| beats (autoregressive), since the true labels aren't known at inference. |
| |
| Args: |
| signal: (N, 2) raw float32 signal at 128 Hz. |
| r_peak_samples: list of R-peak sample indices. |
| device: torch device. |
| |
| Returns: |
| list of (class_name, prob) per beat in order. |
| """ |
| half = WINDOW_SAMPLES // 2 |
| K = model.history_k |
| n_classes = model.num_classes |
| n_total = signal.shape[0] |
| rr_buf = [0.0] * K |
| lbl_buf = [-1] * K |
| out = [] |
| prev_sample = None |
| for s in r_peak_samples: |
| ws = max(0, s - half) |
| we = ws + WINDOW_SAMPLES |
| if we > n_total: |
| break |
| window = signal[ws:we].T.astype(np.float32, copy=False) |
| window = zscore(window) |
| if window.shape[1] < WINDOW_SAMPLES: |
| window = np.pad(window, ((0, 0), (0, WINDOW_SAMPLES - window.shape[1]))) |
| rr_arr = np.array(rr_buf, dtype=np.float32) |
| lbl_arr = np.array(lbl_buf, dtype=np.int64) |
| cls, prob = predict_beat(model, window, rr_arr, lbl_arr, device=device) |
| out.append((cls, prob)) |
| |
| if prev_sample is not None: |
| new_rr = (s - prev_sample) / SOURCE_HZ |
| rr_buf = [new_rr] + rr_buf[:-1] |
| lbl_buf = [model.class_names.index(cls)] + lbl_buf[:-1] |
| prev_sample = s |
| return out |
|
|
|
|
| def demo(): |
| print("Loading HTF beat model from HF...") |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = load_model(device) |
| print(f"Loaded {model.__class__.__name__} on {device}") |
| print(f"Classes: {model.class_names}") |
| print(f"Params: {sum(p.numel() for p in model.parameters()):,}") |
|
|
| |
| print("\n--- single-beat demo (random input) ---") |
| fake_window = zscore(np.random.randn(2, WINDOW_SAMPLES).astype(np.float32)) |
| fake_rr = np.array([0.85, 0.83, 0.87, 0.82, 0.85], dtype=np.float32) |
| fake_lbl = np.array([0, 0, 0, 0, 0], dtype=np.int64) |
| cls, prob = predict_beat(model, fake_window, fake_rr, fake_lbl, device=device) |
| print(f"prediction: {cls} ({prob:.1%})") |
|
|
| print("\n--- beat sequence demo ---") |
| fake_signal = np.random.randn(30 * SOURCE_HZ, 2).astype(np.float32) |
| fake_peaks = list(range(150, 30 * SOURCE_HZ - 150, 100))[:10] |
| preds = predict_beat_sequence(model, fake_signal, fake_peaks, device=device) |
| for i, (c, p) in enumerate(preds): |
| print(f" beat {i + 1} (sample {fake_peaks[i]}): {c} ({p:.1%})") |
|
|
|
|
| if __name__ == "__main__": |
| demo() |
|
|