AttentionBaseNet / README.md
bruAristimunha's picture
Add architecture-only model card
6318a15 verified
|
raw
history blame
15.5 kB
---
license: bsd-3-clause
library_name: braindecode
pipeline_tag: feature-extraction
tags:
- eeg
- biosignal
- pytorch
- neuroscience
- braindecode
- convolutional
- transformer
---
# AttentionBaseNet
AttentionBaseNet from Wimpff M et al (2023) .
> **Architecture-only repository.** This repo documents the
> `braindecode.models.AttentionBaseNet` class. **No pretrained weights are
> distributed here** — instantiate the model and train it on your own
> data, or fine-tune from a published foundation-model checkpoint
> separately.
## Quick start
```bash
pip install braindecode
```
```python
from braindecode.models import AttentionBaseNet
model = AttentionBaseNet(
n_chans=22,
sfreq=250,
input_window_seconds=4.0,
n_outputs=4,
)
```
The signal-shape arguments above are example defaults — adjust them
to match your recording.
## Documentation
- Full API reference (parameters, references, architecture figure):
<https://braindecode.org/stable/generated/braindecode.models.AttentionBaseNet.html>
- Interactive browser with live instantiation:
<https://huggingface.co/spaces/braindecode/model-explorer>
- Source on GitHub: <https://github.com/braindecode/braindecode/blob/master/braindecode/models/attentionbasenet.py#L29>
## Architecture description
The block below is the rendered class docstring (parameters,
references, architecture figure where available).
<div class='bd-doc'><main>
<p>AttentionBaseNet from Wimpff M et al (2023) [Martin2023]_.</p>
<span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#5cb85c;color:white;font-size:11px;font-weight:600;margin-right:4px;">Convolution</span><span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#56B4E9;color:white;font-size:11px;font-weight:600;margin-right:4px;">Attention/Transformer</span>
.. figure:: https://content.cld.iop.org/journals/1741-2552/21/3/036020/revision2/jnead48b9f2_hr.jpg
:align: center
:alt: AttentionBaseNet Architecture
:width: 640px
.. rubric:: Architectural Overview
AttentionBaseNet is a *convolution-first* network with a *channel-attention* stage.
The end-to-end flow is:
- (i) :class:`_FeatureExtractor` learns a temporal filter bank and per-filter spatial
projections (depthwise across electrodes), then condenses time by pooling;
- (ii) **Channel Expansion** uses a ``1x1`` convolution to set the feature width;
- (iii) :class:`_ChannelAttentionBlock` refines features via depthwise–pointwise temporal
convs and an optional channel-attention module (SE/CBAM/ECA/…);
- (iv) **Classifier** flattens the sequence and applies a linear readout.
This design mirrors shallow CNN pipelines (EEGNet-style stem) but inserts a pluggable
attention unit that *re-weights channels* (and optionally temporal positions) before
classification.
.. rubric:: Macro Components
- :class:`_FeatureExtractor` **(Shallow conv stem → condensed feature map)**
- *Operations.*
- **Temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(1, L_t)`` creates a learned
FIR-like filter bank with ``n_temporal_filters`` maps.
- **Depthwise spatial conv** (:class:`torch.nn.Conv2d`, ``groups=n_temporal_filters``)
with kernel ``(n_chans, 1)`` learns per-filter spatial projections over the full montage.
- **BatchNorm → ELU → AvgPool → Dropout** stabilize and downsample time.
- Output shape: ``(B, F2, 1, T₁)`` with ``F2 = n_temporal_filters x spatial_expansion``.
*Interpretability/robustness.* Temporal kernels behave as analyzable FIR filters; the
depthwise spatial step yields rhythm-specific topographies. Pooling acts as a local
integrator that reduces variance on short EEG windows.
- **Channel Expansion**
- *Operations.*
- A ``1x1`` conv → BN → activation maps ``F2 → ch_dim`` without changing
the temporal length ``T₁`` (shape: ``(B, ch_dim, 1, T₁)``).
This sets the embedding width for the attention block.
- :class:`_ChannelAttentionBlock` **(temporal refinement + channel attention)**
- *Operations.*
- **Depthwise temporal conv** ``(1, L_a)`` (groups=``ch_dim``) + **pointwise ``1x1``**,
BN and activation → preserves shape ``(B, ch_dim, 1, T₁)`` while refining timing.
- **Optional attention module** (see *Additional Mechanisms*) applies channel reweighting
(some variants also apply temporal gating).
- **AvgPool (1, P₂)** with stride ``(1, S₂)`` and **Dropout** → outputs
``(B, ch_dim, 1, T₂)``.
*Role.* Emphasizes informative channels (and, in certain modes, salient time steps)
before the classifier; complements the convolutional priors with adaptive re-weighting.
- **Classifier (aggregation + readout)**
*Operations.* :class:`torch.nn.Flatten` → :class:`torch.nn.Linear` from
``(B, ch_dim·T₂)`` to classes.
.. rubric:: Convolutional Details
- **Temporal (where time-domain patterns are learned).**
Wide kernels in the stem (``(1, L_t)``) act as a learned filter bank for oscillatory
bands/transients; the attention block's depthwise temporal conv (``(1, L_a)``) sharpens
short-term dynamics after downsampling. Pool sizes/strides (``P₁,S₁`` then ``P₂,S₂``)
set the token rate and effective temporal resolution.
- **Spatial (how electrodes are processed).**
A depthwise spatial conv with kernel ``(n_chans, 1)`` spans the full montage to
learn *per-temporal-filter* spatial projections (no cross-filter mixing at this step),
mirroring the interpretable spatial stage in shallow CNNs.
- **Spectral (how frequency content is captured).**
No explicit Fourier/wavelet transform is used in the stem—spectral selectivity
emerges from learned temporal kernels. When ``attention_mode="fca"``, a frequency
channel attention (DCT-based) summarizes frequencies to drive channel weights.
.. rubric:: Attention / Sequential Modules
- **Type.** Channel attention chosen by ``attention_mode`` (SE, ECA, CBAM, CAT, GSoP,
EncNet, GE, GCT, SRM, CATLite). Most operate purely on channels; CBAM/CAT additionally
include temporal attention.
- **Shapes.** Input/Output around attention: ``(B, ch_dim, 1, T₁)``. Re-arrangements
(if any) are internal to the module; the block returns the same shape before pooling.
- **Role.** Re-weights channels (and optionally time) to highlight informative sources
and suppress distractors, improving SNR ahead of the linear head.
.. rubric:: Additional Mechanisms
**Attention variants at a glance:**
- ``"se"``: Squeeze-and-Excitation (global pooling → bottleneck → gates).
- ``"gsop"``: Global second-order pooling (covariance-aware channel weights).
- ``"fca"``: Frequency Channel Attention (DCT summary; uses ``seq_len`` and ``freq_idx``).
- ``"encnet"``: EncNet with learned codewords (uses ``n_codewords``).
- ``"eca"``: Efficient Channel Attention (local 1-D conv over channel descriptor; uses ``kernel_size``).
- ``"ge"``: Gather–Excite (context pooling with optional MLP; can use ``extra_params``).
- ``"gct"``: Gated Channel Transformation (global context normalization + gating).
- ``"srm"``: Style-based recalibration (mean–std descriptors; optional MLP).
- ``"cbam"``: Channel then temporal attention (uses ``kernel_size``).
- ``"cat"`` / ``"catlite"``: Collaborative (channel ± temporal) attention; *lite* omits temporal.
**Auto-compatibility on short inputs:**
If the input duration is too short for the configured kernels/pools, the implementation
**automatically rescales** temporal lengths/strides downward (with a warning) to keep
shapes valid and preserve the pipeline semantics.
.. rubric:: Usage and Configuration
- ``n_temporal_filters``, ``temporal_filter_length`` and ``spatial_expansion``:
control the capacity and the number of spatial projections in the stem.
- ``pool_length_inp``, ``pool_stride_inp`` then ``pool_length``, ``pool_stride``:
trade temporal resolution for compute; they determine the final sequence length ``T₂``.
- ``ch_dim``: width after the ``1x1`` expansion and the effective embedding size for attention.
- ``attention_mode`` + its specific hyperparameters (``reduction_rate``,
``kernel_size``, ``seq_len``, ``freq_idx``, ``n_codewords``, ``use_mlp``):
select and tune the reweighting mechanism.
- ``drop_prob_inp`` and ``drop_prob_attn``: regularize stem and attention stages.
- **Training tips.**
Start with moderate pooling (e.g., ``P₁=75,S₁=15``) and ELU activations; enable attention
only after the stem learns stable filters. For small datasets, prefer simpler modes
(``"se"``, ``"eca"``) before heavier ones (``"gsop"``, ``"encnet"``).
Parameters
----------
n_temporal_filters : int, optional
Number of temporal convolutional filters in the first layer. This defines
the number of output channels after the temporal convolution.
Default is 40.
temp_filter_length : int, default=15
The length of the temporal filters in the convolutional layers.
spatial_expansion : int, optional
Multiplicative factor to expand the spatial dimensions. Used to increase
the capacity of the model by expanding spatial features. Default is 1.
pool_length_inp : int, optional
Length of the pooling window in the input layer. Determines how much
temporal information is aggregated during pooling. Default is 75.
pool_stride_inp : int, optional
Stride of the pooling operation in the input layer. Controls the
downsampling factor in the temporal dimension. Default is 15.
drop_prob_inp : float, optional
Dropout rate applied after the input layer. This is the probability of
zeroing out elements during training to prevent overfitting.
Default is 0.5.
ch_dim : int, optional
Number of channels in the subsequent convolutional layers. This controls
the depth of the network after the initial layer. Default is 16.
attention_mode : str, optional
The type of attention mechanism to apply. If `None`, no attention is applied.
- "se" for Squeeze-and-excitation network
- "gsop" for Global Second-Order Pooling
- "fca" for Frequency Channel Attention Network
- "encnet" for context encoding module
- "eca" for Efficient channel attention for deep convolutional neural networks
- "ge" for Gather-Excite
- "gct" for Gated Channel Transformation
- "srm" for Style-based Recalibration Module
- "cbam" for Convolutional Block Attention Module
- "cat" for Learning to collaborate channel and temporal attention
from multi-information fusion
- "catlite" for Learning to collaborate channel attention
from multi-information fusion (lite version, cat w/o temporal attention)
pool_length : int, default=8
The length of the window for the average pooling operation.
pool_stride : int, default=8
The stride of the average pooling operation.
drop_prob_attn : float, default=0.5
The dropout rate for regularization for the attention layer. Values should be between 0 and 1.
reduction_rate : int, default=4
The reduction rate used in the attention mechanism to reduce dimensionality
and computational complexity.
use_mlp : bool, default=False
Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used within
the attention mechanism for further processing.
freq_idx : int, default=0
DCT index used in fca attention mechanism.
n_codewords : int, default=4
The number of codewords (clusters) used in attention mechanisms that employ
quantization or clustering strategies.
kernel_size : int, default=9
The kernel size used in certain types of attention mechanisms for convolution
operations.
activation : type[nn.Module] = nn.ELU,
Activation function class to apply. Should be a PyTorch activation
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
extra_params : bool, default=False
Flag to indicate whether additional, custom parameters should be passed to
the attention mechanism.
Notes
-----
- Sequence length after each stage is computed internally; the final classifier expects
a flattened ``ch_dim x T₂`` vector.
- Attention operates on *channel* dimension by design; temporal gating exists only in
specific variants (CBAM/CAT).
- The paper and original code with more details about the methodological
choices are available at the [Martin2023]_ and [MartinCode]_.
.. versionadded:: 0.9
References
----------
.. [Martin2023] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B., 2023.
EEG motor imagery decoding: A framework for comparative analysis with
channel attention mechanisms. arXiv preprint arXiv:2310.11198.
.. [MartinCode] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B.
GitHub https://github.com/martinwimpff/channel-attention (accessed 2024-03-28)
.. rubric:: Hugging Face Hub integration
When the optional ``huggingface_hub`` package is installed, all models
automatically gain the ability to be pushed to and loaded from the
Hugging Face Hub. Install with::
pip install braindecode[hub]
**Pushing a model to the Hub:**
.. code::
from braindecode.models import AttentionBaseNet
# Train your model
model = AttentionBaseNet(n_chans=22, n_outputs=4, n_times=1000)
# ... training code ...
# Push to the Hub
model.push_to_hub(
repo_id="username/my-attentionbasenet-model",
commit_message="Initial model upload",
)
**Loading a model from the Hub:**
.. code::
from braindecode.models import AttentionBaseNet
# Load pretrained model
model = AttentionBaseNet.from_pretrained("username/my-attentionbasenet-model")
# Load with a different number of outputs (head is rebuilt automatically)
model = AttentionBaseNet.from_pretrained("username/my-attentionbasenet-model", n_outputs=4)
**Extracting features and replacing the head:**
.. code::
import torch
x = torch.randn(1, model.n_chans, model.n_times)
# Extract encoder features (consistent dict across all models)
out = model(x, return_features=True)
features = out["features"]
# Replace the classification head
model.reset_head(n_outputs=10)
**Saving and restoring full configuration:**
.. code::
import json
config = model.get_config() # all __init__ params
with open("config.json", "w") as f:
json.dump(config, f)
model2 = AttentionBaseNet.from_config(config) # reconstruct (no weights)
All model parameters (both EEG-specific and model-specific such as
dropout rates, activation functions, number of filters) are automatically
saved to the Hub and restored when loading.
See :ref:`load-pretrained-models` for a complete tutorial.</main>
</div>
## Citation
Please cite both the original paper for this architecture (see the
*References* section above) and braindecode:
```bibtex
@article{aristimunha2025braindecode,
title = {Braindecode: a deep learning library for raw electrophysiological data},
author = {Aristimunha, Bruno and others},
journal = {Zenodo},
year = {2025},
doi = {10.5281/zenodo.17699192},
}
```
## License
BSD-3-Clause for the model code (matching braindecode).
Pretraining-derived weights, if you fine-tune from a checkpoint,
inherit the licence of that checkpoint and its training corpus.