Yatsuiii's picture
Upload folder using huggingface_hub
16d6869 verified
"""
PyTorch Dataset for preprocessed ABIDE subjects.
Each sample returns:
bold_windows : (W, N) — mean BOLD per ROI at each brain-state snapshot
adj : (N, N) or (W, N, N) — adjacency for this subject
use_dynamic_adj=False → subject's mean FC
use_dynamic_adj=True → mean of per-window FCs
use_dynamic_adj_sequence=True → per-window FCs
use_population_adj=True → shared population adj
label : () — int64 scalar (0 = TC, 1 = ASD)
The adjacency is left as raw (thresholded) FC values so the model can apply
its own Laplacian normalisation via utils.graph_conv.
"""
from __future__ import annotations
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import Dataset
class ABIDEDataset(Dataset):
def __init__(
self,
npz_paths: list[Path | str],
population_adj: np.ndarray | None = None,
use_dynamic_adj: bool = False,
use_dynamic_adj_sequence: bool = False,
fc_threshold: float = 0.2,
max_windows: int | None = None,
site_fc_mean: dict[str, np.ndarray] | None = None,
preserve_fc_sign: bool = False,
site_to_int: dict[str, int] | None = None,
use_fc_variance: bool = False,
use_fisher_z: bool = False,
pca_mean: np.ndarray | None = None,
pca_components: np.ndarray | None = None,
use_fc_degree_features: bool = False,
use_fc_row_features: bool = False,
):
"""
Parameters
----------
npz_paths : paths to per-subject .npz files from preprocess.py
population_adj : (N, N) pre-computed population-level adjacency.
If provided, every sample uses this shared adjacency.
use_dynamic_adj : if True and population_adj is None, use mean of
per-window FCs; otherwise use mean_fc (full-scan FC).
use_dynamic_adj_sequence : if True and population_adj is None, return
per-window FCs with shape (W, N, N).
fc_threshold : zero-out edges with |fc| < threshold before returning
max_windows : truncate all subjects to this many windows so that
batches have uniform seq_len (takes the first W windows)
site_fc_mean : per-site mean FC matrix (N, N) computed from training
set. Subtracted from each subject's FC before thresholding
to remove scanner/site connectivity biases (FC-domain
site normalization). BOLD is already z-scored so
BOLD-domain corrections have no effect.
preserve_fc_sign: if True, keep signed FC values in the adjacency instead
of converting to |FC|. Required for fc_mlp which uses
signed correlations as direct features (anti-correlations
between brain networks are diagnostically relevant).
use_fc_degree_features: if True, replace stored bold_windows (std of
z-scored BOLD ≈ 1.0) with per-window per-ROI mean
absolute FC: np.abs(fc_windows).mean(axis=-1). This
gives each ROI a scalar ≈ its average connectivity
strength in that window — directly discriminative
between ASD and TD, unlike BOLD std which is near-
constant after z-scoring.
use_fc_row_features: if True, use per-window FC rows as node features
instead of scalar BOLD std. Returns (W, N, N) where
node i's feature vector is its full connectivity profile
fc_windows[w, i, :]. This is the standard formulation
in brain GCN literature (BrainNetCNN, BrainGNN, STAGIN).
Requires model to be built with in_features=num_nodes.
"""
self.npz_paths = [Path(p) for p in npz_paths]
self.population_adj = (
torch.FloatTensor(population_adj) if population_adj is not None else None
)
self.use_dynamic_adj = use_dynamic_adj
self.use_dynamic_adj_sequence = use_dynamic_adj_sequence
self.fc_threshold = fc_threshold
self.max_windows = max_windows
self.site_fc_mean = site_fc_mean or {}
self.preserve_fc_sign = preserve_fc_sign
self.site_to_int = site_to_int or {}
self.use_fc_variance = use_fc_variance
self.use_fisher_z = use_fisher_z
self.pca_mean = pca_mean
self.pca_components = pca_components
self.use_fc_degree_features = use_fc_degree_features
self.use_fc_row_features = use_fc_row_features
# Pre-load labels + window counts for fast access without loading full arrays
self._meta = self._scan_metadata()
@staticmethod
def _array(data: np.lib.npyio.NpzFile, primary: str, legacy: str) -> np.ndarray:
if primary in data:
return data[primary]
if legacy in data:
return data[legacy]
raise KeyError(f"Expected '{primary}' or legacy '{legacy}' in subject archive")
def _threshold(self, adj_np: np.ndarray, preserve_sign: bool = False) -> np.ndarray:
mask = np.abs(adj_np) >= self.fc_threshold
if preserve_sign:
return np.where(mask, adj_np, 0.0)
return np.where(mask, np.abs(adj_np), 0.0)
@staticmethod
def _fisher_z(fc: np.ndarray) -> np.ndarray:
"""Fisher's r-to-z transform: z = arctanh(r).
Linearises the correlation space — correlations near ±1 are compressed
in Pearson space but uniform in z-space. Stabilises variance across
different correlation magnitudes, which matters for linear classifiers.
Clipped to ±0.9999 to avoid ±inf at perfect correlations.
"""
return np.arctanh(np.clip(fc, -0.9999, 0.9999))
@staticmethod
def _pad_or_truncate_windows(array: np.ndarray, max_windows: int | None) -> np.ndarray:
if max_windows is None:
return array
if array.shape[0] >= max_windows:
return array[:max_windows]
pad_count = max_windows - array.shape[0]
pad = np.repeat(array[-1:], pad_count, axis=0)
return np.concatenate([array, pad], axis=0)
def _scan_metadata(self) -> list[dict]:
meta = []
for p in self.npz_paths:
data = np.load(p, allow_pickle=True)
W = self._array(data, "bold_windows", "window_bold").shape[0]
if self.max_windows is not None:
W = self.max_windows
meta.append(
{
"label": int(data["label"]),
"subject_id": str(data["subject_id"]),
"site": str(data["site"]),
"num_windows": W,
}
)
return meta
# ------------------------------------------------------------------
def __len__(self) -> int:
return len(self.npz_paths)
def __getitem__(self, idx: int):
data = np.load(self.npz_paths[idx], allow_pickle=True)
site = str(data["site"])
# Pre-load fc_windows if needed for node features or dynamic adjacency
_wfc_loaded: np.ndarray | None = None
if self.use_fc_row_features or self.use_fc_degree_features or self.use_dynamic_adj_sequence or self.use_dynamic_adj:
_wfc_loaded = self._array(data, "fc_windows", "window_fc").astype(np.float32)
# Node feature sequence
if self.use_fc_row_features and _wfc_loaded is not None:
# FC rows as node features: (W, N, N) — each node i gets fc_windows[w, i, :]
# This is the standard brain GCN formulation (BrainNetCNN, BrainGNN, STAGIN).
bold_windows = self._pad_or_truncate_windows(_wfc_loaded, self.max_windows)
elif self.use_fc_degree_features and _wfc_loaded is not None:
# Per-window per-ROI mean |FC| after site correction (W, N)
wfc = self._pad_or_truncate_windows(_wfc_loaded, self.max_windows)
if site in self.site_fc_mean:
wfc = wfc - self.site_fc_mean[site].astype(np.float32)[None]
bold_windows = np.abs(wfc).mean(axis=-1)
bold_windows = self._pad_or_truncate_windows(bold_windows, self.max_windows)
else:
bold_windows = self._array(data, "bold_windows", "window_bold").astype(np.float32)
bold_windows = self._pad_or_truncate_windows(bold_windows, self.max_windows)
# Adjacency
if self.population_adj is not None:
adj = self.population_adj # (N, N) shared
elif self.use_dynamic_adj_sequence:
assert _wfc_loaded is not None
wfc = self._pad_or_truncate_windows(_wfc_loaded, self.max_windows)
if site in self.site_fc_mean:
wfc = wfc - self.site_fc_mean[site].astype(np.float32)[None]
adj = torch.FloatTensor(
self._threshold(wfc, self.preserve_fc_sign).astype(np.float32)
) # (W, N, N)
elif self.use_dynamic_adj:
assert _wfc_loaded is not None
wfc = self._pad_or_truncate_windows(_wfc_loaded, self.max_windows)
fc = wfc.mean(axis=0)
if site in self.site_fc_mean:
fc = fc - self.site_fc_mean[site].astype(np.float32)
adj = torch.FloatTensor(
self._threshold(fc, self.preserve_fc_sign).astype(np.float32)
) # (N, N)
else:
# Static per-subject mean FC
mean_np = data["mean_fc"].astype(np.float32)
if site in self.site_fc_mean:
mean_np = mean_np - self.site_fc_mean[site].astype(np.float32)
if self.use_fisher_z:
mean_np = self._fisher_z(mean_np)
mean_np = self._threshold(mean_np, self.preserve_fc_sign).astype(np.float32)
if self.pca_mean is not None and self.pca_components is not None:
# PCA projection: (D,) → (K,)
# Extract upper triangle the same way the MLP model does
n = mean_np.shape[0]
r, c = np.triu_indices(n, k=1)
x_vec = mean_np[r, c] - self.pca_mean # centre
x_pca = (self.pca_components @ x_vec).astype(np.float32) # (K,)
# Return as (1, K) so collate_fn stacks to (B, 1, K); model flattens
adj = torch.FloatTensor(x_pca).unsqueeze(0) # (1, K)
elif self.use_fc_variance:
# Second channel: temporal std of FC — captures connection instability
wfc = self._array(data, "fc_windows", "window_fc").astype(np.float32)
wfc = self._pad_or_truncate_windows(wfc, self.max_windows)
std_np = wfc.std(axis=0).astype(np.float32)
adj = torch.FloatTensor(np.stack([mean_np, std_np], axis=0)) # (2, N, N)
else:
adj = torch.FloatTensor(mean_np) # (N, N)
label = torch.tensor(int(data["label"]), dtype=torch.long)
site_id = torch.tensor(self.site_to_int.get(site, -1), dtype=torch.long)
return torch.FloatTensor(bold_windows), adj, label, site_id
# ------------------------------------------------------------------
@property
def labels(self) -> list[int]:
return [m["label"] for m in self._meta]
@property
def num_nodes(self) -> int:
data = np.load(self.npz_paths[0], allow_pickle=True)
return data["mean_fc"].shape[0]
@property
def num_windows(self) -> int:
return self._meta[0]["num_windows"]