File size: 5,758 Bytes
16d6869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""
Population graph construction for subject-level GCN (Parisot et al. 2017/2018).

Nodes  = subjects
Edges  = phenotypic similarity (sex match × age Gaussian kernel)
Features = PCA-reduced FC upper triangle, fitted on training subjects only
"""

from __future__ import annotations

from pathlib import Path

import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler


# ---------------------------------------------------------------------------
# Phenotypic data
# ---------------------------------------------------------------------------

def load_phenotypic(pheno_csv: str | Path, processed_dir: str | Path) -> pd.DataFrame:
    """Load ABIDE phenotypic CSV and filter to subjects with processed .npz files."""
    pheno = pd.read_csv(pheno_csv)
    processed_dir = Path(processed_dir)

    available = {int(p.stem) for p in processed_dir.glob("*.npz")}
    pheno = pheno[pheno["SUB_ID"].isin(available)].copy().reset_index(drop=True)

    # DX_GROUP: 1=ASD → label=1, 2=TD → label=0
    pheno["label"] = (pheno["DX_GROUP"] == 1).astype(int)
    # SEX: 1=Male → 0, 2=Female → 1
    pheno["sex_enc"] = (pheno["SEX"] == 2).astype(int)

    return pheno


# ---------------------------------------------------------------------------
# Node features: FC upper triangle → PCA
# ---------------------------------------------------------------------------

def extract_fc_features(processed_dir: str | Path, subject_ids: list[int]) -> np.ndarray:
    """Load upper-triangle FC for each subject. Returns (N, 19900) float32."""
    processed_dir = Path(processed_dir)
    out = []
    for sid in subject_ids:
        data = np.load(processed_dir / f"{sid}.npz", allow_pickle=True)
        fc = data["mean_fc"].astype(np.float32)
        r, c = np.triu_indices(fc.shape[0], k=1)
        out.append(fc[r, c])
    return np.stack(out)


def harmonize_combat(
    features: np.ndarray,
    sites: list[str],
    labels: np.ndarray,
    ages: np.ndarray,
    sexes: np.ndarray,
) -> np.ndarray:
    """ComBat site harmonization on FC upper triangle.

    Preserves biological signal (age, sex, diagnosis) while removing
    scanner-specific batch effects — the dominant noise source in multi-site
    fMRI (ABIDE has 17+ sites with different scanners and protocols).
    """
    from neuroCombat import neuroCombat

    # neuroCombat expects (features, subjects) — transpose
    data_T = features.T  # (19900, N)
    covars = pd.DataFrame({
        "site":  sites,
        "age":   ages,
        "sex":   sexes,
        "dx":    labels,
    })
    result = neuroCombat(
        dat=data_T,
        covars=covars,
        batch_col="site",
        continuous_cols=["age"],
        categorical_cols=["sex", "dx"],
    )
    return result["data"].T.astype(np.float32)   # back to (N, 19900)


def fit_pca(train_feats: np.ndarray, n_components: int = 256) -> tuple[StandardScaler, PCA]:
    """Fit StandardScaler + PCA on training features. Returns fitted objects."""
    scaler = StandardScaler()
    train_scaled = scaler.fit_transform(train_feats)
    n_comp = min(n_components, train_scaled.shape[0] - 1, train_scaled.shape[1])
    pca = PCA(n_components=n_comp, random_state=42)
    pca.fit(train_scaled)
    var = pca.explained_variance_ratio_.sum()
    print(f"PCA {n_comp} components → {var:.1%} variance explained")
    return scaler, pca


def apply_pca(feats: np.ndarray, scaler: StandardScaler, pca: PCA) -> np.ndarray:
    return pca.transform(scaler.transform(feats)).astype(np.float32)


# ---------------------------------------------------------------------------
# Population graph
# ---------------------------------------------------------------------------

def build_population_adj(
    subject_df: pd.DataFrame,
    threshold: float = 0.5,
    age_sigma: float | None = None,
    use_site: bool = False,
) -> np.ndarray:
    """Build N×N weighted adjacency from phenotypic similarity.

    Edge weight = sex_match * age_gaussian_sim (* site_match if use_site).
    Edge exists only if weight > threshold.

    Parameters
    ----------
    subject_df  : DataFrame with columns sex_enc, AGE_AT_SCAN, SITE_ID
    threshold   : minimum similarity to keep an edge
    age_sigma   : std dev for Gaussian age kernel (default: std of ages)
    use_site    : include site-match as a multiplier (Parisot original)
                  Disable after ComBat since site effects are removed.
    """
    N = len(subject_df)
    ages  = subject_df["AGE_AT_SCAN"].values.astype(np.float32)
    sexes = subject_df["sex_enc"].values

    if age_sigma is None:
        age_sigma = float(np.std(ages))

    # Age similarity — Gaussian kernel
    diff = ages[:, None] - ages[None, :]
    age_sim = np.exp(-diff**2 / (2 * age_sigma**2))

    # Sex similarity — binary match
    sex_sim = (sexes[:, None] == sexes[None, :]).astype(np.float32)

    W = sex_sim * age_sim

    if use_site:
        sites = np.array(subject_df["SITE_ID"].tolist())   # force plain object array
        site_sim = (sites[:, None] == sites[None, :]).astype(np.float32)
        W = W * site_sim

    adj = np.where(W > threshold, W, 0.0).astype(np.float32)
    np.fill_diagonal(adj, 0.0)

    n_edges = int((adj > 0).sum()) // 2
    density = n_edges / (N * (N - 1) / 2)
    print(f"Population graph: {N} nodes, {n_edges} edges, {density:.1%} density")
    return adj


def normalize_adj(adj: np.ndarray) -> np.ndarray:
    """Symmetric normalization with self-loops: D^{-1/2}(A+I)D^{-1/2}."""
    A = adj + np.eye(adj.shape[0], dtype=np.float32)
    d = A.sum(axis=1)
    d_inv_sqrt = np.where(d > 0, 1.0 / np.sqrt(d), 0.0)
    return (d_inv_sqrt[:, None] * A * d_inv_sqrt[None, :]).astype(np.float32)