AttentionBaseNet / README.md
bruAristimunha's picture
Add architecture-only model card
6318a15 verified
|
raw
history blame
15.5 kB
metadata
license: bsd-3-clause
library_name: braindecode
pipeline_tag: feature-extraction
tags:
  - eeg
  - biosignal
  - pytorch
  - neuroscience
  - braindecode
  - convolutional
  - transformer

AttentionBaseNet

AttentionBaseNet from Wimpff M et al (2023) .

Architecture-only repository. This repo documents the braindecode.models.AttentionBaseNet class. No pretrained weights are distributed here — instantiate the model and train it on your own data, or fine-tune from a published foundation-model checkpoint separately.

Quick start

pip install braindecode
from braindecode.models import AttentionBaseNet

model = AttentionBaseNet(
    n_chans=22,
    sfreq=250,
    input_window_seconds=4.0,
    n_outputs=4,
)

The signal-shape arguments above are example defaults — adjust them to match your recording.

Documentation

Architecture description

The block below is the rendered class docstring (parameters, references, architecture figure where available).

AttentionBaseNet from Wimpff M et al (2023) [Martin2023]_.

ConvolutionAttention/Transformer

.. figure:: https://content.cld.iop.org/journals/1741-2552/21/3/036020/revision2/jnead48b9f2_hr.jpg :align: center :alt: AttentionBaseNet Architecture :width: 640px

.. rubric:: Architectural Overview

AttentionBaseNet is a convolution-first network with a channel-attention stage. The end-to-end flow is:

  • (i) :class:_FeatureExtractor learns a temporal filter bank and per-filter spatial projections (depthwise across electrodes), then condenses time by pooling;
  • (ii) Channel Expansion uses a 1x1 convolution to set the feature width;
  • (iii) :class:_ChannelAttentionBlock refines features via depthwise–pointwise temporal convs and an optional channel-attention module (SE/CBAM/ECA/…);
  • (iv) Classifier flattens the sequence and applies a linear readout.

This design mirrors shallow CNN pipelines (EEGNet-style stem) but inserts a pluggable attention unit that re-weights channels (and optionally temporal positions) before classification.

.. rubric:: Macro Components

  • :class:_FeatureExtractor (Shallow conv stem → condensed feature map)

    • Operations.
    • Temporal conv (:class:torch.nn.Conv2d) with kernel (1, L_t) creates a learned FIR-like filter bank with n_temporal_filters maps.
    • Depthwise spatial conv (:class:torch.nn.Conv2d, groups=n_temporal_filters) with kernel (n_chans, 1) learns per-filter spatial projections over the full montage.
    • BatchNorm → ELU → AvgPool → Dropout stabilize and downsample time.
    • Output shape: (B, F2, 1, T₁) with F2 = n_temporal_filters x spatial_expansion.

Interpretability/robustness. Temporal kernels behave as analyzable FIR filters; the depthwise spatial step yields rhythm-specific topographies. Pooling acts as a local integrator that reduces variance on short EEG windows.

  • Channel Expansion

    • Operations.
    • A 1x1 conv → BN → activation maps F2 → ch_dim without changing the temporal length T₁ (shape: (B, ch_dim, 1, T₁)). This sets the embedding width for the attention block.
  • :class:_ChannelAttentionBlock (temporal refinement + channel attention)

    • Operations.
    • Depthwise temporal conv (1, L_a) (groups=ch_dim) + pointwise 1x1, BN and activation → preserves shape (B, ch_dim, 1, T₁) while refining timing.
    • Optional attention module (see Additional Mechanisms) applies channel reweighting (some variants also apply temporal gating).
    • AvgPool (1, P₂) with stride (1, S₂) and Dropout → outputs (B, ch_dim, 1, T₂).

Role. Emphasizes informative channels (and, in certain modes, salient time steps) before the classifier; complements the convolutional priors with adaptive re-weighting.

  • Classifier (aggregation + readout)

Operations. :class:torch.nn.Flatten → :class:torch.nn.Linear from (B, ch_dim·T₂) to classes.

