# SPDX-License-Identifier: MIT """Self-contained HTF beat classifier for LTAF (N / A / V). Three parallel streams (time + frequency + history) → fused → MLP head. Inspired by alberto-rota/PAC-PVC-Beat-Classifier-for-ECGs (HTF ensemble). Vendored from rmxjck/TSLM-Arena (src/models/ts_llm/ecg_beat_htf.py). """ from __future__ import annotations from pathlib import Path from typing import List import torch import torch.nn as nn import torch.nn.functional as F BEAT_CLASS_NAMES = ["N", "A", "V"] class _ConvBlock(nn.Module): def __init__(self, in_ch: int, out_ch: int, kernel: int = 7, pool: int = 2, dropout: float = 0.0): super().__init__() self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=kernel, padding=kernel // 2, bias=False) self.bn = nn.BatchNorm1d(out_ch) self.act = nn.ReLU(inplace=True) self.pool = nn.MaxPool1d(pool) if pool > 1 else nn.Identity() self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.drop(self.pool(self.act(self.bn(self.conv(x))))) class _CNNTrunk(nn.Module): """Stack of conv blocks ending in adaptive average pool.""" def __init__(self, in_channels: int, base_channels: int = 32, n_blocks: int = 5, dropout: float = 0.1): super().__init__() layers = [] ch = in_channels out_ch = base_channels for _ in range(n_blocks): layers.append(_ConvBlock(ch, out_ch, kernel=7, pool=2, dropout=dropout)) ch = out_ch out_ch = min(out_ch * 2, 256) self.net = nn.Sequential(*layers) self.pool = nn.AdaptiveAvgPool1d(1) self.out_channels = ch def forward(self, x: torch.Tensor) -> torch.Tensor: return self.pool(self.net(x)).squeeze(-1) class EcgBeatHTFClassifier(nn.Module): """HTF ensemble: time + frequency + history → MLP head.""" def __init__( self, num_classes: int = 3, class_names: List[str] = BEAT_CLASS_NAMES, n_channels: int = 2, window_samples: int = 256, history_k: int = 5, history_use_labels: bool = True, time_base_channels: int = 32, freq_base_channels: int = 32, head_hidden: int = 128, dropout: float = 0.1, ): super().__init__() assert len(class_names) == num_classes self.num_classes = num_classes self.class_names = list(class_names) self.n_channels = n_channels self.window_samples = window_samples self.history_k = history_k self.history_use_labels = history_use_labels self.time_trunk = _CNNTrunk(n_channels, time_base_channels, n_blocks=5, dropout=dropout) self.freq_trunk = _CNNTrunk(n_channels, freq_base_channels, n_blocks=4, dropout=dropout) history_in = history_k if history_use_labels: history_in += history_k * num_classes self.history_net = nn.Sequential( nn.Linear(history_in, 64), nn.ReLU(inplace=True), nn.Dropout(dropout), nn.Linear(64, 64), nn.ReLU(inplace=True), ) fused_dim = self.time_trunk.out_channels + self.freq_trunk.out_channels + 64 self.head = nn.Sequential( nn.Linear(fused_dim, head_hidden), nn.ReLU(inplace=True), nn.Dropout(dropout), nn.Linear(head_hidden, num_classes), ) def _compute_freq(self, x: torch.Tensor) -> torch.Tensor: spec = torch.fft.rfft(x, dim=-1) return torch.log1p(spec.abs()) def forward( self, x_time: torch.Tensor, rr_history: torch.Tensor, label_history: torch.Tensor | None = None, ) -> torch.Tensor: """ Args: x_time: (B, 2, 256) — raw 2-lead ECG window centered on R-peak. rr_history: (B, K) — RR intervals to preceding K beats, in seconds. Set 0.0 for missing (record start). label_history: (B, K) int64 — preceding K beat labels (0=N, 1=A, 2=V). Use -1 for missing. Used iff `history_use_labels=True`. At inference time, can be filled with previous predictions (autoregressive) or zeros (recall slightly degrades). Returns: (B, num_classes) logits. """ time_feat = self.time_trunk(x_time) freq_feat = self.freq_trunk(self._compute_freq(x_time)) if self.history_use_labels: assert label_history is not None B, K = label_history.shape valid = (label_history >= 0).float().unsqueeze(-1) idx = label_history.clamp(min=0) one_hot = F.one_hot(idx, num_classes=self.num_classes).float() * valid hist_in = torch.cat([rr_history, one_hot.reshape(B, -1)], dim=-1) else: hist_in = rr_history hist_feat = self.history_net(hist_in) fused = torch.cat([time_feat, freq_feat, hist_feat], dim=-1) return self.head(fused) @classmethod def load(cls, path: str | Path, device: str = "cpu") -> "EcgBeatHTFClassifier": ckpt = torch.load(path, map_location=device, weights_only=False) model = cls( num_classes=ckpt["num_classes"], class_names=ckpt["class_names"], n_channels=ckpt["n_channels"], window_samples=ckpt["window_samples"], history_k=ckpt["history_k"], history_use_labels=ckpt["history_use_labels"], ) model.load_state_dict(ckpt["state_dict"]) model.to(device).eval() return model