# Flow Matching fMRI Encoder A two-stage architecture for predicting fMRI BOLD responses from naturalistic video stimuli (Friends TV show + Movie10 dataset) using Conditional Flow Matching. --- ## Overview The pipeline decodes brain activity in two sequential stages: | Stage | Model | Goal | |---|---|---| | **1** | `MultiSubjectConvLinearEncoder` | Predict a **Mean Anchor** — a deterministic per-voxel fMRI estimate shared across subjects | | **2** | `CFM` (Conditional Flow Matching) | Learn a **per-subject neural vector field** that refines the Mean Anchor into a sharper, stochastic fMRI prediction | The design mirrors the MedARC approach: Stage 1 provides a stable conditional mean $\mu$; Stage 2 integrates a continuous normalizing flow $\phi_t$ conditioned on $\mu$ to sample from the true posterior over voxel activations. --- ## Architecture ### Stage 1 — Mean Anchor (`medarc_architecture.py`) ``` Features (N, T, D_i) → DepthConv1d + Linear → embed (N, T, d) ↓ [Shared Decoder] + [Subject Decoders] ↓ fMRI Prediction (N, S, T, V) ``` - **`MultiSubjectConvLinearEncoder`**: Projects each feature stream through a `LinearConv` (depthwise conv + linear) to a shared embedding dimension `d = 192`. - Global average pooling across the feature stack combines multi-model features. - A shared linear decoder and per-subject linear decoders combine additively to predict `V` voxels for each of `S` subjects. - Trained with MSE loss against ground-truth BOLD. ### Stage 2 — Neural Vector Field (`matcha_architecture.py`) ``` x1 (B, V, T) → proj_in (V → d) → latent x1 (B, d, T) mu (B, V, T) → proj_in (V → d) → latent mu (B, d, T) ↓ OT-CFM loss (vector field u_t) Matcha-TTS U-Net estimator [Conformer / Transformer blocks] ↓ latent pred (B, d, T) → proj_out (d → V) → fMRI (B, V, T) ``` - **Latent Bottleneck**: fMRI voxels (`V ≈ 1000`) are projected down to a dense latent dimension (`d = 128`) before any convolution, reducing the first-layer parameter count from ~6M to ~98K and preventing gradient collapse. - **`CFM`** wraps a Matcha-TTS style **1D U-Net** (`Decoder`) with ResNet-1D blocks and Conformer/Transformer attention at each scale. - At inference, noise $z \sim \mathcal{N}(0,I)$ is integrated from $t=0$ to $t=1$ over 25 Euler steps conditioned on $\mu$. - An auxiliary reconstruction loss (weight `0.1`) on `proj_in → proj_out` trains the projection pair jointly with the vector field. --- ## Data | Source | Content | Usage | |---|---|---| | **Friends** (seasons 1–7) | fMRI BOLD + multimodal features | Train (S1), Val (S6, S7) | | **Movie10** | fMRI BOLD + multimodal features | Supplementary val (Figures, Life, Bourne, Wolf) | Subjects used: **1, 2, 3, 5**. ### Feature Models The encoder can ingest intermediate activations from any combination of: | Key | Model | |---|---| | `internvl3_8b`, `internvl3_14b` | InternVL3 vision-language model | | `qwen-2-5-omni-3b`, `qwen-2-5-omni-7b` | Qwen2.5-Omni audio-video model | | `whisper` | OpenAI Whisper (audio) | | `llama_3.2_1b`, `llama_3.2_3b` | LLaMA 3.2 (text) | | `vjepa2` | V-JEPA 2 (video) | Active features are set in `config.yml` under `include_features`. --- ## File Structure ``` flow_matching/ ├── config.yml # Main training config (GPU, full data) ├── debug_config.yml # Fast local debug config (CPU, tiny data) ├── environment.yml # Conda environment spec │ ├── src/ │ ├── training.py # Two-stage training loop + evaluation │ ├── matcha_architecture.py # CFM + Matcha-TTS U-Net decoder │ ├── medarc_architecture.py # Stage 1 MultiSubjectConvLinearEncoder │ ├── data.py # Algonauts2025 dataset + loaders │ ├── metric.py # Pearson's r voxel-wise scoring │ ├── visualize.py # Loss curve plotting │ └── inference.py # Standalone inference helper │ ├── test/ │ ├── overfit_test.py # Tiny-batch overfit sanity check for Stage 2 │ ├── check_pearson.py # Load checkpoints and plot per-voxel Pearson's r heatmaps │ └── debug_training.py # End-to-end smoke test │ ├── experiments/ │ └── *.ipynb # Analysis notebooks (RSA, OOD, brain region plots) │ └── Matcha-TTS/ # Vendored Matcha-TTS source (U-Net + solver) ``` --- ## Training ### Full training (server) ```bash cd flow_matching python src/training.py --cfg-path config.yml ``` Checkpoints are written to `output/two_stage_encoding/`: - `stage1_best.pt` — best Stage 1 model by validation Pearson's r - `stage2_epoch_N.pt` — Stage 2 snapshot every 5 epochs ### Local debug (CPU, tiny model) ```bash python src/training.py --cfg-path debug_config.yml ``` --- ## Evaluation ### Pearson's r heatmaps Loads all available Stage 1 and Stage 2 checkpoints, evaluates on the configured validation set, and saves per-subject per-voxel Pearson's r heatmaps to `output/two_stage_encoding/heatmaps/`. ```bash python test/check_pearson.py ``` **Output per checkpoint:** ``` Stage 1 Overall Pearson's r: 0.1832 Stage 1 - Sub 1 Mean Pearson's r: 0.1754 Stage 2 Epoch 5 Overall Pearson's r: 0.2110 Stage 2 Epoch 5 - Sub 1 Mean Pearson's r: 0.2043 ... ``` ### Tiny-batch overfit test Confirms Stage 2 can memorize a single training batch. If loss does not approach `0` within 500 steps, the architecture cannot learn the task. ```bash python test/overfit_test.py --cfg-path config.yml --subject-idx 0 --steps 500 ``` --- ## Key Hyperparameters | Parameter | Value | Location | |---|---|---| | Stage 1 embed dim | 192 | `config.yml / stage1.model.embed_dim` | | Stage 1 encoder kernel | 45 | `config.yml / stage1.model.encoder_kernel_size` | | Stage 1 LR | 3e-4 | `config.yml / stage1.lr` | | Stage 2 latent dim | 128 | `config.yml / stage2.latent_dim` | | Stage 2 U-Net channels | [256, 256] | `config.yml / stage2.decoder.channels` | | Stage 2 block type | Conformer | `config.yml / stage2.decoder.*_block_type` | | Stage 2 LR | 3e-4 | `config.yml / stage2.lr` | | Euler integration steps | 25 | `config.yml / stage2.n_timesteps` | | CFM σ_min | 1e-4 | `config.yml / stage2.cfm.sigma_min` | --- ## Metric Evaluation uses **voxel-wise Pearson's r** averaged across subjects: $$r_v = \frac{\sum_t (y_v^t - \bar{y}_v)(\hat{y}_v^t - \bar{\hat{y}}_v)}{\|\mathbf{y}_v\| \cdot \|\hat{\mathbf{y}}_v\|}$$ The scalar reported is the mean over all `V` voxels and all `S` subjects.