PhurinutR's picture
followed CoRe-ECG idea
7a63dcf
"""Spatio-Temporal Dual Masking (STDM) from CoRe-ECG — reconstruction branch only."""
from __future__ import annotations
import torch
def sample_stdm_masks(
batch_size: int,
num_leads: int,
num_patches: int,
*,
p_time: float,
p_lead: float,
k_visible: int,
device: torch.device,
generator: torch.Generator | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sample visible (V), reconstruction-target (M), and dropped (D) masks.
Each mask is ``(B, C, N)`` bool with ``V | M | D`` partitioning the grid (exclusive).
Parameters
----------
k_visible : number of leads kept visible in partial (non–full-temporal-mask) mode.
p_time : Bernoulli prob for full temporal mask at a patch index.
p_lead : Bernoulli prob for *dropping* a lead from supervision among hidden leads.
"""
if not 0 <= p_time <= 1 or not 0 <= p_lead <= 1:
raise ValueError("p_time and p_lead must be in [0, 1]")
if k_visible < 1 or k_visible > num_leads:
raise ValueError("k_visible must be in [1, num_leads]")
C, N = num_leads, num_patches
V = torch.zeros(batch_size, C, N, dtype=torch.bool, device=device)
M = torch.zeros_like(V)
D = torch.zeros_like(V)
# Per (b, n): full temporal mask vs partial
u_time = torch.rand(batch_size, N, device=device, generator=generator)
full_time = u_time < p_time
for b in range(batch_size):
for n in range(N):
if full_time[b, n]:
M[b, :, n] = True
continue
# Partial: choose k_visible leads uniformly
perm = torch.randperm(C, device=device, generator=generator)
vis = perm[:k_visible]
hid = perm[k_visible:]
V[b, vis, n] = True
if hid.numel() == 0:
continue
u = torch.rand(hid.numel(), device=device, generator=generator)
dropped = hid[u < p_lead]
masked = hid[u >= p_lead]
D[b, dropped, n] = True
M[b, masked, n] = True
return V, M, D