Spaces:
Running
Running
File size: 8,987 Bytes
7a63dcf 574782a 7bf9f76 0974ffe 574782a 0974ffe 7bf9f76 574782a 7bf9f76 574782a 0974ffe 7bf9f76 0974ffe 7bf9f76 574782a 0974ffe 574782a 0974ffe 7bf9f76 0974ffe 574782a 7bf9f76 574782a 7bf9f76 0974ffe 574782a 7bf9f76 0974ffe 7bf9f76 0974ffe 574782a 7bf9f76 0974ffe 7bf9f76 574782a 0974ffe 7bf9f76 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 | """Spatio-Temporal Dual Masking (STDM) from CoRe-ECG — reconstruction branch only."""
from __future__ import annotations
import torch
def sample_stdm_masks(
batch_size: int,
num_leads: int,
num_patches: int,
*,
p_time: float,
p_lead: float,
k_visible: int,
device: torch.device,
generator: torch.Generator | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sample visible (V), reconstruction-target (M), and dropped (D) masks.
Each mask is ``(B, C, N)`` bool with ``V | M | D`` partitioning the grid (exclusive).
Parameters
----------
k_visible : number of leads kept visible in partial (non–full-temporal-mask) mode.
p_time : Bernoulli prob for full temporal mask at a patch index.
p_lead : Bernoulli prob for *dropping* a lead from supervision among hidden leads.
"""
if not 0 <= p_time <= 1 or not 0 <= p_lead <= 1:
raise ValueError("p_time and p_lead must be in [0, 1]")
if k_visible < 1 or k_visible > num_leads:
raise ValueError("k_visible must be in [1, num_leads]")
C, N = num_leads, num_patches
V = torch.zeros(batch_size, C, N, dtype=torch.bool, device=device)
M = torch.zeros_like(V)
D = torch.zeros_like(V)
# Per (b, n): full temporal mask vs partial
u_time = torch.rand(batch_size, N, device=device, generator=generator)
full_time = u_time < p_time
for b in range(batch_size):
for n in range(N):
if full_time[b, n]:
M[b, :, n] = True
continue
# Partial: choose k_visible leads uniformly
perm = torch.randperm(C, device=device, generator=generator)
vis = perm[:k_visible]
hid = perm[k_visible:]
V[b, vis, n] = True
if hid.numel() == 0:
continue
u = torch.rand(hid.numel(), device=device, generator=generator)
dropped = hid[u < p_lead]
masked = hid[u >= p_lead]
D[b, dropped, n] = True
M[b, masked, n] = True
return V, M, D
def _sample_non_overlapping_band_starts(
num_patches: int,
band_width: int,
num_bands: int,
device: torch.device,
generator: torch.Generator | None,
) -> list[int]:
"""Sample ``num_bands`` start indices in ``[0, num_patches - band_width]`` with disjoint intervals."""
max_start = num_patches - band_width
if max_start < 0:
raise ValueError(
f"band_width={band_width} exceeds num_patches={num_patches}; cannot place temporal bands."
)
for _ in range(128):
starts = [
int(torch.randint(0, max_start + 1, (1,), device=device, generator=generator).item())
for _ in range(num_bands)
]
ok = True
for i in range(num_bands):
for j in range(i + 1, num_bands):
a0, a1 = starts[i], starts[i] + band_width
b0, b1 = starts[j], starts[j] + band_width
if not (a1 <= b0 or b1 <= a0):
ok = False
break
if not ok:
break
if ok:
return starts
# Deterministic fallback: pack bands from the left with a gap if needed.
gap = max(1, (max_start + 1 - num_bands * band_width) // max(1, num_bands - 1)) if num_bands > 1 else 0
starts = [min(i * (band_width + gap), max_start) for i in range(num_bands)]
return starts
def _simple_horizontal_masks(
batch_size: int,
num_leads: int,
num_patches: int,
num_masked_leads: int,
device: torch.device,
generator: torch.Generator | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Mask ``num_masked_leads`` random leads at every time patch (horizontal stripes)."""
C, N = num_leads, num_patches
perm = torch.rand(batch_size, C, device=device, generator=generator).argsort(dim=1)
rank_cols = torch.arange(num_masked_leads, device=device).view(1, -1).expand(batch_size, -1)
masked_lead = torch.gather(perm, 1, rank_cols)
is_masked_lead = torch.zeros(batch_size, C, dtype=torch.bool, device=device)
is_masked_lead.scatter_(1, masked_lead, True)
M = is_masked_lead.unsqueeze(-1).expand(-1, -1, N)
V = ~M
D = torch.zeros(batch_size, C, N, dtype=torch.bool, device=device)
return V, M, D
def _simple_vertical_masks(
batch_size: int,
num_leads: int,
num_patches: int,
*,
band_width: int,
num_bands: int,
device: torch.device,
generator: torch.Generator | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Mask ``num_bands`` disjoint time bands (all leads), each of width ``band_width`` patches."""
C, N = num_leads, num_patches
M = torch.zeros(batch_size, C, N, dtype=torch.bool, device=device)
for b in range(batch_size):
starts = _sample_non_overlapping_band_starts(N, band_width, num_bands, device, generator)
for s in starts:
M[b, :, s : s + band_width] = True
V = ~M
D = torch.zeros_like(M)
return V, M, D
def _no_mask_autoencoder(
batch_size: int,
num_leads: int,
num_patches: int,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Autoencoder-style branch: all tokens visible and all tokens supervised.
``V`` and ``M`` are all-True, ``D`` is all-False.
"""
V = torch.ones(batch_size, num_leads, num_patches, dtype=torch.bool, device=device)
M = torch.ones_like(V)
D = torch.zeros_like(V)
return V, M, D
def get_mask(
batch_size: int,
num_leads: int,
num_patches: int,
*,
p_time: float,
p_lead: float,
k_visible: int,
device: torch.device,
generator: torch.Generator | None = None,
p_stdm: float = 0.25,
p_no_mask_autoencoder: float = 0.25,
p_simple_horizontal: float = 0.25,
p_simple_vertical: float = 0.25,
num_masked_leads: int = 2,
vertical_band_width: int = 11,
num_vertical_bands: int = 2,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Per batch item, choose one of four mask types (probabilities sum to 1 by default):
- ``p_stdm`` (0.25): ``sample_stdm_masks`` (full STDM).
- ``p_no_mask_autoencoder`` (0.25): all patches visible and supervised (AE-style).
- ``p_simple_horizontal`` (0.25): mask ``num_masked_leads`` random leads at all time patches.
- ``p_simple_vertical`` (0.25): mask ``num_vertical_bands`` disjoint time bands of width
``vertical_band_width`` patches; all leads are **M** inside each band.
Returns ``(V, M, D)`` each ``(B, C, N)`` bool.
For STDM/horizontal/vertical branches, ``V | M | D`` partitions the grid.
For AE-style branch, ``V`` and ``M`` intentionally overlap (both all-True), ``D`` is all-False.
"""
probs = (p_stdm, p_no_mask_autoencoder, p_simple_horizontal, p_simple_vertical)
if any(p < 0 or p > 1 for p in probs):
raise ValueError("mask probabilities must be in [0, 1]")
if abs(sum(probs) - 1.0) > 1e-6:
raise ValueError(
"p_stdm + p_no_mask_autoencoder + p_simple_horizontal + p_simple_vertical "
f"must equal 1, got {sum(probs)}"
)
if num_masked_leads < 1 or num_masked_leads > num_leads:
raise ValueError("num_masked_leads must be in [1, num_leads]")
if num_vertical_bands < 1:
raise ValueError("num_vertical_bands must be >= 1")
if vertical_band_width < 1:
raise ValueError("vertical_band_width must be >= 1")
C, N = num_leads, num_patches
V_stdm, M_stdm, D_stdm = sample_stdm_masks(
batch_size,
num_leads,
num_patches,
p_time=p_time,
p_lead=p_lead,
k_visible=k_visible,
device=device,
generator=generator,
)
V_h, M_h, D_h = _simple_horizontal_masks(
batch_size, num_leads, num_patches, num_masked_leads, device, generator
)
V_v, M_v, D_v = _simple_vertical_masks(
batch_size,
num_leads,
num_patches,
band_width=vertical_band_width,
num_bands=num_vertical_bands,
device=device,
generator=generator,
)
V_ae, M_ae, D_ae = _no_mask_autoencoder(batch_size, num_leads, num_patches, device)
u = torch.rand(batch_size, device=device, generator=generator)
t0 = p_stdm
t1 = t0 + p_no_mask_autoencoder
t2 = t1 + p_simple_horizontal
use_stdm = u < t0
use_ae = (u >= t0) & (u < t1)
use_h = (u >= t1) & (u < t2)
use_v = u >= t2
sel_stdm = use_stdm.view(batch_size, 1, 1)
sel_ae = use_ae.view(batch_size, 1, 1)
sel_h = use_h.view(batch_size, 1, 1)
V = torch.where(sel_stdm, V_stdm, torch.where(sel_ae, V_ae, torch.where(sel_h, V_h, V_v)))
M = torch.where(sel_stdm, M_stdm, torch.where(sel_ae, M_ae, torch.where(sel_h, M_h, M_v)))
D = torch.where(sel_stdm, D_stdm, torch.where(sel_ae, D_ae, torch.where(sel_h, D_h, D_v)))
return V, M, D
|