File size: 15,515 Bytes
6318a15 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 | ---
license: bsd-3-clause
library_name: braindecode
pipeline_tag: feature-extraction
tags:
- eeg
- biosignal
- pytorch
- neuroscience
- braindecode
- convolutional
- transformer
---
# AttentionBaseNet
AttentionBaseNet from Wimpff M et al (2023) .
> **Architecture-only repository.** This repo documents the
> `braindecode.models.AttentionBaseNet` class. **No pretrained weights are
> distributed here** — instantiate the model and train it on your own
> data, or fine-tune from a published foundation-model checkpoint
> separately.
## Quick start
```bash
pip install braindecode
```
```python
from braindecode.models import AttentionBaseNet
model = AttentionBaseNet(
n_chans=22,
sfreq=250,
input_window_seconds=4.0,
n_outputs=4,
)
```
The signal-shape arguments above are example defaults — adjust them
to match your recording.
## Documentation
- Full API reference (parameters, references, architecture figure):
<https://braindecode.org/stable/generated/braindecode.models.AttentionBaseNet.html>
- Interactive browser with live instantiation:
<https://huggingface.co/spaces/braindecode/model-explorer>
- Source on GitHub: <https://github.com/braindecode/braindecode/blob/master/braindecode/models/attentionbasenet.py#L29>
## Architecture description
The block below is the rendered class docstring (parameters,
references, architecture figure where available).
<div class='bd-doc'><main>
<p>AttentionBaseNet from Wimpff M et al (2023) [Martin2023]_.</p>
<span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#5cb85c;color:white;font-size:11px;font-weight:600;margin-right:4px;">Convolution</span><span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#56B4E9;color:white;font-size:11px;font-weight:600;margin-right:4px;">Attention/Transformer</span>
.. figure:: https://content.cld.iop.org/journals/1741-2552/21/3/036020/revision2/jnead48b9f2_hr.jpg
:align: center
:alt: AttentionBaseNet Architecture
:width: 640px
.. rubric:: Architectural Overview
AttentionBaseNet is a *convolution-first* network with a *channel-attention* stage.
The end-to-end flow is:
- (i) :class:`_FeatureExtractor` learns a temporal filter bank and per-filter spatial
projections (depthwise across electrodes), then condenses time by pooling;
- (ii) **Channel Expansion** uses a ``1x1`` convolution to set the feature width;
- (iii) :class:`_ChannelAttentionBlock` refines features via depthwise–pointwise temporal
convs and an optional channel-attention module (SE/CBAM/ECA/…);
- (iv) **Classifier** flattens the sequence and applies a linear readout.
This design mirrors shallow CNN pipelines (EEGNet-style stem) but inserts a pluggable
attention unit that *re-weights channels* (and optionally temporal positions) before
classification.
.. rubric:: Macro Components
- :class:`_FeatureExtractor` **(Shallow conv stem → condensed feature map)**
- *Operations.*
- **Temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(1, L_t)`` creates a learned
FIR-like filter bank with ``n_temporal_filters`` maps.
- **Depthwise spatial conv** (:class:`torch.nn.Conv2d`, ``groups=n_temporal_filters``)
with kernel ``(n_chans, 1)`` learns per-filter spatial projections over the full montage.
- **BatchNorm → ELU → AvgPool → Dropout** stabilize and downsample time.
- Output shape: ``(B, F2, 1, T₁)`` with ``F2 = n_temporal_filters x spatial_expansion``.
*Interpretability/robustness.* Temporal kernels behave as analyzable FIR filters; the
depthwise spatial step yields rhythm-specific topographies. Pooling acts as a local
integrator that reduces variance on short EEG windows.
- **Channel Expansion**
- *Operations.*
- A ``1x1`` conv → BN → activation maps ``F2 → ch_dim`` without changing
the temporal length ``T₁`` (shape: ``(B, ch_dim, 1, T₁)``).
This sets the embedding width for the attention block.
- :class:`_ChannelAttentionBlock` **(temporal refinement + channel attention)**
- *Operations.*
- **Depthwise temporal conv** ``(1, L_a)`` (groups=``ch_dim``) + **pointwise ``1x1``**,
BN and activation → preserves shape ``(B, ch_dim, 1, T₁)`` while refining timing.
- **Optional attention module** (see *Additional Mechanisms*) applies channel reweighting
(some variants also apply temporal gating).
- **AvgPool (1, P₂)** with stride ``(1, S₂)`` and **Dropout** → outputs
``(B, ch_dim, 1, T₂)``.
*Role.* Emphasizes informative channels (and, in certain modes, salient time steps)
before the classifier; complements the convolutional priors with adaptive re-weighting.
- **Classifier (aggregation + readout)**
*Operations.* :class:`torch.nn.Flatten` → :class:`torch.nn.Linear` from
``(B, ch_dim·T₂)`` to classes.
.. rubric:: Convolutional Details
- **Temporal (where time-domain patterns are learned).**
Wide kernels in the stem (``(1, L_t)``) act as a learned filter bank for oscillatory
bands/transients; the attention block's depthwise temporal conv (``(1, L_a)``) sharpens
short-term dynamics after downsampling. Pool sizes/strides (``P₁,S₁`` then ``P₂,S₂``)
set the token rate and effective temporal resolution.
- **Spatial (how electrodes are processed).**
A depthwise spatial conv with kernel ``(n_chans, 1)`` spans the full montage to
learn *per-temporal-filter* spatial projections (no cross-filter mixing at this step),
mirroring the interpretable spatial stage in shallow CNNs.
- **Spectral (how frequency content is captured).**
No explicit Fourier/wavelet transform is used in the stem—spectral selectivity
emerges from learned temporal kernels. When ``attention_mode="fca"``, a frequency
channel attention (DCT-based) summarizes frequencies to drive channel weights.
.. rubric:: Attention / Sequential Modules
- **Type.** Channel attention chosen by ``attention_mode`` (SE, ECA, CBAM, CAT, GSoP,
EncNet, GE, GCT, SRM, CATLite). Most operate purely on channels; CBAM/CAT additionally
include temporal attention.
- **Shapes.** Input/Output around attention: ``(B, ch_dim, 1, T₁)``. Re-arrangements
(if any) are internal to the module; the block returns the same shape before pooling.
- **Role.** Re-weights channels (and optionally time) to highlight informative sources
and suppress distractors, improving SNR ahead of the linear head.
.. rubric:: Additional Mechanisms
**Attention variants at a glance:**
- ``"se"``: Squeeze-and-Excitation (global pooling → bottleneck → gates).
- ``"gsop"``: Global second-order pooling (covariance-aware channel weights).
- ``"fca"``: Frequency Channel Attention (DCT summary; uses ``seq_len`` and ``freq_idx``).
- ``"encnet"``: EncNet with learned codewords (uses ``n_codewords``).
- ``"eca"``: Efficient Channel Attention (local 1-D conv over channel descriptor; uses ``kernel_size``).
- ``"ge"``: Gather–Excite (context pooling with optional MLP; can use ``extra_params``).
- ``"gct"``: Gated Channel Transformation (global context normalization + gating).
- ``"srm"``: Style-based recalibration (mean–std descriptors; optional MLP).
- ``"cbam"``: Channel then temporal attention (uses ``kernel_size``).
- ``"cat"`` / ``"catlite"``: Collaborative (channel ± temporal) attention; *lite* omits temporal.
**Auto-compatibility on short inputs:**
If the input duration is too short for the configured kernels/pools, the implementation
**automatically rescales** temporal lengths/strides downward (with a warning) to keep
shapes valid and preserve the pipeline semantics.
.. rubric:: Usage and Configuration
- ``n_temporal_filters``, ``temporal_filter_length`` and ``spatial_expansion``:
control the capacity and the number of spatial projections in the stem.
- ``pool_length_inp``, ``pool_stride_inp`` then ``pool_length``, ``pool_stride``:
trade temporal resolution for compute; they determine the final sequence length ``T₂``.
- ``ch_dim``: width after the ``1x1`` expansion and the effective embedding size for attention.
- ``attention_mode`` + its specific hyperparameters (``reduction_rate``,
``kernel_size``, ``seq_len``, ``freq_idx``, ``n_codewords``, ``use_mlp``):
select and tune the reweighting mechanism.
- ``drop_prob_inp`` and ``drop_prob_attn``: regularize stem and attention stages.
- **Training tips.**
Start with moderate pooling (e.g., ``P₁=75,S₁=15``) and ELU activations; enable attention
only after the stem learns stable filters. For small datasets, prefer simpler modes
(``"se"``, ``"eca"``) before heavier ones (``"gsop"``, ``"encnet"``).
Parameters
----------
n_temporal_filters : int, optional
Number of temporal convolutional filters in the first layer. This defines
the number of output channels after the temporal convolution.
Default is 40.
temp_filter_length : int, default=15
The length of the temporal filters in the convolutional layers.
spatial_expansion : int, optional
Multiplicative factor to expand the spatial dimensions. Used to increase
the capacity of the model by expanding spatial features. Default is 1.
pool_length_inp : int, optional
Length of the pooling window in the input layer. Determines how much
temporal information is aggregated during pooling. Default is 75.
pool_stride_inp : int, optional
Stride of the pooling operation in the input layer. Controls the
downsampling factor in the temporal dimension. Default is 15.
drop_prob_inp : float, optional
Dropout rate applied after the input layer. This is the probability of
zeroing out elements during training to prevent overfitting.
Default is 0.5.
ch_dim : int, optional
Number of channels in the subsequent convolutional layers. This controls
the depth of the network after the initial layer. Default is 16.
attention_mode : str, optional
The type of attention mechanism to apply. If `None`, no attention is applied.
- "se" for Squeeze-and-excitation network
- "gsop" for Global Second-Order Pooling
- "fca" for Frequency Channel Attention Network
- "encnet" for context encoding module
- "eca" for Efficient channel attention for deep convolutional neural networks
- "ge" for Gather-Excite
- "gct" for Gated Channel Transformation
- "srm" for Style-based Recalibration Module
- "cbam" for Convolutional Block Attention Module
- "cat" for Learning to collaborate channel and temporal attention
from multi-information fusion
- "catlite" for Learning to collaborate channel attention
from multi-information fusion (lite version, cat w/o temporal attention)
pool_length : int, default=8
The length of the window for the average pooling operation.
pool_stride : int, default=8
The stride of the average pooling operation.
drop_prob_attn : float, default=0.5
The dropout rate for regularization for the attention layer. Values should be between 0 and 1.
reduction_rate : int, default=4
The reduction rate used in the attention mechanism to reduce dimensionality
and computational complexity.
use_mlp : bool, default=False
Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used within
the attention mechanism for further processing.
freq_idx : int, default=0
DCT index used in fca attention mechanism.
n_codewords : int, default=4
The number of codewords (clusters) used in attention mechanisms that employ
quantization or clustering strategies.
kernel_size : int, default=9
The kernel size used in certain types of attention mechanisms for convolution
operations.
activation : type[nn.Module] = nn.ELU,
Activation function class to apply. Should be a PyTorch activation
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
extra_params : bool, default=False
Flag to indicate whether additional, custom parameters should be passed to
the attention mechanism.
Notes
-----
- Sequence length after each stage is computed internally; the final classifier expects
a flattened ``ch_dim x T₂`` vector.
- Attention operates on *channel* dimension by design; temporal gating exists only in
specific variants (CBAM/CAT).
- The paper and original code with more details about the methodological
choices are available at the [Martin2023]_ and [MartinCode]_.
.. versionadded:: 0.9
References
----------
.. [Martin2023] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B., 2023.
EEG motor imagery decoding: A framework for comparative analysis with
channel attention mechanisms. arXiv preprint arXiv:2310.11198.
.. [MartinCode] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B.
GitHub https://github.com/martinwimpff/channel-attention (accessed 2024-03-28)
.. rubric:: Hugging Face Hub integration
When the optional ``huggingface_hub`` package is installed, all models
automatically gain the ability to be pushed to and loaded from the
Hugging Face Hub. Install with::
pip install braindecode[hub]
**Pushing a model to the Hub:**
.. code::
from braindecode.models import AttentionBaseNet
# Train your model
model = AttentionBaseNet(n_chans=22, n_outputs=4, n_times=1000)
# ... training code ...
# Push to the Hub
model.push_to_hub(
repo_id="username/my-attentionbasenet-model",
commit_message="Initial model upload",
)
**Loading a model from the Hub:**
.. code::
from braindecode.models import AttentionBaseNet
# Load pretrained model
model = AttentionBaseNet.from_pretrained("username/my-attentionbasenet-model")
# Load with a different number of outputs (head is rebuilt automatically)
model = AttentionBaseNet.from_pretrained("username/my-attentionbasenet-model", n_outputs=4)
**Extracting features and replacing the head:**
.. code::
import torch
x = torch.randn(1, model.n_chans, model.n_times)
# Extract encoder features (consistent dict across all models)
out = model(x, return_features=True)
features = out["features"]
# Replace the classification head
model.reset_head(n_outputs=10)
**Saving and restoring full configuration:**
.. code::
import json
config = model.get_config() # all __init__ params
with open("config.json", "w") as f:
json.dump(config, f)
model2 = AttentionBaseNet.from_config(config) # reconstruct (no weights)
All model parameters (both EEG-specific and model-specific such as
dropout rates, activation functions, number of filters) are automatically
saved to the Hub and restored when loading.
See :ref:`load-pretrained-models` for a complete tutorial.</main>
</div>
## Citation
Please cite both the original paper for this architecture (see the
*References* section above) and braindecode:
```bibtex
@article{aristimunha2025braindecode,
title = {Braindecode: a deep learning library for raw electrophysiological data},
author = {Aristimunha, Bruno and others},
journal = {Zenodo},
year = {2025},
doi = {10.5281/zenodo.17699192},
}
```
## License
BSD-3-Clause for the model code (matching braindecode).
Pretraining-derived weights, if you fine-tune from a checkpoint,
inherit the licence of that checkpoint and its training corpus.
|