| """ |
| Population-level GCN for subject-level ASD/TD classification. |
| |
| All subjects are nodes in a single graph — transductive setting. |
| The model sees all node features (including unlabeled val/test subjects) |
| during forward passes; loss is masked to training nodes only. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
|
|
|
|
| class GraphConv(nn.Module): |
| """Single graph convolution: linear projection after neighborhood aggregation.""" |
|
|
| def __init__(self, in_dim: int, out_dim: int, bias: bool = True): |
| super().__init__() |
| self.linear = nn.Linear(in_dim, out_dim, bias=bias) |
|
|
| def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor: |
| |
| return self.linear(adj @ x) |
|
|
|
|
| class PopulationGCN(nn.Module): |
| """2-layer GCN on the subject population graph. |
| |
| Architecture |
| ============ |
| Input → Dropout → GC1 → LayerNorm → ReLU |
| → Dropout → GC2 → LayerNorm → ReLU |
| → Dropout → Linear → logits (N, num_classes) |
| |
| Depth 2 is sufficient: each node aggregates 2-hop neighbors, |
| covering subjects with similar age+sex across the whole cohort. |
| """ |
|
|
| def __init__( |
| self, |
| in_dim: int, |
| hidden_dim: int = 64, |
| num_classes: int = 2, |
| dropout: float = 0.5, |
| ): |
| super().__init__() |
| self.gc1 = GraphConv(in_dim, hidden_dim) |
| self.gc2 = GraphConv(hidden_dim, hidden_dim) |
| self.norm1 = nn.LayerNorm(hidden_dim) |
| self.norm2 = nn.LayerNorm(hidden_dim) |
| self.head = nn.Linear(hidden_dim, num_classes) |
| self.drop = nn.Dropout(dropout) |
|
|
| def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor: |
| x = self.drop(x) |
| x = F.relu(self.norm1(self.gc1(x, adj))) |
| x = self.drop(x) |
| x = F.relu(self.norm2(self.gc2(x, adj))) |
| x = self.drop(x) |
| return self.head(x) |
|
|
| @torch.no_grad() |
| def embed(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor: |
| """Return post-GC2 embeddings for t-SNE / analysis.""" |
| x = self.gc1(x, adj) |
| x = F.relu(self.norm1(x)) |
| x = self.gc2(x, adj) |
| return F.relu(self.norm2(x)) |
|
|