Spaces:
Sleeping
Sleeping
| """Spatio-Temporal Dual Masking (STDM) from CoRe-ECG — reconstruction branch only.""" | |
| from __future__ import annotations | |
| import torch | |
| def sample_stdm_masks( | |
| batch_size: int, | |
| num_leads: int, | |
| num_patches: int, | |
| *, | |
| p_time: float, | |
| p_lead: float, | |
| k_visible: int, | |
| device: torch.device, | |
| generator: torch.Generator | None = None, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Sample visible (V), reconstruction-target (M), and dropped (D) masks. | |
| Each mask is ``(B, C, N)`` bool with ``V | M | D`` partitioning the grid (exclusive). | |
| Parameters | |
| ---------- | |
| k_visible : number of leads kept visible in partial (non–full-temporal-mask) mode. | |
| p_time : Bernoulli prob for full temporal mask at a patch index. | |
| p_lead : Bernoulli prob for *dropping* a lead from supervision among hidden leads. | |
| """ | |
| if not 0 <= p_time <= 1 or not 0 <= p_lead <= 1: | |
| raise ValueError("p_time and p_lead must be in [0, 1]") | |
| if k_visible < 1 or k_visible > num_leads: | |
| raise ValueError("k_visible must be in [1, num_leads]") | |
| C, N = num_leads, num_patches | |
| V = torch.zeros(batch_size, C, N, dtype=torch.bool, device=device) | |
| M = torch.zeros_like(V) | |
| D = torch.zeros_like(V) | |
| # Per (b, n): full temporal mask vs partial | |
| u_time = torch.rand(batch_size, N, device=device, generator=generator) | |
| full_time = u_time < p_time | |
| for b in range(batch_size): | |
| for n in range(N): | |
| if full_time[b, n]: | |
| M[b, :, n] = True | |
| continue | |
| # Partial: choose k_visible leads uniformly | |
| perm = torch.randperm(C, device=device, generator=generator) | |
| vis = perm[:k_visible] | |
| hid = perm[k_visible:] | |
| V[b, vis, n] = True | |
| if hid.numel() == 0: | |
| continue | |
| u = torch.rand(hid.numel(), device=device, generator=generator) | |
| dropped = hid[u < p_lead] | |
| masked = hid[u >= p_lead] | |
| D[b, dropped, n] = True | |
| M[b, masked, n] = True | |
| return V, M, D | |