--- license: bsd-3-clause library_name: braindecode pipeline_tag: feature-extraction tags: - eeg - biosignal - pytorch - neuroscience - braindecode - convolutional - transformer --- # AttentionBaseNet AttentionBaseNet from Wimpff M et al (2023) . > **Architecture-only repository.** This repo documents the > `braindecode.models.AttentionBaseNet` class. **No pretrained weights are > distributed here** — instantiate the model and train it on your own > data, or fine-tune from a published foundation-model checkpoint > separately. ## Quick start ```bash pip install braindecode ``` ```python from braindecode.models import AttentionBaseNet model = AttentionBaseNet( n_chans=22, sfreq=250, input_window_seconds=4.0, n_outputs=4, ) ``` The signal-shape arguments above are example defaults — adjust them to match your recording. ## Documentation - Full API reference (parameters, references, architecture figure): - Interactive browser with live instantiation: - Source on GitHub: ## Architecture description The block below is the rendered class docstring (parameters, references, architecture figure where available).

AttentionBaseNet from Wimpff M et al (2023) [Martin2023]_.

ConvolutionAttention/Transformer .. figure:: https://content.cld.iop.org/journals/1741-2552/21/3/036020/revision2/jnead48b9f2_hr.jpg :align: center :alt: AttentionBaseNet Architecture :width: 640px .. rubric:: Architectural Overview AttentionBaseNet is a *convolution-first* network with a *channel-attention* stage. The end-to-end flow is: - (i) :class:`_FeatureExtractor` learns a temporal filter bank and per-filter spatial projections (depthwise across electrodes), then condenses time by pooling; - (ii) **Channel Expansion** uses a ``1x1`` convolution to set the feature width; - (iii) :class:`_ChannelAttentionBlock` refines features via depthwise–pointwise temporal convs and an optional channel-attention module (SE/CBAM/ECA/…); - (iv) **Classifier** flattens the sequence and applies a linear readout. This design mirrors shallow CNN pipelines (EEGNet-style stem) but inserts a pluggable attention unit that *re-weights channels* (and optionally temporal positions) before classification. .. rubric:: Macro Components - :class:`_FeatureExtractor` **(Shallow conv stem → condensed feature map)** - *Operations.* - **Temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(1, L_t)`` creates a learned FIR-like filter bank with ``n_temporal_filters`` maps. - **Depthwise spatial conv** (:class:`torch.nn.Conv2d`, ``groups=n_temporal_filters``) with kernel ``(n_chans, 1)`` learns per-filter spatial projections over the full montage. - **BatchNorm → ELU → AvgPool → Dropout** stabilize and downsample time. - Output shape: ``(B, F2, 1, T₁)`` with ``F2 = n_temporal_filters x spatial_expansion``. *Interpretability/robustness.* Temporal kernels behave as analyzable FIR filters; the depthwise spatial step yields rhythm-specific topographies. Pooling acts as a local integrator that reduces variance on short EEG windows. - **Channel Expansion** - *Operations.* - A ``1x1`` conv → BN → activation maps ``F2 → ch_dim`` without changing the temporal length ``T₁`` (shape: ``(B, ch_dim, 1, T₁)``). This sets the embedding width for the attention block. - :class:`_ChannelAttentionBlock` **(temporal refinement + channel attention)** - *Operations.* - **Depthwise temporal conv** ``(1, L_a)`` (groups=``ch_dim``) + **pointwise ``1x1``**, BN and activation → preserves shape ``(B, ch_dim, 1, T₁)`` while refining timing. - **Optional attention module** (see *Additional Mechanisms*) applies channel reweighting (some variants also apply temporal gating). - **AvgPool (1, P₂)** with stride ``(1, S₂)`` and **Dropout** → outputs ``(B, ch_dim, 1, T₂)``. *Role.* Emphasizes informative channels (and, in certain modes, salient time steps) before the classifier; complements the convolutional priors with adaptive re-weighting. - **Classifier (aggregation + readout)** *Operations.* :class:`torch.nn.Flatten` → :class:`torch.nn.Linear` from ``(B, ch_dim·T₂)`` to classes. .. rubric:: Convolutional Details - **Temporal (where time-domain patterns are learned).** Wide kernels in the stem (``(1, L_t)``) act as a learned filter bank for oscillatory bands/transients; the attention block's depthwise temporal conv (``(1, L_a)``) sharpens short-term dynamics after downsampling. Pool sizes/strides (``P₁,S₁`` then ``P₂,S₂``) set the token rate and effective temporal resolution. - **Spatial (how electrodes are processed).** A depthwise spatial conv with kernel ``(n_chans, 1)`` spans the full montage to learn *per-temporal-filter* spatial projections (no cross-filter mixing at this step), mirroring the interpretable spatial stage in shallow CNNs. - **Spectral (how frequency content is captured).** No explicit Fourier/wavelet transform is used in the stem—spectral selectivity emerges from learned temporal kernels. When ``attention_mode="fca"``, a frequency channel attention (DCT-based) summarizes frequencies to drive channel weights. .. rubric:: Attention / Sequential Modules - **Type.** Channel attention chosen by ``attention_mode`` (SE, ECA, CBAM, CAT, GSoP, EncNet, GE, GCT, SRM, CATLite). Most operate purely on channels; CBAM/CAT additionally include temporal attention. - **Shapes.** Input/Output around attention: ``(B, ch_dim, 1, T₁)``. Re-arrangements (if any) are internal to the module; the block returns the same shape before pooling. - **Role.** Re-weights channels (and optionally time) to highlight informative sources and suppress distractors, improving SNR ahead of the linear head. .. rubric:: Additional Mechanisms **Attention variants at a glance:** - ``"se"``: Squeeze-and-Excitation (global pooling → bottleneck → gates). - ``"gsop"``: Global second-order pooling (covariance-aware channel weights). - ``"fca"``: Frequency Channel Attention (DCT summary; uses ``seq_len`` and ``freq_idx``). - ``"encnet"``: EncNet with learned codewords (uses ``n_codewords``). - ``"eca"``: Efficient Channel Attention (local 1-D conv over channel descriptor; uses ``kernel_size``). - ``"ge"``: Gather–Excite (context pooling with optional MLP; can use ``extra_params``). - ``"gct"``: Gated Channel Transformation (global context normalization + gating). - ``"srm"``: Style-based recalibration (mean–std descriptors; optional MLP). - ``"cbam"``: Channel then temporal attention (uses ``kernel_size``). - ``"cat"`` / ``"catlite"``: Collaborative (channel ± temporal) attention; *lite* omits temporal. **Auto-compatibility on short inputs:** If the input duration is too short for the configured kernels/pools, the implementation **automatically rescales** temporal lengths/strides downward (with a warning) to keep shapes valid and preserve the pipeline semantics. .. rubric:: Usage and Configuration - ``n_temporal_filters``, ``temporal_filter_length`` and ``spatial_expansion``: control the capacity and the number of spatial projections in the stem. - ``pool_length_inp``, ``pool_stride_inp`` then ``pool_length``, ``pool_stride``: trade temporal resolution for compute; they determine the final sequence length ``T₂``. - ``ch_dim``: width after the ``1x1`` expansion and the effective embedding size for attention. - ``attention_mode`` + its specific hyperparameters (``reduction_rate``, ``kernel_size``, ``seq_len``, ``freq_idx``, ``n_codewords``, ``use_mlp``): select and tune the reweighting mechanism. - ``drop_prob_inp`` and ``drop_prob_attn``: regularize stem and attention stages. - **Training tips.** Start with moderate pooling (e.g., ``P₁=75,S₁=15``) and ELU activations; enable attention only after the stem learns stable filters. For small datasets, prefer simpler modes (``"se"``, ``"eca"``) before heavier ones (``"gsop"``, ``"encnet"``). Parameters ---------- n_temporal_filters : int, optional Number of temporal convolutional filters in the first layer. This defines the number of output channels after the temporal convolution. Default is 40. temp_filter_length : int, default=15 The length of the temporal filters in the convolutional layers. spatial_expansion : int, optional Multiplicative factor to expand the spatial dimensions. Used to increase the capacity of the model by expanding spatial features. Default is 1. pool_length_inp : int, optional Length of the pooling window in the input layer. Determines how much temporal information is aggregated during pooling. Default is 75. pool_stride_inp : int, optional Stride of the pooling operation in the input layer. Controls the downsampling factor in the temporal dimension. Default is 15. drop_prob_inp : float, optional Dropout rate applied after the input layer. This is the probability of zeroing out elements during training to prevent overfitting. Default is 0.5. ch_dim : int, optional Number of channels in the subsequent convolutional layers. This controls the depth of the network after the initial layer. Default is 16. attention_mode : str, optional The type of attention mechanism to apply. If `None`, no attention is applied. - "se" for Squeeze-and-excitation network - "gsop" for Global Second-Order Pooling - "fca" for Frequency Channel Attention Network - "encnet" for context encoding module - "eca" for Efficient channel attention for deep convolutional neural networks - "ge" for Gather-Excite - "gct" for Gated Channel Transformation - "srm" for Style-based Recalibration Module - "cbam" for Convolutional Block Attention Module - "cat" for Learning to collaborate channel and temporal attention from multi-information fusion - "catlite" for Learning to collaborate channel attention from multi-information fusion (lite version, cat w/o temporal attention) pool_length : int, default=8 The length of the window for the average pooling operation. pool_stride : int, default=8 The stride of the average pooling operation. drop_prob_attn : float, default=0.5 The dropout rate for regularization for the attention layer. Values should be between 0 and 1. reduction_rate : int, default=4 The reduction rate used in the attention mechanism to reduce dimensionality and computational complexity. use_mlp : bool, default=False Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used within the attention mechanism for further processing. freq_idx : int, default=0 DCT index used in fca attention mechanism. n_codewords : int, default=4 The number of codewords (clusters) used in attention mechanisms that employ quantization or clustering strategies. kernel_size : int, default=9 The kernel size used in certain types of attention mechanisms for convolution operations. activation : type[nn.Module] = nn.ELU, Activation function class to apply. Should be a PyTorch activation module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``. extra_params : bool, default=False Flag to indicate whether additional, custom parameters should be passed to the attention mechanism. Notes ----- - Sequence length after each stage is computed internally; the final classifier expects a flattened ``ch_dim x T₂`` vector. - Attention operates on *channel* dimension by design; temporal gating exists only in specific variants (CBAM/CAT). - The paper and original code with more details about the methodological choices are available at the [Martin2023]_ and [MartinCode]_. .. versionadded:: 0.9 References ---------- .. [Martin2023] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B., 2023. EEG motor imagery decoding: A framework for comparative analysis with channel attention mechanisms. arXiv preprint arXiv:2310.11198. .. [MartinCode] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B. GitHub https://github.com/martinwimpff/channel-attention (accessed 2024-03-28) .. rubric:: Hugging Face Hub integration When the optional ``huggingface_hub`` package is installed, all models automatically gain the ability to be pushed to and loaded from the Hugging Face Hub. Install with:: pip install braindecode[hub] **Pushing a model to the Hub:** .. code:: from braindecode.models import AttentionBaseNet # Train your model model = AttentionBaseNet(n_chans=22, n_outputs=4, n_times=1000) # ... training code ... # Push to the Hub model.push_to_hub( repo_id="username/my-attentionbasenet-model", commit_message="Initial model upload", ) **Loading a model from the Hub:** .. code:: from braindecode.models import AttentionBaseNet # Load pretrained model model = AttentionBaseNet.from_pretrained("username/my-attentionbasenet-model") # Load with a different number of outputs (head is rebuilt automatically) model = AttentionBaseNet.from_pretrained("username/my-attentionbasenet-model", n_outputs=4) **Extracting features and replacing the head:** .. code:: import torch x = torch.randn(1, model.n_chans, model.n_times) # Extract encoder features (consistent dict across all models) out = model(x, return_features=True) features = out["features"] # Replace the classification head model.reset_head(n_outputs=10) **Saving and restoring full configuration:** .. code:: import json config = model.get_config() # all __init__ params with open("config.json", "w") as f: json.dump(config, f) model2 = AttentionBaseNet.from_config(config) # reconstruct (no weights) All model parameters (both EEG-specific and model-specific such as dropout rates, activation functions, number of filters) are automatically saved to the Hub and restored when loading. See :ref:`load-pretrained-models` for a complete tutorial.
## Citation Please cite both the original paper for this architecture (see the *References* section above) and braindecode: ```bibtex @article{aristimunha2025braindecode, title = {Braindecode: a deep learning library for raw electrophysiological data}, author = {Aristimunha, Bruno and others}, journal = {Zenodo}, year = {2025}, doi = {10.5281/zenodo.17699192}, } ``` ## License BSD-3-Clause for the model code (matching braindecode). Pretraining-derived weights, if you fine-tune from a checkpoint, inherit the licence of that checkpoint and its training corpus.