File size: 4,129 Bytes
464d595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# SPDX-License-Identifier: MIT
"""Self-contained RhythmResNet1D for LTAF rhythm classification.

Vendored from rmxjck/TSLM-Arena (src/models/ts_llm/ecg_rhythm_scratch.py)
so the model can be loaded with no external project imports.
"""

from __future__ import annotations

from pathlib import Path
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F


RHYTHM_CLASS_NAMES = ["NSR", "AFIB", "SBR", "AB", "SVTA", "B"]


class _BasicBlock1D(nn.Module):
    """Two-conv residual block with optional stride-2 downsample."""

    def __init__(self, in_c: int, out_c: int, kernel: int = 7, stride: int = 1,
                 dropout: float = 0.1):
        super().__init__()
        pad = kernel // 2
        self.conv1 = nn.Conv1d(in_c, out_c, kernel_size=kernel, stride=stride,
                               padding=pad, bias=False)
        self.bn1 = nn.BatchNorm1d(out_c)
        self.conv2 = nn.Conv1d(out_c, out_c, kernel_size=kernel, stride=1,
                               padding=pad, bias=False)
        self.bn2 = nn.BatchNorm1d(out_c)
        self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        if stride != 1 or in_c != out_c:
            self.proj = nn.Sequential(
                nn.Conv1d(in_c, out_c, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(out_c),
            )
        else:
            self.proj = nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = self.proj(x)
        h = F.relu(self.bn1(self.conv1(x)), inplace=True)
        h = self.drop(h)
        h = self.bn2(self.conv2(h))
        return F.relu(h + identity, inplace=True)


class RhythmResNet1D(nn.Module):
    """1D ResNet — stem + 4 stages, each stage halves time and doubles channels."""

    def __init__(
        self,
        num_classes: int = 6,
        class_names: List[str] = RHYTHM_CLASS_NAMES,
        n_channels: int = 2,
        base_channels: int = 64,
        blocks_per_stage: int = 2,
        stem_kernel: int = 15,
        block_kernel: int = 7,
        dropout: float = 0.2,
    ):
        super().__init__()
        assert len(class_names) == num_classes
        self.num_classes = num_classes
        self.class_names = list(class_names)
        self.n_channels = n_channels
        self.base_channels = base_channels
        self.blocks_per_stage = blocks_per_stage

        self.stem = nn.Sequential(
            nn.Conv1d(n_channels, base_channels, kernel_size=stem_kernel,
                      stride=2, padding=stem_kernel // 2, bias=False),
            nn.BatchNorm1d(base_channels),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(2),
        )
        stages = []
        in_c = base_channels
        out_c = base_channels
        for s in range(4):
            for b in range(blocks_per_stage):
                stride = 2 if (b == 0 and s > 0) else 1
                stages.append(_BasicBlock1D(in_c, out_c, kernel=block_kernel,
                                            stride=stride, dropout=dropout))
                in_c = out_c
            out_c = min(out_c * 2, 512)
        self.stages = nn.Sequential(*stages)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Sequential(
            nn.Linear(in_c, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(128, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.stem(x)
        h = self.stages(h)
        feat = self.pool(h).squeeze(-1)
        return self.head(feat)

    @classmethod
    def load(cls, path: str | Path, device: str = "cpu") -> "RhythmResNet1D":
        ckpt = torch.load(path, map_location=device, weights_only=False)
        model = cls(
            num_classes=ckpt["num_classes"], class_names=ckpt["class_names"],
            n_channels=ckpt["n_channels"],
            base_channels=ckpt.get("base_channels", 64),
            blocks_per_stage=ckpt.get("blocks_per_stage", 2),
        )
        model.load_state_dict(ckpt["state_dict"])
        model.to(device).eval()
        return model