Yatsuiii's picture
Upload folder using huggingface_hub
16d6869 verified
"""
Brain Connectivity Masked Autoencoder (BC-MAE).
Architecture (He et al. MAE 2022, adapted for temporal FC windows):
Pre-training
─────────────
Input : (B, W, N) — per-window ROI connectivity strengths (mean |FC| per window)
Mask : random 50% of W windows are hidden
Encoder: Transformer on visible windows only → (B, W_vis, H)
Decoder: Lightweight Transformer on all positions (visible + mask tokens)
→ reconstruction head → (B, W, N)
Loss : MSE on masked windows only
Fine-tuning
────────────
Encoder (loaded from pre-training, optionally frozen)
+ attention pooling over all W windows
+ MLP classifier → (B, 2)
"""
from __future__ import annotations
import torch
import torch.nn.functional as F
from torch import nn
# ---------------------------------------------------------------------------
# Shared encoder
# ---------------------------------------------------------------------------
class BrainFCEncoder(nn.Module):
"""Transformer encoder operating on visible FC windows.
Each time window's ROI connectivity profile (N-dim) is treated as a
"patch" — analogous to image patches in ViT/MAE.
"""
def __init__(
self,
num_rois: int = 200,
num_windows: int = 30,
hidden_dim: int = 128,
num_heads: int = 4,
num_layers: int = 4,
dropout: float = 0.1,
):
super().__init__()
self.hidden_dim = hidden_dim
# Project each window's ROI features to hidden dim
self.patch_embed = nn.Linear(num_rois, hidden_dim)
# Learnable positional embedding — one per window position
self.pos_embed = nn.Parameter(torch.zeros(1, num_windows, hidden_dim))
nn.init.trunc_normal_(self.pos_embed, std=0.02)
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=hidden_dim * 4,
dropout=dropout,
batch_first=True,
norm_first=True,
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.norm = nn.LayerNorm(hidden_dim)
def forward(
self,
x: torch.Tensor,
ids_keep: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Parameters
----------
x : (B, W_visible, N) visible windows
ids_keep : (B, W_visible) original positions of visible windows
"""
B, W_vis, N = x.shape
# Project patches
x = self.patch_embed(x) # (B, W_vis, H)
# Add positional embeddings at the original positions
if ids_keep is not None:
pos = self.pos_embed.expand(B, -1, -1) # (B, W_all, H)
pos_vis = torch.gather(
pos, 1,
ids_keep.unsqueeze(-1).expand(-1, -1, self.hidden_dim) # (B, W_vis, H)
)
else:
pos_vis = self.pos_embed[:, :W_vis, :]
x = x + pos_vis
x = self.norm(self.transformer(x))
return x # (B, W_vis, H)
# ---------------------------------------------------------------------------
# MAE (pre-training)
# ---------------------------------------------------------------------------
class BrainMAE(nn.Module):
"""Masked Autoencoder for brain FC windows."""
def __init__(
self,
num_rois: int = 200,
num_windows: int = 30,
hidden_dim: int = 128,
decoder_dim: int = 64,
num_heads: int = 4,
encoder_layers: int = 4,
decoder_layers: int = 2,
dropout: float = 0.1,
mask_ratio: float = 0.5,
):
super().__init__()
self.num_windows = num_windows
self.num_rois = num_rois
self.mask_ratio = mask_ratio
self.hidden_dim = hidden_dim
self.decoder_dim = decoder_dim
# Encoder (shared with fine-tuning)
self.encoder = BrainFCEncoder(
num_rois=num_rois,
num_windows=num_windows,
hidden_dim=hidden_dim,
num_heads=num_heads,
num_layers=encoder_layers,
dropout=dropout,
)
# Project encoder output to decoder dim
self.enc_to_dec = nn.Linear(hidden_dim, decoder_dim, bias=False)
# Learnable mask token (broadcast across masked positions)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))
nn.init.trunc_normal_(self.mask_token, std=0.02)
# Decoder positional embedding (all W positions)
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_windows, decoder_dim))
nn.init.trunc_normal_(self.decoder_pos_embed, std=0.02)
# Lightweight decoder
dec_layer = nn.TransformerEncoderLayer(
d_model=decoder_dim,
nhead=max(1, decoder_dim // 32),
dim_feedforward=decoder_dim * 4,
dropout=dropout,
batch_first=True,
norm_first=True,
)
self.decoder = nn.TransformerEncoder(dec_layer, num_layers=decoder_layers)
self.decoder_norm = nn.LayerNorm(decoder_dim)
# Reconstruction head: predict ROI connectivity for each window
self.recon_head = nn.Linear(decoder_dim, num_rois)
# ------------------------------------------------------------------
def _random_masking(
self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Randomly mask windows. Returns visible subset, binary mask, restore indices."""
B, W, _ = x.shape
num_keep = int(W * (1 - self.mask_ratio))
# Random shuffle per sample
noise = torch.rand(B, W, device=x.device)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
ids_keep = ids_shuffle[:, :num_keep] # (B, num_keep)
x_vis = torch.gather(
x, 1,
ids_keep.unsqueeze(-1).expand(-1, -1, x.shape[-1]) # (B, num_keep, N)
)
# Binary mask: 1 = masked, 0 = visible
mask = torch.ones(B, W, device=x.device)
mask[:, :num_keep] = 0
mask = torch.gather(mask, 1, ids_restore)
return x_vis, mask, ids_restore, ids_keep
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for pre-training.
Returns
-------
loss : scalar MSE on masked windows
mask : (B, W) binary mask (1=masked) for logging
"""
B, W, N = x.shape
# Mask
x_vis, mask, ids_restore, ids_keep = self._random_masking(x)
# Encode visible
enc = self.encoder(x_vis, ids_keep=ids_keep) # (B, num_keep, H)
enc = self.enc_to_dec(enc) # (B, num_keep, D)
# Decode: reconstruct all W positions
# Fill masked positions with mask token
num_keep = enc.shape[1]
num_mask = W - num_keep
mask_tokens = self.mask_token.expand(B, num_mask, -1)
# Concatenate visible encoded + mask tokens, then unshuffle
full = torch.cat([enc, mask_tokens], dim=1) # (B, W, D)
full = torch.gather(
full, 1,
ids_restore.unsqueeze(-1).expand(-1, -1, self.decoder_dim)
)
# Add decoder positional embeddings and decode
full = full + self.decoder_pos_embed
dec = self.decoder_norm(self.decoder(full)) # (B, W, D)
# Reconstruct
pred = self.recon_head(dec) # (B, W, N)
# MSE loss on masked windows only
loss = (pred - x).pow(2).mean(dim=-1) # (B, W)
loss = (loss * mask).sum() / (mask.sum() + 1e-8)
return loss, mask
def encode_all(self, x: torch.Tensor) -> torch.Tensor:
"""Encode all W windows (no masking) for downstream tasks."""
return self.encoder(x) # (B, W, H)
# ---------------------------------------------------------------------------
# Fine-tuning classifier
# ---------------------------------------------------------------------------
class BrainFCClassifier(nn.Module):
"""ASD/TD classifier with pre-trained BC-MAE encoder.
Encoder can be frozen (linear probing) or fine-tuned end-to-end.
"""
def __init__(
self,
encoder: BrainFCEncoder,
hidden_dim: int = 128,
num_classes: int = 2,
dropout: float = 0.5,
freeze_encoder: bool = True,
):
super().__init__()
self.encoder = encoder
self.freeze_encoder = freeze_encoder
if freeze_encoder:
for p in self.encoder.parameters():
p.requires_grad_(False)
H = hidden_dim
# Attention pooling over time: which windows discriminate ASD?
self.time_attn = nn.Linear(H, 1)
# Classifier head
self.head = nn.Sequential(
nn.LayerNorm(H),
nn.Linear(H, H // 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(H // 2, num_classes),
)
def forward(
self,
x: torch.Tensor,
adj: torch.Tensor | None = None, # kept for interface compatibility
) -> torch.Tensor:
# x: (B, W, N)
if self.freeze_encoder:
with torch.no_grad():
enc = self.encoder(x) # (B, W, H)
else:
enc = self.encoder(x)
# Attention-weighted pooling over time
attn = torch.softmax(self.time_attn(enc).squeeze(-1), dim=1) # (B, W)
pooled = (enc * attn.unsqueeze(-1)).sum(dim=1) # (B, H)
return self.head(pooled)
def unfreeze_encoder(self) -> None:
for p in self.encoder.parameters():
p.requires_grad_(True)
self.freeze_encoder = False