File size: 2,327 Bytes
16d6869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""
Population-level GCN for subject-level ASD/TD classification.

All subjects are nodes in a single graph — transductive setting.
The model sees all node features (including unlabeled val/test subjects)
during forward passes; loss is masked to training nodes only.
"""

from __future__ import annotations

import torch
import torch.nn.functional as F
from torch import nn


class GraphConv(nn.Module):
    """Single graph convolution: linear projection after neighborhood aggregation."""

    def __init__(self, in_dim: int, out_dim: int, bias: bool = True):
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim, bias=bias)

    def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
        # adj: pre-normalized (N, N); x: (N, in_dim)
        return self.linear(adj @ x)


class PopulationGCN(nn.Module):
    """2-layer GCN on the subject population graph.

    Architecture
    ============
    Input → Dropout → GC1 → LayerNorm → ReLU
          → Dropout → GC2 → LayerNorm → ReLU
          → Dropout → Linear → logits (N, num_classes)

    Depth 2 is sufficient: each node aggregates 2-hop neighbors,
    covering subjects with similar age+sex across the whole cohort.
    """

    def __init__(
        self,
        in_dim: int,
        hidden_dim: int = 64,
        num_classes: int = 2,
        dropout: float = 0.5,
    ):
        super().__init__()
        self.gc1   = GraphConv(in_dim, hidden_dim)
        self.gc2   = GraphConv(hidden_dim, hidden_dim)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.head  = nn.Linear(hidden_dim, num_classes)
        self.drop  = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
        x = self.drop(x)
        x = F.relu(self.norm1(self.gc1(x, adj)))
        x = self.drop(x)
        x = F.relu(self.norm2(self.gc2(x, adj)))
        x = self.drop(x)
        return self.head(x)                           # (N, num_classes)

    @torch.no_grad()
    def embed(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
        """Return post-GC2 embeddings for t-SNE / analysis."""
        x = self.gc1(x, adj)
        x = F.relu(self.norm1(x))
        x = self.gc2(x, adj)
        return F.relu(self.norm2(x))