File size: 3,177 Bytes
16d6869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
Preprocess ABIDE subjects into cached .npz files.

Each .npz contains:
    bold        (T, N)      — z-scored BOLD time series
    mean_fc     (N, N)      — full-scan Pearson FC
    bold_windows (W, N)      — std of BOLD per window (local signal power; node features)
    fc_windows   (W, N, N)   — per-window Pearson FC (dynamic adjacency)
    label       scalar int  — 0 = TC, 1 = ASD
    subject_id  str
    site        str

Run once via ABIDEDataModule.prepare_data(); subsequent runs load from cache.
"""

from __future__ import annotations

import logging
from pathlib import Path

import numpy as np

from .functional_connectivity import compute_fc, sliding_fc_windows

log = logging.getLogger(__name__)


def zscore(bold: np.ndarray) -> np.ndarray:
    """Z-score each ROI time series independently."""
    mean = bold.mean(axis=0, keepdims=True)
    std = bold.std(axis=0, keepdims=True)
    std[std < 1e-8] = 1.0
    return ((bold - mean) / std).astype(np.float32)


def preprocess_subject(
    subject: dict,
    processed_dir: Path,
    window_len: int = 50,
    step: int = 5,
    overwrite: bool = False,
) -> Path | None:
    """
    Process one subject dict (from download.extract_subjects):
        z-score BOLD → compute FC + sliding windows → save .npz

    Returns Path to saved .npz, or None if processing failed.
    """
    out_path = processed_dir / f"{subject['subject_id']}.npz"

    if out_path.exists() and not overwrite:
        return out_path

    bold = subject["bold"]                  # (T, N) float32
    T, N = bold.shape
    if T < window_len + step:
        log.warning(
            "Subject %s: %d TRs is too short for window_len=%d + step=%d — skipping.",
            subject["subject_id"], T, window_len, step,
        )
        return None

    bold = zscore(bold)
    mean_fc = compute_fc(bold)
    bold_windows, fc_windows = sliding_fc_windows(bold, window_len=window_len, step=step)

    np.savez_compressed(
        out_path,
        bold=bold,
        mean_fc=mean_fc,
        bold_windows=bold_windows,
        fc_windows=fc_windows,
        window_bold=bold_windows,
        window_fc=fc_windows,
        label=np.int64(subject["label"]),
        subject_id=subject["subject_id"],
        site=subject["site"],
    )
    return out_path


def preprocess_all(
    subjects: list[dict],
    processed_dir: str | Path,
    window_len: int = 50,
    step: int = 5,
    overwrite: bool = False,
) -> list[Path]:
    """
    Preprocess all subjects, skipping those already cached.
    Returns list of successfully written .npz paths.
    """
    processed_dir = Path(processed_dir)
    processed_dir.mkdir(parents=True, exist_ok=True)

    paths = []
    for i, subject in enumerate(subjects):
        path = preprocess_subject(
            subject, processed_dir,
            window_len=window_len, step=step, overwrite=overwrite,
        )
        if path is not None:
            paths.append(path)
        if (i + 1) % 50 == 0:
            log.info("Preprocessed %d / %d subjects.", i + 1, len(subjects))

    log.info("Preprocessing done: %d / %d subjects saved.", len(paths), len(subjects))
    return paths