Yatsuiii's picture
Upload folder using huggingface_hub
16d6869 verified
"""
Brain GCN model definitions.
v2 changes:
- TwoLayerGCN with residual connection replaces single GraphLinear in encoder
- DropEdge applied in BrainGCNClassifier.forward() during training
- GraphOnlyClassifier also upgraded to TwoLayerGCN (was already 2-layer but
without residual or LayerNorm between layers)
"""
from __future__ import annotations
import torch
from torch import nn
from brain_gcn.utils.graph_conv import calculate_laplacian_with_self_loop, drop_edge
from brain_gcn.utils.grl import GradientReversal
# ---------------------------------------------------------------------------
# Building blocks
# ---------------------------------------------------------------------------
class GraphLinear(nn.Module):
"""Apply normalized adjacency, then a learned linear projection."""
def __init__(self, in_features: int, out_features: int, bias: bool = True):
super().__init__()
self.linear = nn.Linear(in_features, out_features, bias=bias)
def forward(self, x: torch.Tensor, adj_norm: torch.Tensor) -> torch.Tensor:
x = torch.bmm(adj_norm, x)
return self.linear(x)
class TwoLayerGCN(nn.Module):
"""2-layer GCN with residual skip connection.
Architecture (Kipf & Welling 2017 + He et al. 2016 residuals):
h1 = ReLU(LayerNorm(GCN1(x)))
h2 = Dropout(ReLU(LayerNorm(GCN2(h1))))
out = h2 + skip(x) # skip is a plain linear projection
The residual stabilises gradient flow and lets the model interpolate
between 1-hop and 2-hop aggregation.
"""
def __init__(self, in_dim: int, hidden_dim: int, dropout: float = 0.1):
super().__init__()
self.gcn1 = GraphLinear(in_dim, hidden_dim)
self.gcn2 = GraphLinear(hidden_dim, hidden_dim)
self.skip = nn.Linear(in_dim, hidden_dim, bias=False)
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
self.drop = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, adj_norm: torch.Tensor) -> torch.Tensor:
h = torch.relu(self.norm1(self.gcn1(x, adj_norm)))
h = self.drop(torch.relu(self.norm2(self.gcn2(h, adj_norm))))
return h + self.skip(x) # residual
# ---------------------------------------------------------------------------
# Encoders
# ---------------------------------------------------------------------------
class GraphTemporalEncoder(nn.Module):
"""Graph-aware temporal encoder for ROI-level window sequences.
Supports two node feature modes:
- Scalar (in_features=1): bold_windows (B, W, N) — BOLD std per window
- FC rows (in_features=N): fc_windows (B, W, N, N) — connectivity profile per node
Vectorized implementation: single batched GCN pass over all windows.
"""
def __init__(self, hidden_dim: int = 64, dropout: float = 0.1, in_features: int = 1):
super().__init__()
self.input_graph = TwoLayerGCN(in_features, hidden_dim, dropout=min(dropout, 0.1))
self.gru = nn.GRU(
input_size=hidden_dim,
hidden_size=hidden_dim,
batch_first=True,
)
self.norm = nn.LayerNorm(hidden_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, bold_windows: torch.Tensor, adj_norm: torch.Tensor) -> torch.Tensor:
# bold_windows: (B, W, N) for scalar features or (B, W, N, N) for FC-row features
if bold_windows.dim() == 4:
# FC-row features: (B, W, N, N) → (B*W, N, N) where last dim is in_features
batch_size, num_windows, num_nodes, _ = bold_windows.shape
x = bold_windows.reshape(batch_size * num_windows, num_nodes, -1)
else:
# Scalar features: (B, W, N) → (B*W, N, 1)
batch_size, num_windows, num_nodes = bold_windows.shape
x = bold_windows.reshape(batch_size * num_windows, num_nodes, 1)
# Handle both 3D (B,N,N) and 4D (B,W,N,N) adjacency
if adj_norm.dim() == 4:
adj_flat = adj_norm.reshape(batch_size * num_windows, num_nodes, num_nodes)
else:
adj_flat = adj_norm.unsqueeze(1).expand(-1, num_windows, -1, -1)
adj_flat = adj_flat.reshape(batch_size * num_windows, num_nodes, num_nodes)
# Single batched GCN pass → (B*W, N, H)
h = self.input_graph(x, adj_flat)
# Reshape back and apply node-major GRU
h = h.reshape(batch_size, num_windows, num_nodes, -1) # (B, W, N, H)
hidden_dim = h.shape[-1]
h = h.permute(0, 2, 1, 3).reshape(batch_size * num_nodes, num_windows, hidden_dim)
h, _ = self.gru(h)
h = h[:, -1, :].reshape(batch_size, num_nodes, -1) # (B, N, H)
return self.dropout(self.norm(h))
class AttentionReadout(nn.Module):
"""Learn per-ROI attention weights for subject-level graph pooling.
Single linear projection is sufficient for N=200 nodes.
More interpretable and faster than 2-layer MLP.
"""
def __init__(self, hidden_dim: int):
super().__init__()
self.score = nn.Linear(hidden_dim, 1)
def forward(self, node_embeddings: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
weights = torch.softmax(self.score(node_embeddings).squeeze(-1), dim=-1)
pooled = torch.sum(node_embeddings * weights.unsqueeze(-1), dim=1)
return pooled, weights
# ---------------------------------------------------------------------------
# Helpers shared across classifiers
# ---------------------------------------------------------------------------
def make_classifier_head(hidden_dim: int, num_classes: int, dropout: float) -> nn.Sequential:
return nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, num_classes),
)
def graph_readout(
node_embeddings: torch.Tensor,
attention: AttentionReadout | None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if attention is None:
return node_embeddings.mean(dim=1), None
return attention(node_embeddings)
# ---------------------------------------------------------------------------
# Classifiers
# ---------------------------------------------------------------------------
class BrainGCNClassifier(nn.Module):
"""Subject-level ASD/TD classifier for dynamic brain connectivity.
v2: TwoLayerGCN encoder + DropEdge during training.
"""
def __init__(
self,
hidden_dim: int = 64,
num_classes: int = 2,
dropout: float = 0.5,
readout: str = "attention",
drop_edge_p: float = 0.1,
in_features: int = 1,
):
super().__init__()
if readout not in {"mean", "attention"}:
raise ValueError("readout must be 'mean' or 'attention'")
self.encoder = GraphTemporalEncoder(hidden_dim=hidden_dim, dropout=min(dropout, 0.2), in_features=in_features)
self.readout = readout
self.attention = AttentionReadout(hidden_dim) if readout == "attention" else None
self.head = make_classifier_head(hidden_dim, num_classes, dropout)
self.drop_edge_p = drop_edge_p
def forward(
self,
bold_windows: torch.Tensor,
adj: torch.Tensor,
return_attention: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
# DropEdge: applied before Laplacian normalisation, training only
adj = drop_edge(adj, p=self.drop_edge_p, training=self.training)
adj_norm = calculate_laplacian_with_self_loop(adj)
node_embeddings = self.encoder(bold_windows, adj_norm)
pooled, attention_weights = graph_readout(node_embeddings, self.attention)
logits = self.head(pooled)
if return_attention:
return logits, attention_weights
return logits
class GraphOnlyClassifier(nn.Module):
"""GCN baseline — each ROI's average window signal as node input.
v2: upgraded to TwoLayerGCN with residual + DropEdge.
"""
def __init__(
self,
hidden_dim: int = 64,
num_classes: int = 2,
dropout: float = 0.5,
readout: str = "attention",
drop_edge_p: float = 0.1,
):
super().__init__()
if readout not in {"mean", "attention"}:
raise ValueError("readout must be 'mean' or 'attention'")
self.gcn = TwoLayerGCN(1, hidden_dim, dropout=min(dropout, 0.1))
self.norm = nn.LayerNorm(hidden_dim)
self.dropout = nn.Dropout(dropout)
self.attention = AttentionReadout(hidden_dim) if readout == "attention" else None
self.head = make_classifier_head(hidden_dim, num_classes, dropout)
self.drop_edge_p = drop_edge_p
def forward(
self,
bold_windows: torch.Tensor,
adj: torch.Tensor,
return_attention: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
adj = drop_edge(adj, p=self.drop_edge_p, training=self.training)
adj_norm = calculate_laplacian_with_self_loop(adj)
if adj_norm.dim() == 4:
adj_norm = adj_norm.mean(dim=1)
x = bold_windows.mean(dim=1).unsqueeze(-1) # (B, N, 1)
x = self.dropout(self.norm(self.gcn(x, adj_norm)))
pooled, attention_weights = graph_readout(x, self.attention)
logits = self.head(pooled)
if return_attention:
return logits, attention_weights
return logits
class TemporalGRUClassifier(nn.Module):
"""Temporal baseline — GRU over ROI vectors, no graph message passing."""
def __init__(
self,
hidden_dim: int = 64,
num_classes: int = 2,
dropout: float = 0.5,
):
super().__init__()
self.input_proj = nn.LazyLinear(hidden_dim)
self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
self.norm = nn.LayerNorm(hidden_dim)
self.dropout = nn.Dropout(dropout)
self.head = make_classifier_head(hidden_dim, num_classes, dropout)
def forward(
self,
bold_windows: torch.Tensor,
adj: torch.Tensor,
return_attention: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, None]:
x = torch.relu(self.input_proj(bold_windows))
x, _ = self.gru(x)
x = self.dropout(self.norm(x[:, -1, :]))
logits = self.head(x)
if return_attention:
return logits, None
return logits
class ConnectivityMLPClassifier(nn.Module):
"""Static FC baseline — upper triangle of adjacency matrix as features."""
def __init__(
self,
hidden_dim: int = 64,
num_classes: int = 2,
dropout: float = 0.5,
):
super().__init__()
self.net = nn.Sequential(
nn.LazyLinear(hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, num_classes),
)
@staticmethod
def _fc_features(adj: torch.Tensor) -> torch.Tensor:
"""Extract features from adj tensor (various shapes):
(B, N, N) → (B, N*(N-1)/2) signed mean FC upper triangle
(B, 2, N, N) → (B, N*(N-1)) mean FC || std FC concatenated
(B, 1, K) → (B, K) pre-computed PCA features (pass-through)
(B, W, N, N) → (B, N*(N-1)/2) dynamic seq: averaged over windows first
"""
if adj.dim() == 3:
if adj.size(1) == 1:
# PCA projection already computed in dataset — just flatten
return adj.squeeze(1) # (B, K)
# (B, N, N) — standard case
row, col = torch.triu_indices(adj.size(-2), adj.size(-1), offset=1,
device=adj.device)
return adj[:, row, col] # (B, 19900)
if adj.dim() == 4:
if adj.size(1) == 2:
# [mean_fc, std_fc] channels
row, col = torch.triu_indices(adj.size(-2), adj.size(-1), offset=1,
device=adj.device)
x_mean = adj[:, 0, row, col]
x_std = adj[:, 1, row, col]
return torch.cat([x_mean, x_std], dim=-1) # (B, 2*19900)
# Dynamic window sequence: average then extract
adj = adj.mean(dim=1) # (B, N, N)
row, col = torch.triu_indices(adj.size(-2), adj.size(-1), offset=1,
device=adj.device)
return adj[:, row, col]
raise ValueError(f"Unexpected adj shape: {tuple(adj.shape)}")
def forward(
self,
bold_windows: torch.Tensor,
adj: torch.Tensor,
return_attention: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, None]:
x = self._fc_features(adj)
logits = self.net(x)
if return_attention:
return logits, None
return logits
class BrainModeNetwork(nn.Module):
"""
Novel architecture: Brain Mode Network (BMN).
Learns K 'brain modes' — directions in ROI space (v_k ∈ R^N).
Projects the N×N FC matrix into a compact K×K 'mode interaction matrix':
M_kl = v_k^T · FC · v_l
Diagonal M_kk measures connectivity energy along mode k (Rayleigh quotient).
Off-diagonal M_kl captures cross-mode coupling between networks.
With K=16 modes and N=200 ROIs: 136 features instead of 19,900.
Inductive bias: each mode can specialize to a brain network community
(e.g. DMN, FPN, SMN) — the model learns which communities matter for ASD.
Orthogonality regularization keeps modes diverse (callable via
orthogonality_loss(), weight controlled externally in the training task).
"""
def __init__(
self,
num_nodes: int,
num_modes: int = 16,
hidden_dim: int = 64,
num_classes: int = 2,
dropout: float = 0.5,
mode_init: torch.Tensor | None = None,
):
super().__init__()
self.num_modes = num_modes
self.num_nodes = num_nodes
# Learnable modes: K × N — default initialization is near-orthonormal via QR.
# Caller may pass a (K, N) tensor from discriminative_init() instead.
if mode_init is not None:
modes_init = mode_init.clone().float()
else:
modes_init_np = torch.randn(num_nodes, num_modes)
Q, _ = torch.linalg.qr(modes_init_np) # (N, K) orthonormal columns
modes_init = Q.T.contiguous() # (K, N)
self.modes = nn.Parameter(modes_init)
# Features: K(K+1)/2 from static M + K from temporal std(A_k)
num_fc_features = num_modes * (num_modes + 1) // 2
num_total_features = num_fc_features + num_modes # static + dynamic
self.classifier = nn.Sequential(
nn.LayerNorm(num_total_features),
nn.Linear(num_total_features, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, num_classes),
)
def forward(
self,
bold_windows: torch.Tensor,
adj: torch.Tensor,
return_attention: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, None]:
# adj: (B, N, N) signed FC matrix; also accept (B, W, N, N) → avg over W
if adj.dim() == 4:
adj = adj.mean(dim=1) # (B, N, N)
# ── Static stream: mode interaction matrix ──────────────────────────
# M_kl = v_k^T · FC · v_l → (B, K, K)
M = torch.einsum('kn,bnm,lm->bkl', self.modes, adj, self.modes)
# Extract upper triangle (including diagonal): K(K+1)/2 features
r, c = torch.triu_indices(self.num_modes, self.num_modes,
offset=0, device=adj.device)
fc_features = M[:, r, c] # (B, K(K+1)/2)
# ── Dynamic stream: temporal variability of mode activity ───────────
# A_k(t) = v_k · bold(t) → A: (B, W, K)
# std(A_k) captures how much each network fluctuates over time.
# This is genuinely new information not present in static mean FC.
A = torch.einsum('kn,bwn->bwk', self.modes, bold_windows) # (B, W, K)
dyn_features = A.std(dim=1) # (B, K)
features = torch.cat([fc_features, dyn_features], dim=-1) # (B, K(K+1)/2+K)
logits = self.classifier(features)
if return_attention:
return logits, None
return logits
def orthogonality_loss(self) -> torch.Tensor:
"""Penalise non-orthonormal modes: ||V_norm @ V_norm^T - I||_F^2 / K^2.
Encourages each mode to capture a distinct connectivity direction.
Dividing by K^2 keeps the loss scale independent of num_modes.
"""
V_norm = self.modes / (self.modes.norm(dim=1, keepdim=True) + 1e-8)
gram = V_norm @ V_norm.T # (K, K)
I = torch.eye(self.num_modes, device=gram.device, dtype=gram.dtype)
return ((gram - I) ** 2).mean()
@staticmethod
def discriminative_init(
train_fc_asd: "np.ndarray",
train_fc_td: "np.ndarray",
num_modes: int,
) -> "torch.Tensor":
"""Initialize modes from SVD of the ASD-TD mean FC difference matrix.
The k-th left singular vector of (mean_FC_ASD − mean_FC_TD) is the k-th
most discriminative direction in ROI space — the direction along which the
two classes differ most. Starting here gives the optimizer a head start
and reduces the number of epochs needed to learn discriminative modes.
Parameters
----------
train_fc_asd : (n_asd, N, N) FC matrices for ASD training subjects
train_fc_td : (n_td, N, N) FC matrices for TD training subjects
num_modes : K — number of singular vectors to keep
Returns
-------
modes : (K, N) float32 tensor — orthonormal initial modes
"""
import numpy as np
mu_asd = train_fc_asd.mean(axis=0) # (N, N)
mu_td = train_fc_td.mean(axis=0) # (N, N)
delta = mu_asd - mu_td # ASD-TD difference
# SVD of the difference matrix: left singular vectors are ROI directions
# that best explain the connectivity difference between groups.
U, _, _ = np.linalg.svd(delta, full_matrices=True)
K = min(num_modes, U.shape[1])
modes = U[:, :K].T.astype(np.float32) # (K, N)
# If K > available singular vectors (shouldn't happen for N=200, K<<200),
# pad with QR-orthogonalized random directions
if num_modes > K:
extra = np.random.randn(num_modes - K, U.shape[0]).astype(np.float32)
for i in range(len(extra)):
for row in modes:
extra[i] -= np.dot(extra[i], row) * row
n = np.linalg.norm(extra[i])
if n > 1e-8:
extra[i] /= n
modes = np.concatenate([modes, extra], axis=0)
return torch.from_numpy(modes)
class AdversarialBrainModeNetwork(nn.Module):
"""Brain Mode Network with adversarial site deconfounding.
Combines the compact mode-interaction representation of BrainModeNetwork
with the Gradient Reversal Layer (GRL) of Ganin et al. 2016 to push
the learned modes towards site-invariant directions.
Architecture:
bold_windows, FC
→ mode interaction M_kl = v_k^T · FC · v_l (K×K)
→ flatten upper triangle + temporal std (K(K+1)/2 + K features)
→ shared_encoder (MLP)
↙ ↘
asd_head grl(α) → site_head
(minimize ASD CE) (modes unlearn scanner fingerprint)
The discriminative_init() classmethod inherited from BrainModeNetwork
still applies — we start from ASD-TD difference directions and then
adversarially remove site confounds while preserving diagnosis signal.
"""
def __init__(
self,
num_nodes: int,
num_modes: int = 32,
hidden_dim: int = 64,
num_classes: int = 2,
num_sites: int = 17,
dropout: float = 0.5,
mode_init: "torch.Tensor | None" = None,
):
super().__init__()
self.num_modes = num_modes
self.num_nodes = num_nodes
# Shared mode parameters (same as BrainModeNetwork)
if mode_init is not None:
modes_init = mode_init.clone().float()
else:
modes_init_np = torch.randn(num_nodes, num_modes)
Q, _ = torch.linalg.qr(modes_init_np)
modes_init = Q.T.contiguous()
self.modes = nn.Parameter(modes_init)
num_fc_features = num_modes * (num_modes + 1) // 2
num_total_features = num_fc_features + num_modes # static + dynamic
# Shared encoder
self.encoder = nn.Sequential(
nn.LayerNorm(num_total_features),
nn.Linear(num_total_features, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
)
# ASD head
self.asd_head = nn.Linear(hidden_dim, num_classes)
# Adversarial site branch
self.grl = GradientReversal(alpha=0.0)
self.site_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, num_sites),
)
def _encode(self, bold_windows: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
"""Compute mode features and pass through shared encoder."""
if adj.dim() == 4:
adj = adj.mean(dim=1)
M = torch.einsum('kn,bnm,lm->bkl', self.modes, adj, self.modes)
r, c = torch.triu_indices(self.num_modes, self.num_modes,
offset=0, device=adj.device)
fc_features = M[:, r, c]
A = torch.einsum('kn,bwn->bwk', self.modes, bold_windows)
dyn_features = A.std(dim=1)
features = torch.cat([fc_features, dyn_features], dim=-1)
return self.encoder(features)
def forward(
self,
bold_windows: torch.Tensor,
adj: torch.Tensor,
return_site_logits: bool = False,
) -> "torch.Tensor | tuple[torch.Tensor, torch.Tensor]":
h = self._encode(bold_windows, adj)
asd_logits = self.asd_head(h)
if return_site_logits:
site_logits = self.site_head(self.grl(h))
return asd_logits, site_logits
return asd_logits
def orthogonality_loss(self) -> torch.Tensor:
"""Identical to BrainModeNetwork.orthogonality_loss()."""
V_norm = self.modes / (self.modes.norm(dim=1, keepdim=True) + 1e-8)
gram = V_norm @ V_norm.T
I = torch.eye(self.num_modes, device=gram.device, dtype=gram.dtype)
return ((gram - I) ** 2).mean()
# Expose discriminative_init as a static method (same logic as BrainModeNetwork)
discriminative_init = BrainModeNetwork.discriminative_init
class AdversarialConnectivityMLP(nn.Module):
"""FC-based classifier with adversarial site deconfounding (Ganin et al. 2016).
Architecture:
FC upper triangle (signed)
→ shared_encoder # learns site-invariant features
↙ ↘
asd_head grl(α) → site_head
(minimize ASD CE) (encoder maximises site CE via reversed grads)
During training the encoder is pulled in two directions:
- Minimise ASD classification loss (learn diagnosis signal)
- Maximise site classification loss (unlearn scanner fingerprint)
alpha is annealed 0→1 via ganin_alpha() so site deconfounding
ramps up gradually after the ASD signal is first established.
"""
def __init__(
self,
hidden_dim: int = 256,
num_classes: int = 2,
num_sites: int = 17,
dropout: float = 0.5,
):
super().__init__()
# Shared encoder — LazyLinear handles variable FC input size
self.encoder = nn.Sequential(
nn.LazyLinear(hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
)
# ASD classification head
self.asd_head = nn.Linear(hidden_dim, num_classes)
# Site adversarial branch
self.grl = GradientReversal(alpha=0.0) # alpha set externally each epoch
self.site_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, num_sites),
)
def forward(
self,
bold_windows: torch.Tensor,
adj: torch.Tensor,
return_site_logits: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
x = ConnectivityMLPClassifier._fc_features(adj)
features = self.encoder(x)
asd_logits = self.asd_head(features)
if return_site_logits:
site_logits = self.site_head(self.grl(features))
return asd_logits, site_logits
return asd_logits
# ---------------------------------------------------------------------------
# Factory
# ---------------------------------------------------------------------------
def build_model(
model_name: str,
hidden_dim: int = 64,
num_classes: int = 2,
num_sites: int = 1,
num_nodes: int = 200,
num_modes: int = 16,
dropout: float = 0.5,
readout: str = "attention",
drop_edge_p: float = 0.1,
mode_init: "torch.Tensor | None" = None,
in_features: int = 1,
) -> nn.Module:
if model_name == "graph_temporal":
return BrainGCNClassifier(hidden_dim, num_classes, dropout, readout, drop_edge_p, in_features=in_features)
if model_name == "gcn":
return GraphOnlyClassifier(hidden_dim, num_classes, dropout, readout, drop_edge_p)
if model_name == "gru":
return TemporalGRUClassifier(hidden_dim, num_classes, dropout)
if model_name == "fc_mlp":
return ConnectivityMLPClassifier(hidden_dim, num_classes, dropout)
if model_name == "adv_fc_mlp":
return AdversarialConnectivityMLP(hidden_dim, num_classes, num_sites, dropout)
if model_name == "dynamic_fc_attn":
from brain_gcn.models.dynamic_fc import DynamicFCAttention
return DynamicFCAttention(
num_rois=num_nodes,
hidden_dim=hidden_dim,
dropout=dropout,
)
if model_name == "brain_mode":
return BrainModeNetwork(num_nodes, num_modes, hidden_dim, num_classes, dropout,
mode_init=mode_init)
if model_name == "adv_brain_mode":
return AdversarialBrainModeNetwork(num_nodes, num_modes, hidden_dim, num_classes,
num_sites, dropout, mode_init=mode_init)
# Advanced models — lazy import to avoid circular dependency
from brain_gcn.models.advanced_models import (
GATClassifier, TransformerClassifier, CNN3DClassifier, GraphSAGEClassifier,
)
if model_name == "gat":
return GATClassifier(hidden_dim, dropout=dropout)
if model_name == "transformer":
return TransformerClassifier(hidden_dim, dropout=dropout)
if model_name == "cnn3d":
return CNN3DClassifier(hidden_dim, dropout=dropout)
if model_name == "graphsage":
return GraphSAGEClassifier(hidden_dim, dropout=dropout)
raise ValueError(f"Unknown model_name: {model_name}")