"""Spatio-Temporal Dual Masking (STDM) from CoRe-ECG — reconstruction branch only.""" from __future__ import annotations import torch def sample_stdm_masks( batch_size: int, num_leads: int, num_patches: int, *, p_time: float, p_lead: float, k_visible: int, device: torch.device, generator: torch.Generator | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Sample visible (V), reconstruction-target (M), and dropped (D) masks. Each mask is ``(B, C, N)`` bool with ``V | M | D`` partitioning the grid (exclusive). Parameters ---------- k_visible : number of leads kept visible in partial (non–full-temporal-mask) mode. p_time : Bernoulli prob for full temporal mask at a patch index. p_lead : Bernoulli prob for *dropping* a lead from supervision among hidden leads. """ if not 0 <= p_time <= 1 or not 0 <= p_lead <= 1: raise ValueError("p_time and p_lead must be in [0, 1]") if k_visible < 1 or k_visible > num_leads: raise ValueError("k_visible must be in [1, num_leads]") C, N = num_leads, num_patches V = torch.zeros(batch_size, C, N, dtype=torch.bool, device=device) M = torch.zeros_like(V) D = torch.zeros_like(V) # Per (b, n): full temporal mask vs partial u_time = torch.rand(batch_size, N, device=device, generator=generator) full_time = u_time < p_time for b in range(batch_size): for n in range(N): if full_time[b, n]: M[b, :, n] = True continue # Partial: choose k_visible leads uniformly perm = torch.randperm(C, device=device, generator=generator) vis = perm[:k_visible] hid = perm[k_visible:] V[b, vis, n] = True if hid.numel() == 0: continue u = torch.rand(hid.numel(), device=device, generator=generator) dropped = hid[u < p_lead] masked = hid[u >= p_lead] D[b, dropped, n] = True M[b, masked, n] = True return V, M, D