.. rubric:: Convolutional Details

  • Temporal (where time-domain patterns are learned). Wide kernels in the stem ((1, L_t)) act as a learned filter bank for oscillatory bands/transients; the attention block's depthwise temporal conv ((1, L_a)) sharpens short-term dynamics after downsampling. Pool sizes/strides (P₁,S₁ then P₂,S₂) set the token rate and effective temporal resolution.

  • Spatial (how electrodes are processed). A depthwise spatial conv with kernel (n_chans, 1) spans the full montage to learn per-temporal-filter spatial projections (no cross-filter mixing at this step), mirroring the interpretable spatial stage in shallow CNNs.

  • Spectral (how frequency content is captured). No explicit Fourier/wavelet transform is used in the stem—spectral selectivity emerges from learned temporal kernels. When attention_mode="fca", a frequency channel attention (DCT-based) summarizes frequencies to drive channel weights.

.. rubric:: Attention / Sequential Modules

  • Type. Channel attention chosen by attention_mode (SE, ECA, CBAM, CAT, GSoP, EncNet, GE, GCT, SRM, CATLite). Most operate purely on channels; CBAM/CAT additionally include temporal attention.

  • Shapes. Input/Output around attention: (B, ch_dim, 1, T₁). Re-arrangements (if any) are internal to the module; the block returns the same shape before pooling.

  • Role. Re-weights channels (and optionally time) to highlight informative sources and suppress distractors, improving SNR ahead of the linear head.

.. rubric:: Additional Mechanisms

Attention variants at a glance:

  • "se": Squeeze-and-Excitation (global pooling → bottleneck → gates).
  • "gsop": Global second-order pooling (covariance-aware channel weights).
  • "fca": Frequency Channel Attention (DCT summary; uses seq_len and freq_idx).
  • "encnet": EncNet with learned codewords (uses n_codewords).
  • "eca": Efficient Channel Attention (local 1-D conv over channel descriptor; uses kernel_size).
  • "ge": Gather–Excite (context pooling with optional MLP; can use extra_params).
  • "gct": Gated Channel Transformation (global context normalization + gating).
  • "srm": Style-based recalibration (mean–std descriptors; optional MLP).
  • "cbam": Channel then temporal attention (uses kernel_size).
  • "cat" / "catlite": Collaborative (channel ± temporal) attention; lite omits temporal.

Auto-compatibility on short inputs:

 If the input duration is too short for the configured kernels/pools, the implementation
 **automatically rescales** temporal lengths/strides downward (with a warning) to keep
 shapes valid and preserve the pipeline semantics.

.. rubric:: Usage and Configuration

  • n_temporal_filters, temporal_filter_length and spatial_expansion: control the capacity and the number of spatial projections in the stem.

  • pool_length_inp, pool_stride_inp then pool_length, pool_stride: trade temporal resolution for compute; they determine the final sequence length T₂.

  • ch_dim: width after the 1x1 expansion and the effective embedding size for attention.

  • attention_mode + its specific hyperparameters (reduction_rate, kernel_size, seq_len, freq_idx, n_codewords, use_mlp): select and tune the reweighting mechanism.

  • drop_prob_inp and drop_prob_attn: regularize stem and attention stages.

  • Training tips.

    Start with moderate pooling (e.g., P₁=75,S₁=15) and ELU activations; enable attention only after the stem learns stable filters. For small datasets, prefer simpler modes ("se", "eca") before heavier ones ("gsop", "encnet").

Parameters

n_temporal_filters : int, optional Number of temporal convolutional filters in the first layer. This defines the number of output channels after the temporal convolution. Default is 40. temp_filter_length : int, default=15 The length of the temporal filters in the convolutional layers. spatial_expansion : int, optional Multiplicative factor to expand the spatial dimensions. Used to increase the capacity of the model by expanding spatial features. Default is 1. pool_length_inp : int, optional Length of the pooling window in the input layer. Determines how much temporal information is aggregated during pooling. Default is 75. pool_stride_inp : int, optional Stride of the pooling operation in the input layer. Controls the downsampling factor in the temporal dimension. Default is 15. drop_prob_inp : float, optional Dropout rate applied after the input layer. This is the probability of zeroing out elements during training to prevent overfitting. Default is 0.5. ch_dim : int, optional Number of channels in the subsequent convolutional layers. This controls the depth of the network after the initial layer. Default is 16. attention_mode : str, optional The type of attention mechanism to apply. If None, no attention is applied.

 - "se" for Squeeze-and-excitation network
 - "gsop" for Global Second-Order Pooling
 - "fca" for Frequency Channel Attention Network
 - "encnet" for context encoding module
 - "eca" for Efficient channel attention for deep convolutional neural networks
 - "ge" for Gather-Excite
 - "gct" for Gated Channel Transformation
 - "srm" for Style-based Recalibration Module
 - "cbam" for Convolutional Block Attention Module
 - "cat" for Learning to collaborate channel and temporal attention
   from multi-information fusion
 - "catlite" for Learning to collaborate channel attention
   from multi-information fusion (lite version, cat w/o temporal attention)

