ecg_reconstruction / mae /model.py
PhurinutR's picture
accept used input mask
c5554ee
"""ECG masked autoencoder: patch embedding, visibility-restricted encoder, full decoder."""
from __future__ import annotations
from typing import Any
import torch
import torch.nn as nn
from torch import Tensor
from mae.config import MAEConfig
from mae.decoder import DecoderBlock
from mae.encoder import EncoderAttentionBlock, build_encoder_attn_bias
from mae.losses import masked_reconstruction_loss
from mae.stdm import sample_stdm_masks
class ECGDataMAE(nn.Module):
"""
Reconstruction-only MAE with STDM masking and visibility-restricted encoder.
"""
def __init__(self, config: MAEConfig):
super().__init__()
self.cfg = config
C, N, P = config.num_leads, config.num_patches, config.patch_size
L = C * N
d = config.d_model
self.patch_embed = nn.Linear(P, d)
self.pos_embed = nn.Parameter(torch.zeros(1, L, d))
nn.init.trunc_normal_(self.pos_embed, std=0.02)
self.encoder_blocks = nn.ModuleList(
[
EncoderAttentionBlock(d, config.n_heads, config.mlp_ratio, config.dropout)
for _ in range(config.encoder_layers)
]
)
self.decoder_embed = nn.Linear(d, d)
self.decoder_pos = nn.Parameter(torch.zeros(1, L, d))
nn.init.trunc_normal_(self.decoder_pos, std=0.02)
self.decoder_mask_token = nn.Parameter(torch.zeros(1, 1, d))
nn.init.trunc_normal_(self.decoder_mask_token, std=0.02)
self.decoder_blocks = nn.ModuleList(
[
DecoderBlock(d, config.n_heads, config.mlp_ratio, config.dropout)
for _ in range(config.decoder_layers)
]
)
self.decoder_norm = nn.LayerNorm(d)
self.head = nn.Linear(d, P)
self.apply(self._init_weights)
def _init_weights(self, m: nn.Module) -> None:
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def patchify(self, x: torch.Tensor) -> torch.Tensor:
"""(B, C, T) -> (B, C, N, P)"""
B, C, T = x.shape
N, P = self.cfg.num_patches, self.cfg.patch_size
if C != self.cfg.num_leads or T != N * P:
raise ValueError(f"expected (B,{self.cfg.num_leads},{N*P}), got {x.shape}")
return x.reshape(B, C, N, P)
@staticmethod
def _validate_external_masks(
V: Tensor, M: Tensor, D: Tensor, *, B: int, C: int, N: int
) -> None:
for name, t in ("V", V), ("M", M), ("D", D):
if t.shape != (B, C, N):
raise ValueError(f"{name} must have shape (B, C, N)={(B, C, N)}, got {tuple(t.shape)}")
if t.dtype != torch.bool:
raise ValueError(f"{name} must be bool, got {t.dtype}")
if (V & M).any() or (V & D).any() or (M & D).any():
raise ValueError("V, M, D must be mutually exclusive (no cell in more than one set).")
if not (V | M | D).all():
raise ValueError("V, M, D must partition the grid: each (b, c, n) in exactly one of V, M, D.")
def forward(
self,
x: torch.Tensor,
*,
V: Tensor | None = None,
M: Tensor | None = None,
D: Tensor | None = None,
generator: torch.Generator | None = None,
return_loss: bool = True,
) -> dict[str, Any]:
"""
If ``V``, ``M``, and ``D`` are all omitted, sample STDM masks (training default).
If provided, all three must be given: boolean ``(B, C, N)`` masks partitioning the patch grid.
"""
B, C, T = x.shape
N = self.cfg.num_patches
patches = self.patchify(x)
any_mask = V is not None or M is not None or D is not None
if any_mask:
if V is None or M is None or D is None:
raise ValueError("Provide all three of V, M, D or omit all to sample STDM.")
self._validate_external_masks(V, M, D, B=B, C=C, N=N)
else:
V, M, D = sample_stdm_masks(
B,
C,
N,
p_time=self.cfg.p_time,
p_lead=self.cfg.p_lead,
k_visible=self.cfg.num_visible_leads,
device=x.device,
generator=generator,
)
L = C * N
P = self.cfg.patch_size
flat = patches.reshape(B, L, P)
tok = self.patch_embed(flat) + self.pos_embed
v_flat = V.reshape(B, L)
tok = tok * v_flat.unsqueeze(-1).to(tok.dtype)
attn_bias = build_encoder_attn_bias(v_flat)
h = tok
for blk in self.encoder_blocks:
h = blk(h, v_flat, attn_bias)
dec_in = self.decoder_embed(h)
dec_in = dec_in * v_flat.unsqueeze(-1).to(dec_in.dtype)
dec_in = dec_in + self.decoder_mask_token * (~v_flat).unsqueeze(-1).to(dec_in.dtype)
dec_in = dec_in + self.decoder_pos
z = dec_in
for blk in self.decoder_blocks:
z = blk(z)
z = self.decoder_norm(z)
pred_flat = self.head(z)
pred = pred_flat.reshape(B, C, N, P)
out: dict[str, Any] = {"pred": pred, "V": V, "M": M, "D": D, "patches": patches}
if return_loss:
out["loss"] = masked_reconstruction_loss(pred, patches, M)
return out