| """ |
| PyTorch Dataset for preprocessed ABIDE subjects. |
| |
| Each sample returns: |
| bold_windows : (W, N) — mean BOLD per ROI at each brain-state snapshot |
| adj : (N, N) or (W, N, N) — adjacency for this subject |
| use_dynamic_adj=False → subject's mean FC |
| use_dynamic_adj=True → mean of per-window FCs |
| use_dynamic_adj_sequence=True → per-window FCs |
| use_population_adj=True → shared population adj |
| label : () — int64 scalar (0 = TC, 1 = ASD) |
| |
| The adjacency is left as raw (thresholded) FC values so the model can apply |
| its own Laplacian normalisation via utils.graph_conv. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset |
|
|
|
|
| class ABIDEDataset(Dataset): |
| def __init__( |
| self, |
| npz_paths: list[Path | str], |
| population_adj: np.ndarray | None = None, |
| use_dynamic_adj: bool = False, |
| use_dynamic_adj_sequence: bool = False, |
| fc_threshold: float = 0.2, |
| max_windows: int | None = None, |
| site_fc_mean: dict[str, np.ndarray] | None = None, |
| preserve_fc_sign: bool = False, |
| site_to_int: dict[str, int] | None = None, |
| use_fc_variance: bool = False, |
| use_fisher_z: bool = False, |
| pca_mean: np.ndarray | None = None, |
| pca_components: np.ndarray | None = None, |
| use_fc_degree_features: bool = False, |
| use_fc_row_features: bool = False, |
| ): |
| """ |
| Parameters |
| ---------- |
| npz_paths : paths to per-subject .npz files from preprocess.py |
| population_adj : (N, N) pre-computed population-level adjacency. |
| If provided, every sample uses this shared adjacency. |
| use_dynamic_adj : if True and population_adj is None, use mean of |
| per-window FCs; otherwise use mean_fc (full-scan FC). |
| use_dynamic_adj_sequence : if True and population_adj is None, return |
| per-window FCs with shape (W, N, N). |
| fc_threshold : zero-out edges with |fc| < threshold before returning |
| max_windows : truncate all subjects to this many windows so that |
| batches have uniform seq_len (takes the first W windows) |
| site_fc_mean : per-site mean FC matrix (N, N) computed from training |
| set. Subtracted from each subject's FC before thresholding |
| to remove scanner/site connectivity biases (FC-domain |
| site normalization). BOLD is already z-scored so |
| BOLD-domain corrections have no effect. |
| preserve_fc_sign: if True, keep signed FC values in the adjacency instead |
| of converting to |FC|. Required for fc_mlp which uses |
| signed correlations as direct features (anti-correlations |
| between brain networks are diagnostically relevant). |
| use_fc_degree_features: if True, replace stored bold_windows (std of |
| z-scored BOLD ≈ 1.0) with per-window per-ROI mean |
| absolute FC: np.abs(fc_windows).mean(axis=-1). This |
| gives each ROI a scalar ≈ its average connectivity |
| strength in that window — directly discriminative |
| between ASD and TD, unlike BOLD std which is near- |
| constant after z-scoring. |
| use_fc_row_features: if True, use per-window FC rows as node features |
| instead of scalar BOLD std. Returns (W, N, N) where |
| node i's feature vector is its full connectivity profile |
| fc_windows[w, i, :]. This is the standard formulation |
| in brain GCN literature (BrainNetCNN, BrainGNN, STAGIN). |
| Requires model to be built with in_features=num_nodes. |
| """ |
| self.npz_paths = [Path(p) for p in npz_paths] |
| self.population_adj = ( |
| torch.FloatTensor(population_adj) if population_adj is not None else None |
| ) |
| self.use_dynamic_adj = use_dynamic_adj |
| self.use_dynamic_adj_sequence = use_dynamic_adj_sequence |
| self.fc_threshold = fc_threshold |
| self.max_windows = max_windows |
| self.site_fc_mean = site_fc_mean or {} |
| self.preserve_fc_sign = preserve_fc_sign |
| self.site_to_int = site_to_int or {} |
| self.use_fc_variance = use_fc_variance |
| self.use_fisher_z = use_fisher_z |
| self.pca_mean = pca_mean |
| self.pca_components = pca_components |
| self.use_fc_degree_features = use_fc_degree_features |
| self.use_fc_row_features = use_fc_row_features |
|
|
| |
| self._meta = self._scan_metadata() |
|
|
| @staticmethod |
| def _array(data: np.lib.npyio.NpzFile, primary: str, legacy: str) -> np.ndarray: |
| if primary in data: |
| return data[primary] |
| if legacy in data: |
| return data[legacy] |
| raise KeyError(f"Expected '{primary}' or legacy '{legacy}' in subject archive") |
|
|
| def _threshold(self, adj_np: np.ndarray, preserve_sign: bool = False) -> np.ndarray: |
| mask = np.abs(adj_np) >= self.fc_threshold |
| if preserve_sign: |
| return np.where(mask, adj_np, 0.0) |
| return np.where(mask, np.abs(adj_np), 0.0) |
|
|
| @staticmethod |
| def _fisher_z(fc: np.ndarray) -> np.ndarray: |
| """Fisher's r-to-z transform: z = arctanh(r). |
| |
| Linearises the correlation space — correlations near ±1 are compressed |
| in Pearson space but uniform in z-space. Stabilises variance across |
| different correlation magnitudes, which matters for linear classifiers. |
| Clipped to ±0.9999 to avoid ±inf at perfect correlations. |
| """ |
| return np.arctanh(np.clip(fc, -0.9999, 0.9999)) |
|
|
| @staticmethod |
| def _pad_or_truncate_windows(array: np.ndarray, max_windows: int | None) -> np.ndarray: |
| if max_windows is None: |
| return array |
| if array.shape[0] >= max_windows: |
| return array[:max_windows] |
| pad_count = max_windows - array.shape[0] |
| pad = np.repeat(array[-1:], pad_count, axis=0) |
| return np.concatenate([array, pad], axis=0) |
|
|
| def _scan_metadata(self) -> list[dict]: |
| meta = [] |
| for p in self.npz_paths: |
| data = np.load(p, allow_pickle=True) |
| W = self._array(data, "bold_windows", "window_bold").shape[0] |
| if self.max_windows is not None: |
| W = self.max_windows |
| meta.append( |
| { |
| "label": int(data["label"]), |
| "subject_id": str(data["subject_id"]), |
| "site": str(data["site"]), |
| "num_windows": W, |
| } |
| ) |
| return meta |
|
|
| |
| def __len__(self) -> int: |
| return len(self.npz_paths) |
|
|
| def __getitem__(self, idx: int): |
| data = np.load(self.npz_paths[idx], allow_pickle=True) |
|
|
| site = str(data["site"]) |
|
|
| |
| _wfc_loaded: np.ndarray | None = None |
| if self.use_fc_row_features or self.use_fc_degree_features or self.use_dynamic_adj_sequence or self.use_dynamic_adj: |
| _wfc_loaded = self._array(data, "fc_windows", "window_fc").astype(np.float32) |
|
|
| |
| if self.use_fc_row_features and _wfc_loaded is not None: |
| |
| |
| bold_windows = self._pad_or_truncate_windows(_wfc_loaded, self.max_windows) |
| elif self.use_fc_degree_features and _wfc_loaded is not None: |
| |
| wfc = self._pad_or_truncate_windows(_wfc_loaded, self.max_windows) |
| if site in self.site_fc_mean: |
| wfc = wfc - self.site_fc_mean[site].astype(np.float32)[None] |
| bold_windows = np.abs(wfc).mean(axis=-1) |
| bold_windows = self._pad_or_truncate_windows(bold_windows, self.max_windows) |
| else: |
| bold_windows = self._array(data, "bold_windows", "window_bold").astype(np.float32) |
| bold_windows = self._pad_or_truncate_windows(bold_windows, self.max_windows) |
|
|
| |
| if self.population_adj is not None: |
| adj = self.population_adj |
|
|
| elif self.use_dynamic_adj_sequence: |
| assert _wfc_loaded is not None |
| wfc = self._pad_or_truncate_windows(_wfc_loaded, self.max_windows) |
| if site in self.site_fc_mean: |
| wfc = wfc - self.site_fc_mean[site].astype(np.float32)[None] |
| adj = torch.FloatTensor( |
| self._threshold(wfc, self.preserve_fc_sign).astype(np.float32) |
| ) |
|
|
| elif self.use_dynamic_adj: |
| assert _wfc_loaded is not None |
| wfc = self._pad_or_truncate_windows(_wfc_loaded, self.max_windows) |
| fc = wfc.mean(axis=0) |
| if site in self.site_fc_mean: |
| fc = fc - self.site_fc_mean[site].astype(np.float32) |
| adj = torch.FloatTensor( |
| self._threshold(fc, self.preserve_fc_sign).astype(np.float32) |
| ) |
|
|
| else: |
| |
| mean_np = data["mean_fc"].astype(np.float32) |
| if site in self.site_fc_mean: |
| mean_np = mean_np - self.site_fc_mean[site].astype(np.float32) |
| if self.use_fisher_z: |
| mean_np = self._fisher_z(mean_np) |
| mean_np = self._threshold(mean_np, self.preserve_fc_sign).astype(np.float32) |
|
|
| if self.pca_mean is not None and self.pca_components is not None: |
| |
| |
| n = mean_np.shape[0] |
| r, c = np.triu_indices(n, k=1) |
| x_vec = mean_np[r, c] - self.pca_mean |
| x_pca = (self.pca_components @ x_vec).astype(np.float32) |
| |
| adj = torch.FloatTensor(x_pca).unsqueeze(0) |
|
|
| elif self.use_fc_variance: |
| |
| wfc = self._array(data, "fc_windows", "window_fc").astype(np.float32) |
| wfc = self._pad_or_truncate_windows(wfc, self.max_windows) |
| std_np = wfc.std(axis=0).astype(np.float32) |
| adj = torch.FloatTensor(np.stack([mean_np, std_np], axis=0)) |
|
|
| else: |
| adj = torch.FloatTensor(mean_np) |
|
|
| label = torch.tensor(int(data["label"]), dtype=torch.long) |
| site_id = torch.tensor(self.site_to_int.get(site, -1), dtype=torch.long) |
| return torch.FloatTensor(bold_windows), adj, label, site_id |
|
|
| |
| @property |
| def labels(self) -> list[int]: |
| return [m["label"] for m in self._meta] |
|
|
| @property |
| def num_nodes(self) -> int: |
| data = np.load(self.npz_paths[0], allow_pickle=True) |
| return data["mean_fc"].shape[0] |
|
|
| @property |
| def num_windows(self) -> int: |
| return self._meta[0]["num_windows"] |
|
|