BrainConnect-ASD / brain_gcn /models /population_gcn.py
Yatsuiii's picture
Upload folder using huggingface_hub
16d6869 verified
"""
Population-level GCN for subject-level ASD/TD classification.
All subjects are nodes in a single graph — transductive setting.
The model sees all node features (including unlabeled val/test subjects)
during forward passes; loss is masked to training nodes only.
"""
from __future__ import annotations
import torch
import torch.nn.functional as F
from torch import nn
class GraphConv(nn.Module):
"""Single graph convolution: linear projection after neighborhood aggregation."""
def __init__(self, in_dim: int, out_dim: int, bias: bool = True):
super().__init__()
self.linear = nn.Linear(in_dim, out_dim, bias=bias)
def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
# adj: pre-normalized (N, N); x: (N, in_dim)
return self.linear(adj @ x)
class PopulationGCN(nn.Module):
"""2-layer GCN on the subject population graph.
Architecture
============
Input → Dropout → GC1 → LayerNorm → ReLU
→ Dropout → GC2 → LayerNorm → ReLU
→ Dropout → Linear → logits (N, num_classes)
Depth 2 is sufficient: each node aggregates 2-hop neighbors,
covering subjects with similar age+sex across the whole cohort.
"""
def __init__(
self,
in_dim: int,
hidden_dim: int = 64,
num_classes: int = 2,
dropout: float = 0.5,
):
super().__init__()
self.gc1 = GraphConv(in_dim, hidden_dim)
self.gc2 = GraphConv(hidden_dim, hidden_dim)
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
self.head = nn.Linear(hidden_dim, num_classes)
self.drop = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
x = self.drop(x)
x = F.relu(self.norm1(self.gc1(x, adj)))
x = self.drop(x)
x = F.relu(self.norm2(self.gc2(x, adj)))
x = self.drop(x)
return self.head(x) # (N, num_classes)
@torch.no_grad()
def embed(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
"""Return post-GC2 embeddings for t-SNE / analysis."""
x = self.gc1(x, adj)
x = F.relu(self.norm1(x))
x = self.gc2(x, adj)
return F.relu(self.norm2(x))