BrainConnect-ASD / brain_gcn /models /advanced_models.py
Yatsuiii's picture
Upload folder using huggingface_hub
16d6869 verified
"""
Advanced model architectures for brain connectivity analysis.
New models:
- Graph Attention Networks (GAT)
- Transformer-based temporal encoder
- 3D-CNN for spatiotemporal features
- GraphSAGE (sampling-aggregating)
"""
from __future__ import annotations
import torch
from torch import nn
import torch.nn.functional as F
from brain_gcn.utils.graph_conv import calculate_laplacian_with_self_loop, drop_edge
from brain_gcn.models.brain_gcn import AttentionReadout
# ---------------------------------------------------------------------------
# Graph Attention Networks (GAT)
# ---------------------------------------------------------------------------
class GraphAttentionLayer(nn.Module):
"""Multi-head graph attention layer."""
def __init__(self, in_dim: int, out_dim: int, num_heads: int = 4, dropout: float = 0.1):
super().__init__()
self.num_heads = num_heads
self.out_dim = out_dim
assert out_dim % num_heads == 0, "out_dim must be divisible by num_heads"
self.head_dim = out_dim // num_heads
self.query = nn.Linear(in_dim, out_dim)
self.key = nn.Linear(in_dim, out_dim)
self.value = nn.Linear(in_dim, out_dim)
self.fc_out = nn.Linear(out_dim, out_dim)
self.dropout = nn.Dropout(dropout)
self.scale = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
# x: (batch, nodes, in_dim)
# adj: (batch, nodes, nodes) or (nodes, nodes)
Q = self.query(x) # (batch, nodes, out_dim)
K = self.key(x)
V = self.value(x)
# Reshape for multi-head: (batch, nodes, heads, head_dim)
Q = Q.reshape(Q.shape[0], Q.shape[1], self.num_heads, self.head_dim).transpose(1, 2)
K = K.reshape(K.shape[0], K.shape[1], self.num_heads, self.head_dim).transpose(1, 2)
V = V.reshape(V.shape[0], V.shape[1], self.num_heads, self.head_dim).transpose(1, 2)
# Attention scores: (batch, heads, nodes, nodes)
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
# Mask non-edges with large negative value (binary mask, not value-based)
scores = scores + (adj.unsqueeze(1) == 0).float() * -1e9
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
# Apply attention to values
out = torch.matmul(attn, V) # (batch, heads, nodes, head_dim)
out = out.transpose(1, 2).reshape(out.shape[0], out.shape[2], -1) # (batch, nodes, out_dim)
return self.fc_out(out)
class GATEncoder(nn.Module):
"""Multi-layer Graph Attention Network."""
def __init__(self, in_dim: int, hidden_dim: int, num_heads: int = 4, dropout: float = 0.1):
super().__init__()
self.layer1 = GraphAttentionLayer(in_dim, hidden_dim, num_heads=num_heads, dropout=dropout)
self.layer2 = GraphAttentionLayer(hidden_dim, hidden_dim, num_heads=num_heads, dropout=dropout)
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
h = self.layer1(x, adj)
h = self.dropout(F.relu(self.norm1(h)))
h = self.layer2(h, adj)
h = self.dropout(F.relu(self.norm2(h)))
return h
# ---------------------------------------------------------------------------
# Transformer-based Temporal Encoder
# ---------------------------------------------------------------------------
class TransformerTemporalEncoder(nn.Module):
"""Transformer-based encoder for temporal sequences."""
def __init__(self, hidden_dim: int = 64, num_heads: int = 4, num_layers: int = 2, dropout: float = 0.1):
super().__init__()
self.embedding = nn.Linear(1, hidden_dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=hidden_dim * 4,
dropout=dropout,
batch_first=True,
activation='relu',
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.norm = nn.LayerNorm(hidden_dim)
def forward(self, bold_windows: torch.Tensor) -> torch.Tensor:
# bold_windows: (batch, windows, nodes) → embed → (batch * nodes, windows, hidden_dim)
batch, windows, nodes = bold_windows.shape
# Embed time dimension
x = bold_windows.permute(0, 2, 1).reshape(batch * nodes, windows, 1) # (B*N, W, 1)
x = self.embedding(x) # (B*N, W, hidden_dim)
# Transformer
h = self.transformer(x) # (B*N, W, hidden_dim)
h = self.norm(h)
h = h[:, -1, :] # Take last token
h = h.reshape(batch, nodes, -1) # (B, N, hidden_dim)
return h
# ---------------------------------------------------------------------------
# 3D-CNN for Spatiotemporal Features
# ---------------------------------------------------------------------------
class CNN3D(nn.Module):
"""3D-CNN for spatiotemporal brain connectivity analysis."""
def __init__(self, hidden_dim: int = 64, dropout: float = 0.1):
super().__init__()
# Input: (batch, 1, time, height, width) for connectivity matrices
# Scale intermediate channels relative to hidden_dim
ch1 = max(8, hidden_dim // 4)
ch2 = max(16, hidden_dim // 2)
self.conv1 = nn.Conv3d(1, ch1, kernel_size=(3, 3, 3), padding=(1, 1, 1))
self.conv2 = nn.Conv3d(ch1, ch2, kernel_size=(3, 3, 3), padding=(1, 1, 1))
self.conv3 = nn.Conv3d(ch2, hidden_dim, kernel_size=(3, 3, 3), padding=(1, 1, 1))
self.pool = nn.MaxPool3d(kernel_size=2, stride=2)
self.dropout = nn.Dropout3d(dropout)
self.norm1 = nn.BatchNorm3d(ch1)
self.norm2 = nn.BatchNorm3d(ch2)
self.norm3 = nn.BatchNorm3d(hidden_dim)
def forward(self, fc_windows: torch.Tensor) -> torch.Tensor:
# fc_windows: (batch, windows, nodes, nodes)
batch, windows, nodes, _ = fc_windows.shape
# Add channel dimension: (batch, 1, windows, nodes, nodes)
x = fc_windows.unsqueeze(1)
x = self.conv1(x)
x = self.norm1(x)
x = F.relu(x)
x = self.pool(x)
x = self.dropout(x)
x = self.conv2(x)
x = self.norm2(x)
x = F.relu(x)
x = self.pool(x)
x = self.dropout(x)
x = self.conv3(x)
x = self.norm3(x)
x = F.relu(x)
# Global average pooling
x = x.mean(dim=(2, 3, 4)) # (batch, channels)
return x
# ---------------------------------------------------------------------------
# GraphSAGE (Sampling and Aggregating)
# ---------------------------------------------------------------------------
class GraphSAGELayer(nn.Module):
"""GraphSAGE layer using mean aggregation."""
def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.1):
super().__init__()
self.agg_weight = nn.Linear(in_dim, out_dim)
self.self_weight = nn.Linear(in_dim, out_dim)
self.norm = nn.LayerNorm(out_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
# x: (batch, nodes, in_dim)
# adj: (batch, nodes, nodes) or (nodes, nodes)
# Aggregate neighbors: (batch, nodes, in_dim)
if adj.dim() == 2:
adj = adj.unsqueeze(0)
# Normalize adjacency for aggregation
degree = adj.sum(dim=-1, keepdim=True).clamp(min=1)
adj_norm = adj / degree
neighbor_agg = torch.bmm(adj_norm, x) # (batch, nodes, in_dim)
# Combine self and aggregated neighbor features
h_agg = self.agg_weight(neighbor_agg)
h_self = self.self_weight(x)
h = h_agg + h_self
h = F.relu(self.norm(h))
h = self.dropout(h)
return h
class GraphSAGEEncoder(nn.Module):
"""Multi-layer GraphSAGE encoder."""
def __init__(self, in_dim: int, hidden_dim: int, dropout: float = 0.1):
super().__init__()
self.layer1 = GraphSAGELayer(in_dim, hidden_dim, dropout=dropout)
self.layer2 = GraphSAGELayer(hidden_dim, hidden_dim, dropout=dropout)
def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
h = self.layer1(x, adj)
h = self.layer2(h, adj)
return h
# ---------------------------------------------------------------------------
# Classifier Heads
# ---------------------------------------------------------------------------
def make_head(hidden_dim: int, num_classes: int = 2, dropout: float = 0.5) -> nn.Sequential:
return nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, num_classes),
)
# ---------------------------------------------------------------------------
# Complete Models
# ---------------------------------------------------------------------------
class GATClassifier(nn.Module):
"""Graph Attention Network classifier."""
def __init__(self, hidden_dim: int = 64, num_heads: int = 4, dropout: float = 0.5):
super().__init__()
self.encoder = GATEncoder(1, hidden_dim, num_heads=num_heads, dropout=min(dropout, 0.2))
self.attention = AttentionReadout(hidden_dim)
self.head = make_head(hidden_dim, dropout=dropout)
def forward(self, bold_windows: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
batch, windows, nodes = bold_windows.shape
# Process each window
embeddings_list = []
adj_norm = calculate_laplacian_with_self_loop(adj)
for w in range(windows):
x = bold_windows[:, w, :].unsqueeze(-1) # (batch, nodes, 1)
if adj_norm.dim() == 3:
adj_w = adj_norm
else:
adj_w = adj_norm.unsqueeze(0)
h = self.encoder(x, adj_w)
embeddings_list.append(h)
# Average over windows
h = torch.stack(embeddings_list, dim=1).mean(dim=1) # (batch, nodes, hidden_dim)
pooled, _ = self.attention(h)
logits = self.head(pooled)
return logits
class TransformerClassifier(nn.Module):
"""Transformer-based classifier for temporal brain signals."""
def __init__(self, hidden_dim: int = 64, num_heads: int = 4, dropout: float = 0.5):
super().__init__()
self.temporal_encoder = TransformerTemporalEncoder(hidden_dim, num_heads=num_heads, dropout=min(dropout, 0.2))
self.attention = AttentionReadout(hidden_dim)
self.head = make_head(hidden_dim, dropout=dropout)
def forward(self, bold_windows: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
h = self.temporal_encoder(bold_windows) # (batch, nodes, hidden_dim)
pooled, _ = self.attention(h)
logits = self.head(pooled)
return logits
class CNN3DClassifier(nn.Module):
"""3D-CNN classifier for connectivity dynamics."""
def __init__(self, hidden_dim: int = 64, dropout: float = 0.5):
super().__init__()
self.cnn = CNN3D(hidden_dim, dropout=min(dropout, 0.2))
self.head = make_head(hidden_dim, dropout=dropout)
def forward(self, bold_windows: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
if adj.dim() == 4:
# Dynamic adjacency (B, W, N, N) — use directly
fc_windows = adj
else:
# Static adjacency (B, N, N) — replicate across windows
W = bold_windows.shape[1]
fc_windows = adj.unsqueeze(1).expand(-1, W, -1, -1)
h = self.cnn(fc_windows) # (batch, 64)
logits = self.head(h)
return logits
class GraphSAGEClassifier(nn.Module):
"""GraphSAGE-based classifier."""
def __init__(self, hidden_dim: int = 64, dropout: float = 0.5):
super().__init__()
self.encoder = GraphSAGEEncoder(1, hidden_dim, dropout=min(dropout, 0.2))
self.attention = AttentionReadout(hidden_dim)
self.head = make_head(hidden_dim, dropout=dropout)
def forward(self, bold_windows: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
batch, windows, nodes = bold_windows.shape
adj_norm = calculate_laplacian_with_self_loop(adj)
embeddings_list = []
for w in range(windows):
x = bold_windows[:, w, :].unsqueeze(-1) # (batch, nodes, 1)
if adj_norm.dim() == 3:
adj_w = adj_norm
else:
adj_w = adj_norm.unsqueeze(0)
h = self.encoder(x, adj_w)
embeddings_list.append(h)
h = torch.stack(embeddings_list, dim=1).mean(dim=1)
pooled, _ = self.attention(h)
logits = self.head(pooled)
return logits