license: bsd-3-clause
library_name: braindecode
pipeline_tag: feature-extraction
tags:
- eeg
- biosignal
- pytorch
- neuroscience
- braindecode
- convolutional
- transformer
AttentionBaseNet
AttentionBaseNet from Wimpff M et al (2023) .
Architecture-only repository. This repo documents the
braindecode.models.AttentionBaseNetclass. No pretrained weights are distributed here — instantiate the model and train it on your own data, or fine-tune from a published foundation-model checkpoint separately.
Quick start
pip install braindecode
from braindecode.models import AttentionBaseNet
model = AttentionBaseNet(
n_chans=22,
sfreq=250,
input_window_seconds=4.0,
n_outputs=4,
)
The signal-shape arguments above are example defaults — adjust them to match your recording.
Documentation
- Full API reference (parameters, references, architecture figure): https://braindecode.org/stable/generated/braindecode.models.AttentionBaseNet.html
- Interactive browser with live instantiation: https://huggingface.co/spaces/braindecode/model-explorer
- Source on GitHub: https://github.com/braindecode/braindecode/blob/master/braindecode/models/attentionbasenet.py#L29
Architecture description
The block below is the rendered class docstring (parameters, references, architecture figure where available).
AttentionBaseNet from Wimpff M et al (2023) [Martin2023]_.
ConvolutionAttention/Transformer.. figure:: https://content.cld.iop.org/journals/1741-2552/21/3/036020/revision2/jnead48b9f2_hr.jpg :align: center :alt: AttentionBaseNet Architecture :width: 640px
.. rubric:: Architectural Overview
AttentionBaseNet is a convolution-first network with a channel-attention stage. The end-to-end flow is:
- (i) :class:
_FeatureExtractorlearns a temporal filter bank and per-filter spatial projections (depthwise across electrodes), then condenses time by pooling; - (ii) Channel Expansion uses a
1x1convolution to set the feature width; - (iii) :class:
_ChannelAttentionBlockrefines features via depthwise–pointwise temporal convs and an optional channel-attention module (SE/CBAM/ECA/…); - (iv) Classifier flattens the sequence and applies a linear readout.
This design mirrors shallow CNN pipelines (EEGNet-style stem) but inserts a pluggable attention unit that re-weights channels (and optionally temporal positions) before classification.
.. rubric:: Macro Components
:class:
_FeatureExtractor(Shallow conv stem → condensed feature map)- Operations.
- Temporal conv (:class:
torch.nn.Conv2d) with kernel(1, L_t)creates a learned FIR-like filter bank withn_temporal_filtersmaps. - Depthwise spatial conv (:class:
torch.nn.Conv2d,groups=n_temporal_filters) with kernel(n_chans, 1)learns per-filter spatial projections over the full montage. - BatchNorm → ELU → AvgPool → Dropout stabilize and downsample time.
- Output shape:
(B, F2, 1, T₁)withF2 = n_temporal_filters x spatial_expansion.
Interpretability/robustness. Temporal kernels behave as analyzable FIR filters; the depthwise spatial step yields rhythm-specific topographies. Pooling acts as a local integrator that reduces variance on short EEG windows.
Channel Expansion
- Operations.
- A
1x1conv → BN → activation mapsF2 → ch_dimwithout changing the temporal lengthT₁(shape:(B, ch_dim, 1, T₁)). This sets the embedding width for the attention block.
:class:
_ChannelAttentionBlock(temporal refinement + channel attention)- Operations.
- Depthwise temporal conv
(1, L_a)(groups=ch_dim) + pointwise1x1, BN and activation → preserves shape(B, ch_dim, 1, T₁)while refining timing. - Optional attention module (see Additional Mechanisms) applies channel reweighting (some variants also apply temporal gating).
- AvgPool (1, P₂) with stride
(1, S₂)and Dropout → outputs(B, ch_dim, 1, T₂).
Role. Emphasizes informative channels (and, in certain modes, salient time steps) before the classifier; complements the convolutional priors with adaptive re-weighting.
- Classifier (aggregation + readout)
Operations. :class:torch.nn.Flatten → :class:torch.nn.Linear from
(B, ch_dim·T₂) to classes.
.. rubric:: Convolutional Details
Temporal (where time-domain patterns are learned). Wide kernels in the stem (
(1, L_t)) act as a learned filter bank for oscillatory bands/transients; the attention block's depthwise temporal conv ((1, L_a)) sharpens short-term dynamics after downsampling. Pool sizes/strides (P₁,S₁thenP₂,S₂) set the token rate and effective temporal resolution.Spatial (how electrodes are processed). A depthwise spatial conv with kernel
(n_chans, 1)spans the full montage to learn per-temporal-filter spatial projections (no cross-filter mixing at this step), mirroring the interpretable spatial stage in shallow CNNs.Spectral (how frequency content is captured). No explicit Fourier/wavelet transform is used in the stem—spectral selectivity emerges from learned temporal kernels. When
attention_mode="fca", a frequency channel attention (DCT-based) summarizes frequencies to drive channel weights.
.. rubric:: Attention / Sequential Modules
Type. Channel attention chosen by
attention_mode(SE, ECA, CBAM, CAT, GSoP, EncNet, GE, GCT, SRM, CATLite). Most operate purely on channels; CBAM/CAT additionally include temporal attention.Shapes. Input/Output around attention:
(B, ch_dim, 1, T₁). Re-arrangements (if any) are internal to the module; the block returns the same shape before pooling.Role. Re-weights channels (and optionally time) to highlight informative sources and suppress distractors, improving SNR ahead of the linear head.
.. rubric:: Additional Mechanisms
Attention variants at a glance:
"se": Squeeze-and-Excitation (global pooling → bottleneck → gates)."gsop": Global second-order pooling (covariance-aware channel weights)."fca": Frequency Channel Attention (DCT summary; usesseq_lenandfreq_idx)."encnet": EncNet with learned codewords (usesn_codewords)."eca": Efficient Channel Attention (local 1-D conv over channel descriptor; useskernel_size)."ge": Gather–Excite (context pooling with optional MLP; can useextra_params)."gct": Gated Channel Transformation (global context normalization + gating)."srm": Style-based recalibration (mean–std descriptors; optional MLP)."cbam": Channel then temporal attention (useskernel_size)."cat"/"catlite": Collaborative (channel ± temporal) attention; lite omits temporal.
Auto-compatibility on short inputs:
If the input duration is too short for the configured kernels/pools, the implementation
**automatically rescales** temporal lengths/strides downward (with a warning) to keep
shapes valid and preserve the pipeline semantics.
.. rubric:: Usage and Configuration
n_temporal_filters,temporal_filter_lengthandspatial_expansion: control the capacity and the number of spatial projections in the stem.pool_length_inp,pool_stride_inpthenpool_length,pool_stride: trade temporal resolution for compute; they determine the final sequence lengthT₂.ch_dim: width after the1x1expansion and the effective embedding size for attention.attention_mode+ its specific hyperparameters (reduction_rate,kernel_size,seq_len,freq_idx,n_codewords,use_mlp): select and tune the reweighting mechanism.drop_prob_inpanddrop_prob_attn: regularize stem and attention stages.Training tips.
Start with moderate pooling (e.g.,
P₁=75,S₁=15) and ELU activations; enable attention only after the stem learns stable filters. For small datasets, prefer simpler modes ("se","eca") before heavier ones ("gsop","encnet").
Parameters
n_temporal_filters : int, optional
Number of temporal convolutional filters in the first layer. This defines
the number of output channels after the temporal convolution.
Default is 40.
temp_filter_length : int, default=15
The length of the temporal filters in the convolutional layers.
spatial_expansion : int, optional
Multiplicative factor to expand the spatial dimensions. Used to increase
the capacity of the model by expanding spatial features. Default is 1.
pool_length_inp : int, optional
Length of the pooling window in the input layer. Determines how much
temporal information is aggregated during pooling. Default is 75.
pool_stride_inp : int, optional
Stride of the pooling operation in the input layer. Controls the
downsampling factor in the temporal dimension. Default is 15.
drop_prob_inp : float, optional
Dropout rate applied after the input layer. This is the probability of
zeroing out elements during training to prevent overfitting.
Default is 0.5.
ch_dim : int, optional
Number of channels in the subsequent convolutional layers. This controls
the depth of the network after the initial layer. Default is 16.
attention_mode : str, optional
The type of attention mechanism to apply. If None, no attention is applied.
- "se" for Squeeze-and-excitation network
- "gsop" for Global Second-Order Pooling
- "fca" for Frequency Channel Attention Network
- "encnet" for context encoding module
- "eca" for Efficient channel attention for deep convolutional neural networks
- "ge" for Gather-Excite
- "gct" for Gated Channel Transformation
- "srm" for Style-based Recalibration Module
- "cbam" for Convolutional Block Attention Module
- "cat" for Learning to collaborate channel and temporal attention
from multi-information fusion
- "catlite" for Learning to collaborate channel attention
from multi-information fusion (lite version, cat w/o temporal attention)
pool_length : int, default=8
The length of the window for the average pooling operation.
pool_stride : int, default=8
The stride of the average pooling operation.
drop_prob_attn : float, default=0.5
The dropout rate for regularization for the attention layer. Values should be between 0 and 1.
reduction_rate : int, default=4
The reduction rate used in the attention mechanism to reduce dimensionality
and computational complexity.
use_mlp : bool, default=False
Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used within
the attention mechanism for further processing.
freq_idx : int, default=0
DCT index used in fca attention mechanism.
n_codewords : int, default=4
The number of codewords (clusters) used in attention mechanisms that employ
quantization or clustering strategies.
kernel_size : int, default=9
The kernel size used in certain types of attention mechanisms for convolution
operations.
activation : type[nn.Module] = nn.ELU,
Activation function class to apply. Should be a PyTorch activation
module class like nn.ReLU or nn.ELU. Default is nn.ELU.
extra_params : bool, default=False
Flag to indicate whether additional, custom parameters should be passed to
the attention mechanism.
Notes
- Sequence length after each stage is computed internally; the final classifier expects
a flattened
ch_dim x T₂vector. - Attention operates on channel dimension by design; temporal gating exists only in specific variants (CBAM/CAT).
- The paper and original code with more details about the methodological choices are available at the [Martin2023]_ and [MartinCode]_.
.. versionadded:: 0.9
References
.. [Martin2023] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B., 2023. EEG motor imagery decoding: A framework for comparative analysis with channel attention mechanisms. arXiv preprint arXiv:2310.11198. .. [MartinCode] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B. GitHub https://github.com/martinwimpff/channel-attention (accessed 2024-03-28)
.. rubric:: Hugging Face Hub integration
When the optional huggingface_hub package is installed, all models
automatically gain the ability to be pushed to and loaded from the
Hugging Face Hub. Install with::
pip install braindecode[hub]
Pushing a model to the Hub:
.. code:: from braindecode.models import AttentionBaseNet
# Train your model
model = AttentionBaseNet(n_chans=22, n_outputs=4, n_times=1000)
# ... training code ...
# Push to the Hub
model.push_to_hub(
repo_id="username/my-attentionbasenet-model",
commit_message="Initial model upload",
)
Loading a model from the Hub:
.. code:: from braindecode.models import AttentionBaseNet
# Load pretrained model
model = AttentionBaseNet.from_pretrained("username/my-attentionbasenet-model")
# Load with a different number of outputs (head is rebuilt automatically)
model = AttentionBaseNet.from_pretrained("username/my-attentionbasenet-model", n_outputs=4)
Extracting features and replacing the head:
.. code:: import torch
x = torch.randn(1, model.n_chans, model.n_times)
# Extract encoder features (consistent dict across all models)
out = model(x, return_features=True)
features = out["features"]
# Replace the classification head
model.reset_head(n_outputs=10)
Saving and restoring full configuration:
.. code:: import json
config = model.get_config() # all __init__ params
with open("config.json", "w") as f:
json.dump(config, f)
model2 = AttentionBaseNet.from_config(config) # reconstruct (no weights)
All model parameters (both EEG-specific and model-specific such as dropout rates, activation functions, number of filters) are automatically saved to the Hub and restored when loading.
See :ref:load-pretrained-models for a complete tutorial.
Citation
Please cite both the original paper for this architecture (see the References section above) and braindecode:
@article{aristimunha2025braindecode,
title = {Braindecode: a deep learning library for raw electrophysiological data},
author = {Aristimunha, Bruno and others},
journal = {Zenodo},
year = {2025},
doi = {10.5281/zenodo.17699192},
}
License
BSD-3-Clause for the model code (matching braindecode). Pretraining-derived weights, if you fine-tune from a checkpoint, inherit the licence of that checkpoint and its training corpus.