File size: 12,081 Bytes
16d6869 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 | """
PyTorch Dataset for preprocessed ABIDE subjects.
Each sample returns:
bold_windows : (W, N) — mean BOLD per ROI at each brain-state snapshot
adj : (N, N) or (W, N, N) — adjacency for this subject
use_dynamic_adj=False → subject's mean FC
use_dynamic_adj=True → mean of per-window FCs
use_dynamic_adj_sequence=True → per-window FCs
use_population_adj=True → shared population adj
label : () — int64 scalar (0 = TC, 1 = ASD)
The adjacency is left as raw (thresholded) FC values so the model can apply
its own Laplacian normalisation via utils.graph_conv.
"""
from __future__ import annotations
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import Dataset
class ABIDEDataset(Dataset):
def __init__(
self,
npz_paths: list[Path | str],
population_adj: np.ndarray | None = None,
use_dynamic_adj: bool = False,
use_dynamic_adj_sequence: bool = False,
fc_threshold: float = 0.2,
max_windows: int | None = None,
site_fc_mean: dict[str, np.ndarray] | None = None,
preserve_fc_sign: bool = False,
site_to_int: dict[str, int] | None = None,
use_fc_variance: bool = False,
use_fisher_z: bool = False,
pca_mean: np.ndarray | None = None,
pca_components: np.ndarray | None = None,
use_fc_degree_features: bool = False,
use_fc_row_features: bool = False,
):
"""
Parameters
----------
npz_paths : paths to per-subject .npz files from preprocess.py
population_adj : (N, N) pre-computed population-level adjacency.
If provided, every sample uses this shared adjacency.
use_dynamic_adj : if True and population_adj is None, use mean of
per-window FCs; otherwise use mean_fc (full-scan FC).
use_dynamic_adj_sequence : if True and population_adj is None, return
per-window FCs with shape (W, N, N).
fc_threshold : zero-out edges with |fc| < threshold before returning
max_windows : truncate all subjects to this many windows so that
batches have uniform seq_len (takes the first W windows)
site_fc_mean : per-site mean FC matrix (N, N) computed from training
set. Subtracted from each subject's FC before thresholding
to remove scanner/site connectivity biases (FC-domain
site normalization). BOLD is already z-scored so
BOLD-domain corrections have no effect.
preserve_fc_sign: if True, keep signed FC values in the adjacency instead
of converting to |FC|. Required for fc_mlp which uses
signed correlations as direct features (anti-correlations
between brain networks are diagnostically relevant).
use_fc_degree_features: if True, replace stored bold_windows (std of
z-scored BOLD ≈ 1.0) with per-window per-ROI mean
absolute FC: np.abs(fc_windows).mean(axis=-1). This
gives each ROI a scalar ≈ its average connectivity
strength in that window — directly discriminative
between ASD and TD, unlike BOLD std which is near-
constant after z-scoring.
use_fc_row_features: if True, use per-window FC rows as node features
instead of scalar BOLD std. Returns (W, N, N) where
node i's feature vector is its full connectivity profile
fc_windows[w, i, :]. This is the standard formulation
in brain GCN literature (BrainNetCNN, BrainGNN, STAGIN).
Requires model to be built with in_features=num_nodes.
"""
self.npz_paths = [Path(p) for p in npz_paths]
self.population_adj = (
torch.FloatTensor(population_adj) if population_adj is not None else None
)
self.use_dynamic_adj = use_dynamic_adj
self.use_dynamic_adj_sequence = use_dynamic_adj_sequence
self.fc_threshold = fc_threshold
self.max_windows = max_windows
self.site_fc_mean = site_fc_mean or {}
self.preserve_fc_sign = preserve_fc_sign
self.site_to_int = site_to_int or {}
self.use_fc_variance = use_fc_variance
self.use_fisher_z = use_fisher_z
self.pca_mean = pca_mean
self.pca_components = pca_components
self.use_fc_degree_features = use_fc_degree_features
self.use_fc_row_features = use_fc_row_features
# Pre-load labels + window counts for fast access without loading full arrays
self._meta = self._scan_metadata()
@staticmethod
def _array(data: np.lib.npyio.NpzFile, primary: str, legacy: str) -> np.ndarray:
if primary in data:
return data[primary]
if legacy in data:
return data[legacy]
raise KeyError(f"Expected '{primary}' or legacy '{legacy}' in subject archive")
def _threshold(self, adj_np: np.ndarray, preserve_sign: bool = False) -> np.ndarray:
mask = np.abs(adj_np) >= self.fc_threshold
if preserve_sign:
return np.where(mask, adj_np, 0.0)
return np.where(mask, np.abs(adj_np), 0.0)
@staticmethod
def _fisher_z(fc: np.ndarray) -> np.ndarray:
"""Fisher's r-to-z transform: z = arctanh(r).
Linearises the correlation space — correlations near ±1 are compressed
in Pearson space but uniform in z-space. Stabilises variance across
different correlation magnitudes, which matters for linear classifiers.
Clipped to ±0.9999 to avoid ±inf at perfect correlations.
"""
return np.arctanh(np.clip(fc, -0.9999, 0.9999))
@staticmethod
def _pad_or_truncate_windows(array: np.ndarray, max_windows: int | None) -> np.ndarray:
if max_windows is None:
return array
if array.shape[0] >= max_windows:
return array[:max_windows]
pad_count = max_windows - array.shape[0]
pad = np.repeat(array[-1:], pad_count, axis=0)
return np.concatenate([array, pad], axis=0)
def _scan_metadata(self) -> list[dict]:
meta = []
for p in self.npz_paths:
data = np.load(p, allow_pickle=True)
W = self._array(data, "bold_windows", "window_bold").shape[0]
if self.max_windows is not None:
W = self.max_windows
meta.append(
{
"label": int(data["label"]),
"subject_id": str(data["subject_id"]),
"site": str(data["site"]),
"num_windows": W,
}
)
return meta
# ------------------------------------------------------------------
def __len__(self) -> int:
return len(self.npz_paths)
def __getitem__(self, idx: int):
data = np.load(self.npz_paths[idx], allow_pickle=True)
site = str(data["site"])
# Pre-load fc_windows if needed for node features or dynamic adjacency
_wfc_loaded: np.ndarray | None = None
if self.use_fc_row_features or self.use_fc_degree_features or self.use_dynamic_adj_sequence or self.use_dynamic_adj:
_wfc_loaded = self._array(data, "fc_windows", "window_fc").astype(np.float32)
# Node feature sequence
if self.use_fc_row_features and _wfc_loaded is not None:
# FC rows as node features: (W, N, N) — each node i gets fc_windows[w, i, :]
# This is the standard brain GCN formulation (BrainNetCNN, BrainGNN, STAGIN).
bold_windows = self._pad_or_truncate_windows(_wfc_loaded, self.max_windows)
elif self.use_fc_degree_features and _wfc_loaded is not None:
# Per-window per-ROI mean |FC| after site correction (W, N)
wfc = self._pad_or_truncate_windows(_wfc_loaded, self.max_windows)
if site in self.site_fc_mean:
wfc = wfc - self.site_fc_mean[site].astype(np.float32)[None]
bold_windows = np.abs(wfc).mean(axis=-1)
bold_windows = self._pad_or_truncate_windows(bold_windows, self.max_windows)
else:
bold_windows = self._array(data, "bold_windows", "window_bold").astype(np.float32)
bold_windows = self._pad_or_truncate_windows(bold_windows, self.max_windows)
# Adjacency
if self.population_adj is not None:
adj = self.population_adj # (N, N) shared
elif self.use_dynamic_adj_sequence:
assert _wfc_loaded is not None
wfc = self._pad_or_truncate_windows(_wfc_loaded, self.max_windows)
if site in self.site_fc_mean:
wfc = wfc - self.site_fc_mean[site].astype(np.float32)[None]
adj = torch.FloatTensor(
self._threshold(wfc, self.preserve_fc_sign).astype(np.float32)
) # (W, N, N)
elif self.use_dynamic_adj:
assert _wfc_loaded is not None
wfc = self._pad_or_truncate_windows(_wfc_loaded, self.max_windows)
fc = wfc.mean(axis=0)
if site in self.site_fc_mean:
fc = fc - self.site_fc_mean[site].astype(np.float32)
adj = torch.FloatTensor(
self._threshold(fc, self.preserve_fc_sign).astype(np.float32)
) # (N, N)
else:
# Static per-subject mean FC
mean_np = data["mean_fc"].astype(np.float32)
if site in self.site_fc_mean:
mean_np = mean_np - self.site_fc_mean[site].astype(np.float32)
if self.use_fisher_z:
mean_np = self._fisher_z(mean_np)
mean_np = self._threshold(mean_np, self.preserve_fc_sign).astype(np.float32)
if self.pca_mean is not None and self.pca_components is not None:
# PCA projection: (D,) → (K,)
# Extract upper triangle the same way the MLP model does
n = mean_np.shape[0]
r, c = np.triu_indices(n, k=1)
x_vec = mean_np[r, c] - self.pca_mean # centre
x_pca = (self.pca_components @ x_vec).astype(np.float32) # (K,)
# Return as (1, K) so collate_fn stacks to (B, 1, K); model flattens
adj = torch.FloatTensor(x_pca).unsqueeze(0) # (1, K)
elif self.use_fc_variance:
# Second channel: temporal std of FC — captures connection instability
wfc = self._array(data, "fc_windows", "window_fc").astype(np.float32)
wfc = self._pad_or_truncate_windows(wfc, self.max_windows)
std_np = wfc.std(axis=0).astype(np.float32)
adj = torch.FloatTensor(np.stack([mean_np, std_np], axis=0)) # (2, N, N)
else:
adj = torch.FloatTensor(mean_np) # (N, N)
label = torch.tensor(int(data["label"]), dtype=torch.long)
site_id = torch.tensor(self.site_to_int.get(site, -1), dtype=torch.long)
return torch.FloatTensor(bold_windows), adj, label, site_id
# ------------------------------------------------------------------
@property
def labels(self) -> list[int]:
return [m["label"] for m in self._meta]
@property
def num_nodes(self) -> int:
data = np.load(self.npz_paths[0], allow_pickle=True)
return data["mean_fc"].shape[0]
@property
def num_windows(self) -> int:
return self._meta[0]["num_windows"]
|