| --- |
| license: bsd-3-clause |
| library_name: braindecode |
| pipeline_tag: feature-extraction |
| tags: |
| - eeg |
| - biosignal |
| - pytorch |
| - neuroscience |
| - braindecode |
| - convolutional |
| - transformer |
| --- |
| |
| # AttentionBaseNet |
|
|
| AttentionBaseNet from Wimpff M et al (2023) . |
|
|
| > **Architecture-only repository.** This repo documents the |
| > `braindecode.models.AttentionBaseNet` class. **No pretrained weights are |
| > distributed here** — instantiate the model and train it on your own |
| > data, or fine-tune from a published foundation-model checkpoint |
| > separately. |
|
|
| ## Quick start |
|
|
| ```bash |
| pip install braindecode |
| ``` |
|
|
| ```python |
| from braindecode.models import AttentionBaseNet |
| |
| model = AttentionBaseNet( |
| n_chans=22, |
| sfreq=250, |
| input_window_seconds=4.0, |
| n_outputs=4, |
| ) |
| ``` |
|
|
| The signal-shape arguments above are example defaults — adjust them |
| to match your recording. |
|
|
| ## Documentation |
|
|
| - Full API reference (parameters, references, architecture figure): |
| <https://braindecode.org/stable/generated/braindecode.models.AttentionBaseNet.html> |
| - Interactive browser with live instantiation: |
| <https://huggingface.co/spaces/braindecode/model-explorer> |
| - Source on GitHub: <https://github.com/braindecode/braindecode/blob/master/braindecode/models/attentionbasenet.py#L29> |
|
|
| ## Architecture description |
|
|
| The block below is the rendered class docstring (parameters, |
| references, architecture figure where available). |
|
|
| <div class='bd-doc'><main> |
| <p>AttentionBaseNet from Wimpff M et al (2023) [Martin2023]_.</p> |
| <span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#5cb85c;color:white;font-size:11px;font-weight:600;margin-right:4px;">Convolution</span><span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#56B4E9;color:white;font-size:11px;font-weight:600;margin-right:4px;">Attention/Transformer</span> |
| |
| |
| |
| .. figure:: https://content.cld.iop.org/journals/1741-2552/21/3/036020/revision2/jnead48b9f2_hr.jpg |
| :align: center |
| :alt: AttentionBaseNet Architecture |
| :width: 640px |
| |
| .. rubric:: Architectural Overview |
|
|
| AttentionBaseNet is a *convolution-first* network with a *channel-attention* stage. |
| The end-to-end flow is: |
|
|
| - (i) :class:`_FeatureExtractor` learns a temporal filter bank and per-filter spatial |
| projections (depthwise across electrodes), then condenses time by pooling; |
| - (ii) **Channel Expansion** uses a ``1x1`` convolution to set the feature width; |
| - (iii) :class:`_ChannelAttentionBlock` refines features via depthwise–pointwise temporal |
| convs and an optional channel-attention module (SE/CBAM/ECA/…); |
| - (iv) **Classifier** flattens the sequence and applies a linear readout. |
|
|
| This design mirrors shallow CNN pipelines (EEGNet-style stem) but inserts a pluggable |
| attention unit that *re-weights channels* (and optionally temporal positions) before |
| classification. |
|
|
| .. rubric:: Macro Components |
|
|
| - :class:`_FeatureExtractor` **(Shallow conv stem → condensed feature map)** |
|
|
| - *Operations.* |
| - **Temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(1, L_t)`` creates a learned |
| FIR-like filter bank with ``n_temporal_filters`` maps. |
| - **Depthwise spatial conv** (:class:`torch.nn.Conv2d`, ``groups=n_temporal_filters``) |
| with kernel ``(n_chans, 1)`` learns per-filter spatial projections over the full montage. |
| - **BatchNorm → ELU → AvgPool → Dropout** stabilize and downsample time. |
| - Output shape: ``(B, F2, 1, T₁)`` with ``F2 = n_temporal_filters x spatial_expansion``. |
| |
| *Interpretability/robustness.* Temporal kernels behave as analyzable FIR filters; the |
| depthwise spatial step yields rhythm-specific topographies. Pooling acts as a local |
| integrator that reduces variance on short EEG windows. |
|
|
| - **Channel Expansion** |
|
|
| - *Operations.* |
| - A ``1x1`` conv → BN → activation maps ``F2 → ch_dim`` without changing |
| the temporal length ``T₁`` (shape: ``(B, ch_dim, 1, T₁)``). |
| This sets the embedding width for the attention block. |
| |
| - :class:`_ChannelAttentionBlock` **(temporal refinement + channel attention)** |
|
|
| - *Operations.* |
| - **Depthwise temporal conv** ``(1, L_a)`` (groups=``ch_dim``) + **pointwise ``1x1``**, |
| BN and activation → preserves shape ``(B, ch_dim, 1, T₁)`` while refining timing. |
| - **Optional attention module** (see *Additional Mechanisms*) applies channel reweighting |
| (some variants also apply temporal gating). |
| - **AvgPool (1, P₂)** with stride ``(1, S₂)`` and **Dropout** → outputs |
| ``(B, ch_dim, 1, T₂)``. |
| |
| *Role.* Emphasizes informative channels (and, in certain modes, salient time steps) |
| before the classifier; complements the convolutional priors with adaptive re-weighting. |
| |
| - **Classifier (aggregation + readout)** |
|
|
| *Operations.* :class:`torch.nn.Flatten` → :class:`torch.nn.Linear` from |
| ``(B, ch_dim·T₂)`` to classes. |
|
|
| .. rubric:: Convolutional Details |
|
|
| - **Temporal (where time-domain patterns are learned).** |
| Wide kernels in the stem (``(1, L_t)``) act as a learned filter bank for oscillatory |
| bands/transients; the attention block's depthwise temporal conv (``(1, L_a)``) sharpens |
| short-term dynamics after downsampling. Pool sizes/strides (``P₁,S₁`` then ``P₂,S₂``) |
| set the token rate and effective temporal resolution. |
| |
| - **Spatial (how electrodes are processed).** |
| A depthwise spatial conv with kernel ``(n_chans, 1)`` spans the full montage to |
| learn *per-temporal-filter* spatial projections (no cross-filter mixing at this step), |
| mirroring the interpretable spatial stage in shallow CNNs. |
| |
| - **Spectral (how frequency content is captured).** |
| No explicit Fourier/wavelet transform is used in the stem—spectral selectivity |
| emerges from learned temporal kernels. When ``attention_mode="fca"``, a frequency |
| channel attention (DCT-based) summarizes frequencies to drive channel weights. |
| |
| .. rubric:: Attention / Sequential Modules |
|
|
| - **Type.** Channel attention chosen by ``attention_mode`` (SE, ECA, CBAM, CAT, GSoP, |
| EncNet, GE, GCT, SRM, CATLite). Most operate purely on channels; CBAM/CAT additionally |
| include temporal attention. |
| |
| - **Shapes.** Input/Output around attention: ``(B, ch_dim, 1, T₁)``. Re-arrangements |
| (if any) are internal to the module; the block returns the same shape before pooling. |
| |
| - **Role.** Re-weights channels (and optionally time) to highlight informative sources |
| and suppress distractors, improving SNR ahead of the linear head. |
| |
| .. rubric:: Additional Mechanisms |
|
|
| **Attention variants at a glance:** |
|
|
| - ``"se"``: Squeeze-and-Excitation (global pooling → bottleneck → gates). |
| - ``"gsop"``: Global second-order pooling (covariance-aware channel weights). |
| - ``"fca"``: Frequency Channel Attention (DCT summary; uses ``seq_len`` and ``freq_idx``). |
| - ``"encnet"``: EncNet with learned codewords (uses ``n_codewords``). |
| - ``"eca"``: Efficient Channel Attention (local 1-D conv over channel descriptor; uses ``kernel_size``). |
| - ``"ge"``: Gather–Excite (context pooling with optional MLP; can use ``extra_params``). |
| - ``"gct"``: Gated Channel Transformation (global context normalization + gating). |
| - ``"srm"``: Style-based recalibration (mean–std descriptors; optional MLP). |
| - ``"cbam"``: Channel then temporal attention (uses ``kernel_size``). |
| - ``"cat"`` / ``"catlite"``: Collaborative (channel ± temporal) attention; *lite* omits temporal. |
|
|
| **Auto-compatibility on short inputs:** |
|
|
| If the input duration is too short for the configured kernels/pools, the implementation |
| **automatically rescales** temporal lengths/strides downward (with a warning) to keep |
| shapes valid and preserve the pipeline semantics. |
| |
| .. rubric:: Usage and Configuration |
|
|
| - ``n_temporal_filters``, ``temporal_filter_length`` and ``spatial_expansion``: |
| control the capacity and the number of spatial projections in the stem. |
| - ``pool_length_inp``, ``pool_stride_inp`` then ``pool_length``, ``pool_stride``: |
| trade temporal resolution for compute; they determine the final sequence length ``T₂``. |
| - ``ch_dim``: width after the ``1x1`` expansion and the effective embedding size for attention. |
| - ``attention_mode`` + its specific hyperparameters (``reduction_rate``, |
| ``kernel_size``, ``seq_len``, ``freq_idx``, ``n_codewords``, ``use_mlp``): |
| select and tune the reweighting mechanism. |
| - ``drop_prob_inp`` and ``drop_prob_attn``: regularize stem and attention stages. |
| - **Training tips.** |
| |
| Start with moderate pooling (e.g., ``P₁=75,S₁=15``) and ELU activations; enable attention |
| only after the stem learns stable filters. For small datasets, prefer simpler modes |
| (``"se"``, ``"eca"``) before heavier ones (``"gsop"``, ``"encnet"``). |
| |
| Parameters |
| ---------- |
| n_temporal_filters : int, optional |
| Number of temporal convolutional filters in the first layer. This defines |
| the number of output channels after the temporal convolution. |
| Default is 40. |
| temp_filter_length : int, default=15 |
| The length of the temporal filters in the convolutional layers. |
| spatial_expansion : int, optional |
| Multiplicative factor to expand the spatial dimensions. Used to increase |
| the capacity of the model by expanding spatial features. Default is 1. |
| pool_length_inp : int, optional |
| Length of the pooling window in the input layer. Determines how much |
| temporal information is aggregated during pooling. Default is 75. |
| pool_stride_inp : int, optional |
| Stride of the pooling operation in the input layer. Controls the |
| downsampling factor in the temporal dimension. Default is 15. |
| drop_prob_inp : float, optional |
| Dropout rate applied after the input layer. This is the probability of |
| zeroing out elements during training to prevent overfitting. |
| Default is 0.5. |
| ch_dim : int, optional |
| Number of channels in the subsequent convolutional layers. This controls |
| the depth of the network after the initial layer. Default is 16. |
| attention_mode : str, optional |
| The type of attention mechanism to apply. If `None`, no attention is applied. |
| |
| - "se" for Squeeze-and-excitation network |
| - "gsop" for Global Second-Order Pooling |
| - "fca" for Frequency Channel Attention Network |
| - "encnet" for context encoding module |
| - "eca" for Efficient channel attention for deep convolutional neural networks |
| - "ge" for Gather-Excite |
| - "gct" for Gated Channel Transformation |
| - "srm" for Style-based Recalibration Module |
| - "cbam" for Convolutional Block Attention Module |
| - "cat" for Learning to collaborate channel and temporal attention |
| from multi-information fusion |
| - "catlite" for Learning to collaborate channel attention |
| from multi-information fusion (lite version, cat w/o temporal attention) |
| |
| pool_length : int, default=8 |
| The length of the window for the average pooling operation. |
| pool_stride : int, default=8 |
| The stride of the average pooling operation. |
| drop_prob_attn : float, default=0.5 |
| The dropout rate for regularization for the attention layer. Values should be between 0 and 1. |
| reduction_rate : int, default=4 |
| The reduction rate used in the attention mechanism to reduce dimensionality |
| and computational complexity. |
| use_mlp : bool, default=False |
| Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used within |
| the attention mechanism for further processing. |
| freq_idx : int, default=0 |
| DCT index used in fca attention mechanism. |
| n_codewords : int, default=4 |
| The number of codewords (clusters) used in attention mechanisms that employ |
| quantization or clustering strategies. |
| kernel_size : int, default=9 |
| The kernel size used in certain types of attention mechanisms for convolution |
| operations. |
| activation : type[nn.Module] = nn.ELU, |
| Activation function class to apply. Should be a PyTorch activation |
| module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``. |
| extra_params : bool, default=False |
| Flag to indicate whether additional, custom parameters should be passed to |
| the attention mechanism. |
| |
| Notes |
| ----- |
| - Sequence length after each stage is computed internally; the final classifier expects |
| a flattened ``ch_dim x T₂`` vector. |
| - Attention operates on *channel* dimension by design; temporal gating exists only in |
| specific variants (CBAM/CAT). |
| - The paper and original code with more details about the methodological |
| choices are available at the [Martin2023]_ and [MartinCode]_. |
| |
| .. versionadded:: 0.9 |
| |
| References |
| ---------- |
| .. [Martin2023] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B., 2023. |
| EEG motor imagery decoding: A framework for comparative analysis with |
| channel attention mechanisms. arXiv preprint arXiv:2310.11198. |
| .. [MartinCode] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B. |
| GitHub https://github.com/martinwimpff/channel-attention (accessed 2024-03-28) |
| |
| .. rubric:: Hugging Face Hub integration |
| |
| When the optional ``huggingface_hub`` package is installed, all models |
| automatically gain the ability to be pushed to and loaded from the |
| Hugging Face Hub. Install with:: |
|
|
| pip install braindecode[hub] |
| |
| **Pushing a model to the Hub:** |
|
|
| .. code:: |
| from braindecode.models import AttentionBaseNet |
| |
| # Train your model |
| model = AttentionBaseNet(n_chans=22, n_outputs=4, n_times=1000) |
| # ... training code ... |
| |
| # Push to the Hub |
| model.push_to_hub( |
| repo_id="username/my-attentionbasenet-model", |
| commit_message="Initial model upload", |
| ) |
| |
| **Loading a model from the Hub:** |
|
|
| .. code:: |
| from braindecode.models import AttentionBaseNet |
| |
| # Load pretrained model |
| model = AttentionBaseNet.from_pretrained("username/my-attentionbasenet-model") |
| |
| # Load with a different number of outputs (head is rebuilt automatically) |
| model = AttentionBaseNet.from_pretrained("username/my-attentionbasenet-model", n_outputs=4) |
| |
| **Extracting features and replacing the head:** |
|
|
| .. code:: |
| import torch |
| |
| x = torch.randn(1, model.n_chans, model.n_times) |
| # Extract encoder features (consistent dict across all models) |
| out = model(x, return_features=True) |
| features = out["features"] |
| |
| # Replace the classification head |
| model.reset_head(n_outputs=10) |
| |
| **Saving and restoring full configuration:** |
|
|
| .. code:: |
| import json |
| |
| config = model.get_config() # all __init__ params |
| with open("config.json", "w") as f: |
| json.dump(config, f) |
| |
| model2 = AttentionBaseNet.from_config(config) # reconstruct (no weights) |
| |
| All model parameters (both EEG-specific and model-specific such as |
| dropout rates, activation functions, number of filters) are automatically |
| saved to the Hub and restored when loading. |
|
|
| See :ref:`load-pretrained-models` for a complete tutorial.</main> |
| </div> |
|
|
| ## Citation |
|
|
| Please cite both the original paper for this architecture (see the |
| *References* section above) and braindecode: |
|
|
| ```bibtex |
| @article{aristimunha2025braindecode, |
| title = {Braindecode: a deep learning library for raw electrophysiological data}, |
| author = {Aristimunha, Bruno and others}, |
| journal = {Zenodo}, |
| year = {2025}, |
| doi = {10.5281/zenodo.17699192}, |
| } |
| ``` |
|
|
| ## License |
|
|
| BSD-3-Clause for the model code (matching braindecode). |
| Pretraining-derived weights, if you fine-tune from a checkpoint, |
| inherit the licence of that checkpoint and its training corpus. |
|
|