Spaces:
Sleeping
Sleeping
| """ECG masked autoencoder: patch embedding, visibility-restricted encoder, full decoder.""" | |
| from __future__ import annotations | |
| from typing import Any | |
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor | |
| from mae.config import MAEConfig | |
| from mae.decoder import DecoderBlock | |
| from mae.encoder import EncoderAttentionBlock, build_encoder_attn_bias | |
| from mae.losses import masked_reconstruction_loss | |
| from mae.stdm import sample_stdm_masks | |
| class ECGDataMAE(nn.Module): | |
| """ | |
| Reconstruction-only MAE with STDM masking and visibility-restricted encoder. | |
| """ | |
| def __init__(self, config: MAEConfig): | |
| super().__init__() | |
| self.cfg = config | |
| C, N, P = config.num_leads, config.num_patches, config.patch_size | |
| L = C * N | |
| d = config.d_model | |
| self.patch_embed = nn.Linear(P, d) | |
| self.pos_embed = nn.Parameter(torch.zeros(1, L, d)) | |
| nn.init.trunc_normal_(self.pos_embed, std=0.02) | |
| self.encoder_blocks = nn.ModuleList( | |
| [ | |
| EncoderAttentionBlock(d, config.n_heads, config.mlp_ratio, config.dropout) | |
| for _ in range(config.encoder_layers) | |
| ] | |
| ) | |
| self.decoder_embed = nn.Linear(d, d) | |
| self.decoder_pos = nn.Parameter(torch.zeros(1, L, d)) | |
| nn.init.trunc_normal_(self.decoder_pos, std=0.02) | |
| self.decoder_mask_token = nn.Parameter(torch.zeros(1, 1, d)) | |
| nn.init.trunc_normal_(self.decoder_mask_token, std=0.02) | |
| self.decoder_blocks = nn.ModuleList( | |
| [ | |
| DecoderBlock(d, config.n_heads, config.mlp_ratio, config.dropout) | |
| for _ in range(config.decoder_layers) | |
| ] | |
| ) | |
| self.decoder_norm = nn.LayerNorm(d) | |
| self.head = nn.Linear(d, P) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m: nn.Module) -> None: | |
| if isinstance(m, nn.Linear): | |
| nn.init.trunc_normal_(m.weight, std=0.02) | |
| if m.bias is not None: | |
| nn.init.zeros_(m.bias) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.ones_(m.weight) | |
| nn.init.zeros_(m.bias) | |
| def patchify(self, x: torch.Tensor) -> torch.Tensor: | |
| """(B, C, T) -> (B, C, N, P)""" | |
| B, C, T = x.shape | |
| N, P = self.cfg.num_patches, self.cfg.patch_size | |
| if C != self.cfg.num_leads or T != N * P: | |
| raise ValueError(f"expected (B,{self.cfg.num_leads},{N*P}), got {x.shape}") | |
| return x.reshape(B, C, N, P) | |
| def _validate_external_masks( | |
| V: Tensor, M: Tensor, D: Tensor, *, B: int, C: int, N: int | |
| ) -> None: | |
| for name, t in ("V", V), ("M", M), ("D", D): | |
| if t.shape != (B, C, N): | |
| raise ValueError(f"{name} must have shape (B, C, N)={(B, C, N)}, got {tuple(t.shape)}") | |
| if t.dtype != torch.bool: | |
| raise ValueError(f"{name} must be bool, got {t.dtype}") | |
| if (V & M).any() or (V & D).any() or (M & D).any(): | |
| raise ValueError("V, M, D must be mutually exclusive (no cell in more than one set).") | |
| if not (V | M | D).all(): | |
| raise ValueError("V, M, D must partition the grid: each (b, c, n) in exactly one of V, M, D.") | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| *, | |
| V: Tensor | None = None, | |
| M: Tensor | None = None, | |
| D: Tensor | None = None, | |
| generator: torch.Generator | None = None, | |
| return_loss: bool = True, | |
| ) -> dict[str, Any]: | |
| """ | |
| If ``V``, ``M``, and ``D`` are all omitted, sample STDM masks (training default). | |
| If provided, all three must be given: boolean ``(B, C, N)`` masks partitioning the patch grid. | |
| """ | |
| B, C, T = x.shape | |
| N = self.cfg.num_patches | |
| patches = self.patchify(x) | |
| any_mask = V is not None or M is not None or D is not None | |
| if any_mask: | |
| if V is None or M is None or D is None: | |
| raise ValueError("Provide all three of V, M, D or omit all to sample STDM.") | |
| self._validate_external_masks(V, M, D, B=B, C=C, N=N) | |
| else: | |
| V, M, D = sample_stdm_masks( | |
| B, | |
| C, | |
| N, | |
| p_time=self.cfg.p_time, | |
| p_lead=self.cfg.p_lead, | |
| k_visible=self.cfg.num_visible_leads, | |
| device=x.device, | |
| generator=generator, | |
| ) | |
| L = C * N | |
| P = self.cfg.patch_size | |
| flat = patches.reshape(B, L, P) | |
| tok = self.patch_embed(flat) + self.pos_embed | |
| v_flat = V.reshape(B, L) | |
| tok = tok * v_flat.unsqueeze(-1).to(tok.dtype) | |
| attn_bias = build_encoder_attn_bias(v_flat) | |
| h = tok | |
| for blk in self.encoder_blocks: | |
| h = blk(h, v_flat, attn_bias) | |
| dec_in = self.decoder_embed(h) | |
| dec_in = dec_in * v_flat.unsqueeze(-1).to(dec_in.dtype) | |
| dec_in = dec_in + self.decoder_mask_token * (~v_flat).unsqueeze(-1).to(dec_in.dtype) | |
| dec_in = dec_in + self.decoder_pos | |
| z = dec_in | |
| for blk in self.decoder_blocks: | |
| z = blk(z) | |
| z = self.decoder_norm(z) | |
| pred_flat = self.head(z) | |
| pred = pred_flat.reshape(B, C, N, P) | |
| out: dict[str, Any] = {"pred": pred, "V": V, "M": M, "D": D, "patches": patches} | |
| if return_loss: | |
| out["loss"] = masked_reconstruction_loss(pred, patches, M) | |
| return out | |