CBraMod / README.md
bruAristimunha's picture
Add architecture-only model card
d982eaa verified
|
raw
history blame
10.5 kB
---
license: bsd-3-clause
library_name: braindecode
pipeline_tag: feature-extraction
tags:
- eeg
- biosignal
- pytorch
- neuroscience
- braindecode
- foundation-model
- transformer
---
# CBraMod
**C**\ riss-\ **C**\ ross **Bra**\ in **Mod**\ el for EEG Decoding from Wang et al. (2025) .
> **Architecture-only repository.** This repo documents the
> `braindecode.models.CBraMod` class. **No pretrained weights are
> distributed here** — instantiate the model and train it on your own
> data, or fine-tune from a published foundation-model checkpoint
> separately.
## Quick start
```bash
pip install braindecode
```
```python
from braindecode.models import CBraMod
model = CBraMod(
n_chans=22,
sfreq=200,
input_window_seconds=4.0,
n_outputs=2,
)
```
The signal-shape arguments above are example defaults — adjust them
to match your recording.
## Documentation
- Full API reference (parameters, references, architecture figure):
<https://braindecode.org/stable/generated/braindecode.models.CBraMod.html>
- Interactive browser with live instantiation:
<https://huggingface.co/spaces/braindecode/model-explorer>
- Source on GitHub: <https://github.com/braindecode/braindecode/blob/master/braindecode/models/cbramod.py#L23>
## Architecture description
The block below is the rendered class docstring (parameters,
references, architecture figure where available).
<div class='bd-doc'><main>
<p><strong>C</strong>riss-<strong>C</strong>ross <strong>Bra</strong>in <strong>Mod</strong>el for EEG Decoding from Wang et al. (2025) <a class="citation-reference" href="#cbramod" id="citation-reference-1" role="doc-biblioref">[cbramod]</a>.</p>
<span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#d9534f;color:white;font-size:11px;font-weight:600;margin-right:4px;">Foundation Model</span><span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#56B4E9;color:white;font-size:11px;font-weight:600;margin-right:4px;">Attention/Transformer</span><figure class="align-center">
<img alt="CBraMod pre-training overview" src="https://raw.githubusercontent.com/wjq-learning/CBraMod/refs/heads/main/figure/model.png" style="width: 1000px;" />
</figure>
<p>CBraMod is a foundation model for EEG decoding that leverages a novel criss-cross transformer
architecture to effectively model the unique spatial and temporal characteristics of EEG signals.
Pre-trained on the Temple University Hospital EEG Corpus (TUEG)—the largest public EEG corpus—
using masked EEG patch reconstruction, CBraMod achieves state-of-the-art performance across
diverse downstream BCI and clinical applications.</p>
<p><strong>Key Innovation: Criss-Cross Attention</strong></p>
<p>Unlike existing EEG foundation models that use full attention to model all spatial and temporal
dependencies together, CBraMod separates spatial and temporal dependencies through a
<strong>criss-cross transformer</strong> architecture:</p>
<ul class="simple">
<li><p><strong>Spatial Attention</strong>: Models dependencies between channels while keeping patches separate</p></li>
<li><p><strong>Temporal Attention</strong>: Models dependencies between temporal patches while keeping channels separate</p></li>
</ul>
<p>This design is inspired by criss-cross strategies from computer vision and effectively
leverages the inherent structural characteristics of EEG signals. The criss-cross approach
reduces computational complexity (FLOPs reduced by ~32% compared to full attention) while
improving performance and enabling faster convergence.</p>
<p><strong>Asymmetric Conditional Positional Encoding (ACPE)</strong></p>
<p>Rather than using fixed positional embeddings, CBraMod employs <strong>Asymmetric Conditional
Positional Encoding</strong> that dynamically generates positional embeddings using a convolutional
network. This enables the model to:</p>
<ul class="simple">
<li><p>Capture relative positional information adaptively</p></li>
<li><p>Handle diverse EEG channel formats (different channel counts and reference schemes)</p></li>
<li><p>Generalize to arbitrary downstream EEG formats without retraining</p></li>
<li><p>Support various reference schemes (earlobe, average, REST, bipolar)</p></li>
</ul>
<p><strong>Pretraining Highlights</strong></p>
<ul class="simple">
<li><p><strong>Pretraining Dataset</strong>: Temple University Hospital EEG Corpus (TUEG), the largest public EEG corpus</p></li>
<li><p><strong>Pretraining Task</strong>: Self-supervised masked EEG patch reconstruction from both time-domain
and frequency-domain EEG signals</p></li>
<li><p><strong>Model Parameters</strong>: ~4.0M parameters (very compact compared to other foundation models)</p></li>
<li><p><strong>Fast Convergence</strong>: Achieves decent results in first epoch on downstream tasks,
full convergence within ~10 epochs (vs. ~30 for supervised models like EEGConformer)</p></li>
</ul>
<p><strong>Macro Components</strong></p>
<ul class="simple">
<li><p><strong>Patch Encoding Network</strong>: Converts raw EEG patches into embeddings</p></li>
<li><p><strong>Asymmetric Conditional Positional Encoding (ACPE)</strong>: Generates spatial-temporal positional
embeddings adaptively from input EEG format</p></li>
<li><p><strong>Criss-Cross Transformer Blocks</strong> (12 layers): Alternates spatial and temporal attention
to learn EEG representations</p></li>
<li><p><strong>Reconstruction Head</strong>: Reconstructs masked EEG patches during pretraining</p></li>
<li><dl class="simple">
<dt><strong>Task head</strong> (<span class="docutils literal">final_layer</span>): flatten summary tokens across patches and map to</dt>
<dd><p><span class="docutils literal">n_outputs</span>; if <span class="docutils literal">return_encoder_output=True</span>, return the encoder features instead.</p>
</dd>
</dl>
</li>
</ul>
<p>The model is highly efficient, requiring only ~318.9M FLOPs on a typical 16-channel, 10-second
EEG recording (significantly lower than full attention baselines).</p>
<p><strong>Known Limitations</strong></p>
<ul class="simple">
<li><p><strong>Data Quality</strong>: TUEG corpus contains &quot;dirty data&quot;; pretraining used crude filtering,
reducing available pre-training data</p></li>
<li><p><strong>Channel Dependency</strong>: Performance degrades with very sparse electrode setups (e.g., &lt;4 channels)</p></li>
<li><p><strong>Computational Resources</strong>: While efficient, foundation models have higher deployment
requirements than lightweight models</p></li>
<li><p><strong>Limited Scaling Exploration</strong>: Future work should explore scaling laws at billion-parameter levels
and integration with large pre-trained vision/language models</p></li>
</ul>
<aside class="admonition important">
<p class="admonition-title">Important</p>
<p><strong>Pre-trained Weights Available</strong></p>
<p>This model has pre-trained weights available on the Hugging Face Hub.
You can load them using:</p>
<p>To push your own trained model to the Hub:</p>
<p>Requires installing <span class="docutils literal">braindecode[hug]</span> for Hub integration.</p>
</aside>
<section id="parameters">
<h2>Parameters</h2>
<dl class="simple">
<dt>patch_size<span class="classifier">int, default=200</span></dt>
<dd><p>Temporal patch size in samples (200 samples = 1 second at 200 Hz).</p>
</dd>
<dt>dim_feedforward<span class="classifier">int, default=800</span></dt>
<dd><p>Dimension of the feedforward network in Transformer layers.</p>
</dd>
<dt>n_layer<span class="classifier">int, default=12</span></dt>
<dd><p>Number of Transformer layers.</p>
</dd>
<dt>nhead<span class="classifier">int, default=8</span></dt>
<dd><p>Number of attention heads.</p>
</dd>
<dt>activation<span class="classifier">type[nn.Module], default=nn.GELU</span></dt>
<dd><p>Activation function used in Transformer feedforward layers.</p>
</dd>
<dt>emb_dim<span class="classifier">int, default=200</span></dt>
<dd><p>Output embedding dimension.</p>
</dd>
<dt>drop_prob<span class="classifier">float, default=0.1</span></dt>
<dd><p>Dropout probability.</p>
</dd>
<dt>return_encoder_output<span class="classifier">bool, default=False</span></dt>
<dd><p>If false (default), the features are flattened and passed through a final linear layer
to produce class logits of size <span class="docutils literal">n_outputs</span>.
If True, the model returns the encoder output features.</p>
</dd>
</dl>
</section>
<section id="references">
<h2>References</h2>
<div role="list" class="citation-list">
<div class="citation" id="cbramod" role="doc-biblioentry">
<span class="label"><span class="fn-bracket">[</span><a role="doc-backlink" href="#citation-reference-1">cbramod</a><span class="fn-bracket">]</span></span>
<p>Wang, J., Zhao, S., Luo, Z., Zhou, Y., Jiang, H., Li, S., Li, T., &amp; Pan, G. (2025).
CBraMod: A Criss-Cross Brain Foundation Model for EEG Decoding.
In The Thirteenth International Conference on Learning Representations (ICLR 2025).
<a class="reference external" href="https://arxiv.org/abs/2412.07236">https://arxiv.org/abs/2412.07236</a></p>
</div>
</div>
<p><strong>Hugging Face Hub integration</strong></p>
<p>When the optional <span class="docutils literal">huggingface_hub</span> package is installed, all models
automatically gain the ability to be pushed to and loaded from the
Hugging Face Hub. Install with:</p>
<pre class="literal-block">pip install braindecode[hub]</pre>
<p><strong>Pushing a model to the Hub:</strong></p>
<p><strong>Loading a model from the Hub:</strong></p>
<p><strong>Extracting features and replacing the head:</strong></p>
<p><strong>Saving and restoring full configuration:</strong></p>
<p>All model parameters (both EEG-specific and model-specific such as
dropout rates, activation functions, number of filters) are automatically
saved to the Hub and restored when loading.</p>
<p>See :ref:`load-pretrained-models` for a complete tutorial.</p>
</section>
</main>
</div>
## Citation
Please cite both the original paper for this architecture (see the
*References* section above) and braindecode:
```bibtex
@article{aristimunha2025braindecode,
title = {Braindecode: a deep learning library for raw electrophysiological data},
author = {Aristimunha, Bruno and others},
journal = {Zenodo},
year = {2025},
doi = {10.5281/zenodo.17699192},
}
```
## License
BSD-3-Clause for the model code (matching braindecode).
Pretraining-derived weights, if you fine-tune from a checkpoint,
inherit the licence of that checkpoint and its training corpus.