| |
| """Self-contained HTF beat classifier for LTAF (N / A / V). |
| |
| Three parallel streams (time + frequency + history) → fused → MLP head. |
| Inspired by alberto-rota/PAC-PVC-Beat-Classifier-for-ECGs (HTF ensemble). |
| |
| Vendored from rmxjck/TSLM-Arena (src/models/ts_llm/ecg_beat_htf.py). |
| """ |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import List |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| BEAT_CLASS_NAMES = ["N", "A", "V"] |
|
|
|
|
| class _ConvBlock(nn.Module): |
| def __init__(self, in_ch: int, out_ch: int, kernel: int = 7, |
| pool: int = 2, dropout: float = 0.0): |
| super().__init__() |
| self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=kernel, |
| padding=kernel // 2, bias=False) |
| self.bn = nn.BatchNorm1d(out_ch) |
| self.act = nn.ReLU(inplace=True) |
| self.pool = nn.MaxPool1d(pool) if pool > 1 else nn.Identity() |
| self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.drop(self.pool(self.act(self.bn(self.conv(x))))) |
|
|
|
|
| class _CNNTrunk(nn.Module): |
| """Stack of conv blocks ending in adaptive average pool.""" |
|
|
| def __init__(self, in_channels: int, base_channels: int = 32, |
| n_blocks: int = 5, dropout: float = 0.1): |
| super().__init__() |
| layers = [] |
| ch = in_channels |
| out_ch = base_channels |
| for _ in range(n_blocks): |
| layers.append(_ConvBlock(ch, out_ch, kernel=7, pool=2, dropout=dropout)) |
| ch = out_ch |
| out_ch = min(out_ch * 2, 256) |
| self.net = nn.Sequential(*layers) |
| self.pool = nn.AdaptiveAvgPool1d(1) |
| self.out_channels = ch |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.pool(self.net(x)).squeeze(-1) |
|
|
|
|
| class EcgBeatHTFClassifier(nn.Module): |
| """HTF ensemble: time + frequency + history → MLP head.""" |
|
|
| def __init__( |
| self, |
| num_classes: int = 3, |
| class_names: List[str] = BEAT_CLASS_NAMES, |
| n_channels: int = 2, |
| window_samples: int = 256, |
| history_k: int = 5, |
| history_use_labels: bool = True, |
| time_base_channels: int = 32, |
| freq_base_channels: int = 32, |
| head_hidden: int = 128, |
| dropout: float = 0.1, |
| ): |
| super().__init__() |
| assert len(class_names) == num_classes |
| self.num_classes = num_classes |
| self.class_names = list(class_names) |
| self.n_channels = n_channels |
| self.window_samples = window_samples |
| self.history_k = history_k |
| self.history_use_labels = history_use_labels |
|
|
| self.time_trunk = _CNNTrunk(n_channels, time_base_channels, |
| n_blocks=5, dropout=dropout) |
| self.freq_trunk = _CNNTrunk(n_channels, freq_base_channels, |
| n_blocks=4, dropout=dropout) |
|
|
| history_in = history_k |
| if history_use_labels: |
| history_in += history_k * num_classes |
| self.history_net = nn.Sequential( |
| nn.Linear(history_in, 64), nn.ReLU(inplace=True), |
| nn.Dropout(dropout), |
| nn.Linear(64, 64), nn.ReLU(inplace=True), |
| ) |
|
|
| fused_dim = self.time_trunk.out_channels + self.freq_trunk.out_channels + 64 |
| self.head = nn.Sequential( |
| nn.Linear(fused_dim, head_hidden), |
| nn.ReLU(inplace=True), |
| nn.Dropout(dropout), |
| nn.Linear(head_hidden, num_classes), |
| ) |
|
|
| def _compute_freq(self, x: torch.Tensor) -> torch.Tensor: |
| spec = torch.fft.rfft(x, dim=-1) |
| return torch.log1p(spec.abs()) |
|
|
| def forward( |
| self, |
| x_time: torch.Tensor, |
| rr_history: torch.Tensor, |
| label_history: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| x_time: (B, 2, 256) — raw 2-lead ECG window centered on R-peak. |
| rr_history: (B, K) — RR intervals to preceding K beats, in |
| seconds. Set 0.0 for missing (record start). |
| label_history: (B, K) int64 — preceding K beat labels (0=N, 1=A, |
| 2=V). Use -1 for missing. Used iff `history_use_labels=True`. |
| At inference time, can be filled with previous predictions |
| (autoregressive) or zeros (recall slightly degrades). |
| |
| Returns: |
| (B, num_classes) logits. |
| """ |
| time_feat = self.time_trunk(x_time) |
| freq_feat = self.freq_trunk(self._compute_freq(x_time)) |
|
|
| if self.history_use_labels: |
| assert label_history is not None |
| B, K = label_history.shape |
| valid = (label_history >= 0).float().unsqueeze(-1) |
| idx = label_history.clamp(min=0) |
| one_hot = F.one_hot(idx, num_classes=self.num_classes).float() * valid |
| hist_in = torch.cat([rr_history, one_hot.reshape(B, -1)], dim=-1) |
| else: |
| hist_in = rr_history |
| hist_feat = self.history_net(hist_in) |
|
|
| fused = torch.cat([time_feat, freq_feat, hist_feat], dim=-1) |
| return self.head(fused) |
|
|
| @classmethod |
| def load(cls, path: str | Path, device: str = "cpu") -> "EcgBeatHTFClassifier": |
| ckpt = torch.load(path, map_location=device, weights_only=False) |
| model = cls( |
| num_classes=ckpt["num_classes"], |
| class_names=ckpt["class_names"], |
| n_channels=ckpt["n_channels"], |
| window_samples=ckpt["window_samples"], |
| history_k=ckpt["history_k"], |
| history_use_labels=ckpt["history_use_labels"], |
| ) |
| model.load_state_dict(ckpt["state_dict"]) |
| model.to(device).eval() |
| return model |
|
|