File size: 3,480 Bytes
16d6869 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 | """
Dynamic FC Temporal Attention model for ASD/TD classification.
Architecture (STAGIN-inspired, simplified):
Input : (B, W, N) — per-window ROI connectivity strength (mean |FC| per ROI)
Step 1 : Linear projection N → H
Step 2 : Learnable positional encoding over W time steps
Step 3 : Transformer encoder (multi-head self-attention over windows)
Step 4 : Attention-weighted pooling over W → subject embedding (H,)
Step 5 : MLP classifier → 2
Why this works:
ASD shows altered *dynamic* connectivity — not just different mean FC but
different temporal patterns of connectivity fluctuation across brain states.
The self-attention learns which window combinations are most discriminative.
"""
from __future__ import annotations
import torch
import torch.nn.functional as F
from torch import nn
class DynamicFCAttention(nn.Module):
def __init__(
self,
num_rois: int = 200,
max_windows: int = 30,
hidden_dim: int = 128,
num_heads: int = 4,
num_layers: int = 2,
dropout: float = 0.5,
num_classes: int = 2,
):
super().__init__()
assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
# Project ROI connectivity strengths to hidden dim
self.input_proj = nn.Sequential(
nn.Linear(num_rois, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout * 0.5),
)
# Learnable positional encoding — one vector per window
self.pos_embed = nn.Parameter(torch.randn(1, max_windows, hidden_dim) * 0.02)
# Transformer encoder: self-attention over time windows
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=hidden_dim * 2,
dropout=dropout * 0.5,
batch_first=True,
norm_first=True, # pre-norm for stability
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# Attention pooling over time: learn which windows matter
self.time_attn = nn.Linear(hidden_dim, 1)
# Classifier head
self.head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, num_classes),
)
def forward(
self,
bold_windows: torch.Tensor,
adj: torch.Tensor | None = None, # unused — kept for interface compatibility
return_attention: bool = False,
) -> torch.Tensor:
# bold_windows: (B, W, N) — mean |FC| per ROI per time window
B, W, N = bold_windows.shape
# Project each window's ROI features to hidden dim
x = self.input_proj(bold_windows) # (B, W, H)
# Add positional encoding
x = x + self.pos_embed[:, :W, :]
# Self-attention over time windows
x = self.transformer(x) # (B, W, H)
# Attention-weighted pooling: which windows are most discriminative?
attn = torch.softmax(self.time_attn(x).squeeze(-1), dim=1) # (B, W)
embedding = (x * attn.unsqueeze(-1)).sum(dim=1) # (B, H)
logits = self.head(embedding)
if return_attention:
return logits, attn
return logits
|