File size: 12,081 Bytes
16d6869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
"""
PyTorch Dataset for preprocessed ABIDE subjects.

Each sample returns:
    bold_windows : (W, N)     — mean BOLD per ROI at each brain-state snapshot
    adj         : (N, N) or (W, N, N) — adjacency for this subject
                               use_dynamic_adj=False → subject's mean FC
                               use_dynamic_adj=True  → mean of per-window FCs
                               use_dynamic_adj_sequence=True → per-window FCs
                               use_population_adj=True → shared population adj
    label       : ()         — int64 scalar  (0 = TC, 1 = ASD)

The adjacency is left as raw (thresholded) FC values so the model can apply
its own Laplacian normalisation via utils.graph_conv.
"""

from __future__ import annotations

from pathlib import Path

import numpy as np
import torch
from torch.utils.data import Dataset


class ABIDEDataset(Dataset):
    def __init__(
        self,
        npz_paths: list[Path | str],
        population_adj: np.ndarray | None = None,
        use_dynamic_adj: bool = False,
        use_dynamic_adj_sequence: bool = False,
        fc_threshold: float = 0.2,
        max_windows: int | None = None,
        site_fc_mean: dict[str, np.ndarray] | None = None,
        preserve_fc_sign: bool = False,
        site_to_int: dict[str, int] | None = None,
        use_fc_variance: bool = False,
        use_fisher_z: bool = False,
        pca_mean: np.ndarray | None = None,
        pca_components: np.ndarray | None = None,
        use_fc_degree_features: bool = False,
        use_fc_row_features: bool = False,
    ):
        """
        Parameters
        ----------
        npz_paths       : paths to per-subject .npz files from preprocess.py
        population_adj  : (N, N) pre-computed population-level adjacency.
                          If provided, every sample uses this shared adjacency.
        use_dynamic_adj : if True and population_adj is None, use mean of
                          per-window FCs; otherwise use mean_fc (full-scan FC).
        use_dynamic_adj_sequence : if True and population_adj is None, return
                          per-window FCs with shape (W, N, N).
        fc_threshold    : zero-out edges with |fc| < threshold before returning
        max_windows     : truncate all subjects to this many windows so that
                          batches have uniform seq_len (takes the first W windows)
        site_fc_mean    : per-site mean FC matrix (N, N) computed from training
                          set. Subtracted from each subject's FC before thresholding
                          to remove scanner/site connectivity biases (FC-domain
                          site normalization). BOLD is already z-scored so
                          BOLD-domain corrections have no effect.
        preserve_fc_sign: if True, keep signed FC values in the adjacency instead
                          of converting to |FC|. Required for fc_mlp which uses
                          signed correlations as direct features (anti-correlations
                          between brain networks are diagnostically relevant).
        use_fc_degree_features: if True, replace stored bold_windows (std of
                          z-scored BOLD ≈ 1.0) with per-window per-ROI mean
                          absolute FC: np.abs(fc_windows).mean(axis=-1). This
                          gives each ROI a scalar ≈ its average connectivity
                          strength in that window — directly discriminative
                          between ASD and TD, unlike BOLD std which is near-
                          constant after z-scoring.
        use_fc_row_features: if True, use per-window FC rows as node features
                          instead of scalar BOLD std. Returns (W, N, N) where
                          node i's feature vector is its full connectivity profile
                          fc_windows[w, i, :]. This is the standard formulation
                          in brain GCN literature (BrainNetCNN, BrainGNN, STAGIN).
                          Requires model to be built with in_features=num_nodes.
        """
        self.npz_paths = [Path(p) for p in npz_paths]
        self.population_adj = (
            torch.FloatTensor(population_adj) if population_adj is not None else None
        )
        self.use_dynamic_adj = use_dynamic_adj
        self.use_dynamic_adj_sequence = use_dynamic_adj_sequence
        self.fc_threshold = fc_threshold
        self.max_windows = max_windows
        self.site_fc_mean = site_fc_mean or {}
        self.preserve_fc_sign = preserve_fc_sign
        self.site_to_int = site_to_int or {}
        self.use_fc_variance = use_fc_variance
        self.use_fisher_z = use_fisher_z
        self.pca_mean = pca_mean
        self.pca_components = pca_components
        self.use_fc_degree_features = use_fc_degree_features
        self.use_fc_row_features = use_fc_row_features

        # Pre-load labels + window counts for fast access without loading full arrays
        self._meta = self._scan_metadata()

    @staticmethod
    def _array(data: np.lib.npyio.NpzFile, primary: str, legacy: str) -> np.ndarray:
        if primary in data:
            return data[primary]
        if legacy in data:
            return data[legacy]
        raise KeyError(f"Expected '{primary}' or legacy '{legacy}' in subject archive")

    def _threshold(self, adj_np: np.ndarray, preserve_sign: bool = False) -> np.ndarray:
        mask = np.abs(adj_np) >= self.fc_threshold
        if preserve_sign:
            return np.where(mask, adj_np, 0.0)
        return np.where(mask, np.abs(adj_np), 0.0)

    @staticmethod
    def _fisher_z(fc: np.ndarray) -> np.ndarray:
        """Fisher's r-to-z transform: z = arctanh(r).

        Linearises the correlation space — correlations near ±1 are compressed
        in Pearson space but uniform in z-space. Stabilises variance across
        different correlation magnitudes, which matters for linear classifiers.
        Clipped to ±0.9999 to avoid ±inf at perfect correlations.
        """
        return np.arctanh(np.clip(fc, -0.9999, 0.9999))

    @staticmethod
    def _pad_or_truncate_windows(array: np.ndarray, max_windows: int | None) -> np.ndarray:
        if max_windows is None:
            return array
        if array.shape[0] >= max_windows:
            return array[:max_windows]
        pad_count = max_windows - array.shape[0]
        pad = np.repeat(array[-1:], pad_count, axis=0)
        return np.concatenate([array, pad], axis=0)

    def _scan_metadata(self) -> list[dict]:
        meta = []
        for p in self.npz_paths:
            data = np.load(p, allow_pickle=True)
            W = self._array(data, "bold_windows", "window_bold").shape[0]
            if self.max_windows is not None:
                W = self.max_windows
            meta.append(
                {
                    "label": int(data["label"]),
                    "subject_id": str(data["subject_id"]),
                    "site": str(data["site"]),
                    "num_windows": W,
                }
            )
        return meta

    # ------------------------------------------------------------------
    def __len__(self) -> int:
        return len(self.npz_paths)

    def __getitem__(self, idx: int):
        data = np.load(self.npz_paths[idx], allow_pickle=True)

        site = str(data["site"])

        # Pre-load fc_windows if needed for node features or dynamic adjacency
        _wfc_loaded: np.ndarray | None = None
        if self.use_fc_row_features or self.use_fc_degree_features or self.use_dynamic_adj_sequence or self.use_dynamic_adj:
            _wfc_loaded = self._array(data, "fc_windows", "window_fc").astype(np.float32)

        # Node feature sequence
        if self.use_fc_row_features and _wfc_loaded is not None:
            # FC rows as node features: (W, N, N) — each node i gets fc_windows[w, i, :]
            # This is the standard brain GCN formulation (BrainNetCNN, BrainGNN, STAGIN).
            bold_windows = self._pad_or_truncate_windows(_wfc_loaded, self.max_windows)
        elif self.use_fc_degree_features and _wfc_loaded is not None:
            # Per-window per-ROI mean |FC| after site correction (W, N)
            wfc = self._pad_or_truncate_windows(_wfc_loaded, self.max_windows)
            if site in self.site_fc_mean:
                wfc = wfc - self.site_fc_mean[site].astype(np.float32)[None]
            bold_windows = np.abs(wfc).mean(axis=-1)
            bold_windows = self._pad_or_truncate_windows(bold_windows, self.max_windows)
        else:
            bold_windows = self._array(data, "bold_windows", "window_bold").astype(np.float32)
            bold_windows = self._pad_or_truncate_windows(bold_windows, self.max_windows)

        # Adjacency
        if self.population_adj is not None:
            adj = self.population_adj                          # (N, N) shared

        elif self.use_dynamic_adj_sequence:
            assert _wfc_loaded is not None
            wfc = self._pad_or_truncate_windows(_wfc_loaded, self.max_windows)
            if site in self.site_fc_mean:
                wfc = wfc - self.site_fc_mean[site].astype(np.float32)[None]
            adj = torch.FloatTensor(
                self._threshold(wfc, self.preserve_fc_sign).astype(np.float32)
            )                                                  # (W, N, N)

        elif self.use_dynamic_adj:
            assert _wfc_loaded is not None
            wfc = self._pad_or_truncate_windows(_wfc_loaded, self.max_windows)
            fc = wfc.mean(axis=0)
            if site in self.site_fc_mean:
                fc = fc - self.site_fc_mean[site].astype(np.float32)
            adj = torch.FloatTensor(
                self._threshold(fc, self.preserve_fc_sign).astype(np.float32)
            )                                                  # (N, N)

        else:
            # Static per-subject mean FC
            mean_np = data["mean_fc"].astype(np.float32)
            if site in self.site_fc_mean:
                mean_np = mean_np - self.site_fc_mean[site].astype(np.float32)
            if self.use_fisher_z:
                mean_np = self._fisher_z(mean_np)
            mean_np = self._threshold(mean_np, self.preserve_fc_sign).astype(np.float32)

            if self.pca_mean is not None and self.pca_components is not None:
                # PCA projection: (D,) → (K,)
                # Extract upper triangle the same way the MLP model does
                n = mean_np.shape[0]
                r, c = np.triu_indices(n, k=1)
                x_vec = mean_np[r, c] - self.pca_mean               # centre
                x_pca = (self.pca_components @ x_vec).astype(np.float32)  # (K,)
                # Return as (1, K) so collate_fn stacks to (B, 1, K); model flattens
                adj = torch.FloatTensor(x_pca).unsqueeze(0)          # (1, K)

            elif self.use_fc_variance:
                # Second channel: temporal std of FC — captures connection instability
                wfc = self._array(data, "fc_windows", "window_fc").astype(np.float32)
                wfc = self._pad_or_truncate_windows(wfc, self.max_windows)
                std_np = wfc.std(axis=0).astype(np.float32)
                adj = torch.FloatTensor(np.stack([mean_np, std_np], axis=0))  # (2, N, N)

            else:
                adj = torch.FloatTensor(mean_np)               # (N, N)

        label = torch.tensor(int(data["label"]), dtype=torch.long)
        site_id = torch.tensor(self.site_to_int.get(site, -1), dtype=torch.long)
        return torch.FloatTensor(bold_windows), adj, label, site_id

    # ------------------------------------------------------------------
    @property
    def labels(self) -> list[int]:
        return [m["label"] for m in self._meta]

    @property
    def num_nodes(self) -> int:
        data = np.load(self.npz_paths[0], allow_pickle=True)
        return data["mean_fc"].shape[0]

    @property
    def num_windows(self) -> int:
        return self._meta[0]["num_windows"]