pool_length : int, default=8 The length of the window for the average pooling operation. pool_stride : int, default=8 The stride of the average pooling operation. drop_prob_attn : float, default=0.5 The dropout rate for regularization for the attention layer. Values should be between 0 and 1. reduction_rate : int, default=4 The reduction rate used in the attention mechanism to reduce dimensionality and computational complexity. use_mlp : bool, default=False Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used within the attention mechanism for further processing. freq_idx : int, default=0 DCT index used in fca attention mechanism. n_codewords : int, default=4 The number of codewords (clusters) used in attention mechanisms that employ quantization or clustering strategies. kernel_size : int, default=9 The kernel size used in certain types of attention mechanisms for convolution operations. activation : type[nn.Module] = nn.ELU, Activation function class to apply. Should be a PyTorch activation module class like nn.ReLU or nn.ELU. Default is nn.ELU. extra_params : bool, default=False Flag to indicate whether additional, custom parameters should be passed to the attention mechanism.

Notes

  • Sequence length after each stage is computed internally; the final classifier expects a flattened ch_dim x T₂ vector.
  • Attention operates on channel dimension by design; temporal gating exists only in specific variants (CBAM/CAT).
  • The paper and original code with more details about the methodological choices are available at the [Martin2023]_ and [MartinCode]_.

.. versionadded:: 0.9

References

.. [Martin2023] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B., 2023. EEG motor imagery decoding: A framework for comparative analysis with channel attention mechanisms. arXiv preprint arXiv:2310.11198. .. [MartinCode] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B. GitHub https://github.com/martinwimpff/channel-attention (accessed 2024-03-28)

.. rubric:: Hugging Face Hub integration

When the optional huggingface_hub package is installed, all models automatically gain the ability to be pushed to and loaded from the Hugging Face Hub. Install with::

 pip install braindecode[hub]

Pushing a model to the Hub:

.. code:: from braindecode.models import AttentionBaseNet

 # Train your model
 model = AttentionBaseNet(n_chans=22, n_outputs=4, n_times=1000)
 # ... training code ...

 # Push to the Hub
 model.push_to_hub(
     repo_id="username/my-attentionbasenet-model",
     commit_message="Initial model upload",
 )

Loading a model from the Hub:

.. code:: from braindecode.models import AttentionBaseNet

 # Load pretrained model
 model = AttentionBaseNet.from_pretrained("username/my-attentionbasenet-model")

 # Load with a different number of outputs (head is rebuilt automatically)
 model = AttentionBaseNet.from_pretrained("username/my-attentionbasenet-model", n_outputs=4)

Extracting features and replacing the head:

.. code:: import torch

 x = torch.randn(1, model.n_chans, model.n_times)
 # Extract encoder features (consistent dict across all models)
 out = model(x, return_features=True)
 features = out["features"]

 # Replace the classification head
 model.reset_head(n_outputs=10)

Saving and restoring full configuration:

.. code:: import json

 config = model.get_config()            # all __init__ params
 with open("config.json", "w") as f:
     json.dump(config, f)

 model2 = AttentionBaseNet.from_config(config)    # reconstruct (no weights)

All model parameters (both EEG-specific and model-specific such as dropout rates, activation functions, number of filters) are automatically saved to the Hub and restored when loading.

See :ref:load-pretrained-models for a complete tutorial.

Citation

Please cite both the original paper for this architecture (see the References section above) and braindecode:

@article{aristimunha2025braindecode,
  title   = {Braindecode: a deep learning library for raw electrophysiological data},
  author  = {Aristimunha, Bruno and others},
  journal = {Zenodo},
  year    = {2025},
  doi     = {10.5281/zenodo.17699192},
}

License

BSD-3-Clause for the model code (matching braindecode). Pretraining-derived weights, if you fine-tune from a checkpoint, inherit the licence of that checkpoint and its training corpus.