#!/usr/bin/env python3 # SPDX-License-Identifier: MIT """Inference example for the LTAF HTF beat classifier (N / A / V).""" from __future__ import annotations from typing import List, Tuple import numpy as np import torch import torch.nn.functional as F from huggingface_hub import hf_hub_download from model import EcgBeatHTFClassifier, BEAT_CLASS_NAMES SOURCE_HZ = 128 WINDOW_SAMPLES = 256 # 2 s def load_model(device: str = "cpu") -> EcgBeatHTFClassifier: ckpt_path = hf_hub_download("rmxjck/ltaf-ecg-beats-classifier-htf", "best_classifier.pt") return EcgBeatHTFClassifier.load(ckpt_path, device=device) def zscore(window: np.ndarray) -> np.ndarray: mean = window.mean(axis=-1, keepdims=True) std = window.std(axis=-1, keepdims=True) return ((window - mean) / (std + 1e-6)).astype(np.float32, copy=False) def predict_beat( model: EcgBeatHTFClassifier, window: np.ndarray, rr_history: np.ndarray, label_history: np.ndarray | None = None, device: str = "cpu", ) -> Tuple[str, float]: """Predict on one beat. Args: window: (2, 256) z-scored 2 s window centered on the R-peak. rr_history: (5,) RR intervals (seconds) to preceding 5 beats. Use 0.0 for unknown / record start. label_history: (5,) int — preceding 5 beat labels (0=N, 1=A, 2=V). Use -1 for unknown. If None, model uses zero labels. device: torch device. """ if window.shape != (2, WINDOW_SAMPLES): raise ValueError(f"window must be (2, {WINDOW_SAMPLES}), got {window.shape}") if rr_history.shape != (model.history_k,): raise ValueError(f"rr_history must be ({model.history_k},), got {rr_history.shape}") if label_history is None: label_history = np.full(model.history_k, -1, dtype=np.int64) x_time = torch.from_numpy(window).float().unsqueeze(0).to(device) rr = torch.from_numpy(rr_history).float().unsqueeze(0).to(device) lbl = torch.from_numpy(label_history).long().unsqueeze(0).to(device) with torch.no_grad(): probs = F.softmax(model(x_time, rr, lbl), dim=-1)[0] idx = int(probs.argmax().item()) return model.class_names[idx], float(probs[idx].item()) def predict_beat_sequence( model: EcgBeatHTFClassifier, signal: np.ndarray, r_peak_samples: List[int], device: str = "cpu", ) -> List[Tuple[str, float]]: """Predict labels for a sequence of beats in order. Uses each beat's *predicted* label as the history input for subsequent beats (autoregressive), since the true labels aren't known at inference. Args: signal: (N, 2) raw float32 signal at 128 Hz. r_peak_samples: list of R-peak sample indices. device: torch device. Returns: list of (class_name, prob) per beat in order. """ half = WINDOW_SAMPLES // 2 K = model.history_k n_classes = model.num_classes n_total = signal.shape[0] rr_buf = [0.0] * K lbl_buf = [-1] * K out = [] prev_sample = None for s in r_peak_samples: ws = max(0, s - half) we = ws + WINDOW_SAMPLES if we > n_total: break window = signal[ws:we].T.astype(np.float32, copy=False) window = zscore(window) if window.shape[1] < WINDOW_SAMPLES: window = np.pad(window, ((0, 0), (0, WINDOW_SAMPLES - window.shape[1]))) rr_arr = np.array(rr_buf, dtype=np.float32) lbl_arr = np.array(lbl_buf, dtype=np.int64) cls, prob = predict_beat(model, window, rr_arr, lbl_arr, device=device) out.append((cls, prob)) # Update history buffers for the next beat. if prev_sample is not None: new_rr = (s - prev_sample) / SOURCE_HZ rr_buf = [new_rr] + rr_buf[:-1] lbl_buf = [model.class_names.index(cls)] + lbl_buf[:-1] prev_sample = s return out def demo(): print("Loading HTF beat model from HF...") device = "cuda" if torch.cuda.is_available() else "cpu" model = load_model(device) print(f"Loaded {model.__class__.__name__} on {device}") print(f"Classes: {model.class_names}") print(f"Params: {sum(p.numel() for p in model.parameters()):,}") # Synthetic example print("\n--- single-beat demo (random input) ---") fake_window = zscore(np.random.randn(2, WINDOW_SAMPLES).astype(np.float32)) fake_rr = np.array([0.85, 0.83, 0.87, 0.82, 0.85], dtype=np.float32) fake_lbl = np.array([0, 0, 0, 0, 0], dtype=np.int64) # 5 prior normal beats cls, prob = predict_beat(model, fake_window, fake_rr, fake_lbl, device=device) print(f"prediction: {cls} ({prob:.1%})") print("\n--- beat sequence demo ---") fake_signal = np.random.randn(30 * SOURCE_HZ, 2).astype(np.float32) fake_peaks = list(range(150, 30 * SOURCE_HZ - 150, 100))[:10] preds = predict_beat_sequence(model, fake_signal, fake_peaks, device=device) for i, (c, p) in enumerate(preds): print(f" beat {i + 1} (sample {fake_peaks[i]}): {c} ({p:.1%})") if __name__ == "__main__": demo()