#!/usr/bin/env python3 # SPDX-License-Identifier: MIT """Inference example for the LTAF ECG rhythm classifier. Two modes: - Single-window: pass a (B, 2, 1280) z-scored 10 s @ 128 Hz tensor. - TTA-7 (recommended, +4 pp F1): pass a longer signal slice and the function will pull 7 random 10 s windows from it and soft-vote. Usage: .venv/bin/python inference.py """ from __future__ import annotations from pathlib import Path from typing import Tuple import numpy as np import torch import torch.nn.functional as F from huggingface_hub import hf_hub_download from model import RHYTHM_CLASS_NAMES, RhythmResNet1D WINDOW_SECONDS = 10 SOURCE_HZ = 128 WINDOW_SAMPLES = WINDOW_SECONDS * SOURCE_HZ # 1280 def load_model(device: str = "cpu") -> RhythmResNet1D: """Download the checkpoint from HF and load it.""" ckpt_path = hf_hub_download( "rmxjck/ltaf-ecg-rhythm-classifier", "best_classifier.pt", ) return RhythmResNet1D.load(ckpt_path, device=device) def zscore(window: np.ndarray) -> np.ndarray: """Per-channel z-score a (C, L) array.""" mean = window.mean(axis=-1, keepdims=True) std = window.std(axis=-1, keepdims=True) return ((window - mean) / (std + 1e-6)).astype(np.float32, copy=False) def predict_single( model: RhythmResNet1D, window: np.ndarray, device: str = "cpu", ) -> Tuple[str, float]: """Predict on one (2, 1280) z-scored window. Returns (class_name, prob).""" if window.shape != (2, WINDOW_SAMPLES): raise ValueError(f"Expected (2, {WINDOW_SAMPLES}), got {window.shape}") x = torch.from_numpy(window).float().unsqueeze(0).to(device) with torch.no_grad(): probs = F.softmax(model(x), dim=-1)[0] idx = int(probs.argmax().item()) return model.class_names[idx], float(probs[idx].item()) def predict_tta( model: RhythmResNet1D, long_signal: np.ndarray, n_views: int = 7, device: str = "cpu", seed: int = 42, ) -> Tuple[str, float, np.ndarray]: """TTA-soft-voting prediction over a longer (2, L) signal. Samples ``n_views`` random 10 s windows from ``long_signal`` (L >= 1280), z-scores each independently, runs them through the model, and averages the softmax probabilities. Returns (class_name, prob, full_probs) where full_probs is shape (6,). """ n_ch, n_samples = long_signal.shape if n_ch != 2: raise ValueError(f"Expected 2-channel signal, got {n_ch}") if n_samples < WINDOW_SAMPLES: raise ValueError(f"Need at least {WINDOW_SAMPLES} samples, got {n_samples}") rng = np.random.default_rng(seed) starts = rng.integers(0, n_samples - WINDOW_SAMPLES + 1, size=n_views) accum = torch.zeros(model.num_classes, device=device) for s in starts: window = zscore(long_signal[:, s:s + WINDOW_SAMPLES]) x = torch.from_numpy(window).float().unsqueeze(0).to(device) with torch.no_grad(): probs = F.softmax(model(x), dim=-1)[0] accum += probs probs_avg = accum / n_views idx = int(probs_avg.argmax().item()) return model.class_names[idx], float(probs_avg[idx].item()), probs_avg.cpu().numpy() def demo(): print("Loading model from HF...") device = "cuda" if torch.cuda.is_available() else "cpu" model = load_model(device) print(f"Loaded {model.__class__.__name__} on {device}") print(f"Classes: {model.class_names}") print(f"Params: {sum(p.numel() for p in model.parameters()):,}") # Synthetic example: random noise (will get garbage prediction). print("\n--- single-window demo (random input) ---") fake_window = zscore(np.random.randn(2, WINDOW_SAMPLES).astype(np.float32)) cls, prob = predict_single(model, fake_window, device=device) print(f"prediction: {cls} ({prob:.1%})") print("\n--- TTA-7 demo (random 30 s input) ---") fake_long = np.random.randn(2, 30 * SOURCE_HZ).astype(np.float32) cls, prob, full = predict_tta(model, fake_long, n_views=7, device=device) print(f"prediction: {cls} ({prob:.1%})") print(f"all class probs: {dict(zip(model.class_names, [round(p, 3) for p in full.tolist()]))}") if __name__ == "__main__": demo()