File size: 15,515 Bytes
6318a15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
---
license: bsd-3-clause
library_name: braindecode
pipeline_tag: feature-extraction
tags:
  - eeg
  - biosignal
  - pytorch
  - neuroscience
  - braindecode
  - convolutional
  - transformer
---

# AttentionBaseNet

AttentionBaseNet from Wimpff M et al (2023) .

> **Architecture-only repository.** This repo documents the
> `braindecode.models.AttentionBaseNet` class. **No pretrained weights are
> distributed here** — instantiate the model and train it on your own
> data, or fine-tune from a published foundation-model checkpoint
> separately.

## Quick start

```bash
pip install braindecode
```

```python
from braindecode.models import AttentionBaseNet

model = AttentionBaseNet(
    n_chans=22,
    sfreq=250,
    input_window_seconds=4.0,
    n_outputs=4,
)
```

The signal-shape arguments above are example defaults — adjust them
to match your recording.

## Documentation

- Full API reference (parameters, references, architecture figure):
  <https://braindecode.org/stable/generated/braindecode.models.AttentionBaseNet.html>
- Interactive browser with live instantiation:
  <https://huggingface.co/spaces/braindecode/model-explorer>
- Source on GitHub: <https://github.com/braindecode/braindecode/blob/master/braindecode/models/attentionbasenet.py#L29>

## Architecture description

The block below is the rendered class docstring (parameters,
references, architecture figure where available).

<div class='bd-doc'><main>
<p>AttentionBaseNet from Wimpff M et al (2023) [Martin2023]_.</p>
<span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#5cb85c;color:white;font-size:11px;font-weight:600;margin-right:4px;">Convolution</span><span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#56B4E9;color:white;font-size:11px;font-weight:600;margin-right:4px;">Attention/Transformer</span>



 .. figure:: https://content.cld.iop.org/journals/1741-2552/21/3/036020/revision2/jnead48b9f2_hr.jpg
     :align: center
     :alt: AttentionBaseNet Architecture
     :width: 640px

 .. rubric:: Architectural Overview

 AttentionBaseNet is a *convolution-first* network with a *channel-attention* stage.
 The end-to-end flow is:

 - (i) :class:`_FeatureExtractor` learns a temporal filter bank and per-filter spatial
   projections (depthwise across electrodes), then condenses time by pooling;
 - (ii) **Channel Expansion** uses a ``1x1`` convolution to set the feature width;
 - (iii) :class:`_ChannelAttentionBlock` refines features via depthwise–pointwise temporal
   convs and an optional channel-attention module (SE/CBAM/ECA/…);
 - (iv) **Classifier** flattens the sequence and applies a linear readout.

 This design mirrors shallow CNN pipelines (EEGNet-style stem) but inserts a pluggable
 attention unit that *re-weights channels* (and optionally temporal positions) before
 classification.

 .. rubric:: Macro Components

 - :class:`_FeatureExtractor` **(Shallow conv stem → condensed feature map)**

     - *Operations.*
     - **Temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(1, L_t)`` creates a learned
       FIR-like filter bank with ``n_temporal_filters`` maps.
     - **Depthwise spatial conv** (:class:`torch.nn.Conv2d`, ``groups=n_temporal_filters``)
       with kernel ``(n_chans, 1)`` learns per-filter spatial projections over the full montage.
     - **BatchNorm → ELU → AvgPool → Dropout** stabilize and downsample time.
     - Output shape: ``(B, F2, 1, T₁)`` with ``F2 = n_temporal_filters x spatial_expansion``.

 *Interpretability/robustness.* Temporal kernels behave as analyzable FIR filters; the
 depthwise spatial step yields rhythm-specific topographies. Pooling acts as a local
 integrator that reduces variance on short EEG windows.

 - **Channel Expansion**

     - *Operations.*
     - A ``1x1`` conv → BN → activation maps ``F2 → ch_dim`` without changing
       the temporal length ``T₁`` (shape: ``(B, ch_dim, 1, T₁)``).
       This sets the embedding width for the attention block.

 - :class:`_ChannelAttentionBlock` **(temporal refinement + channel attention)**

     - *Operations.*
     - **Depthwise temporal conv** ``(1, L_a)`` (groups=``ch_dim``) + **pointwise ``1x1``**,
       BN and activation → preserves shape ``(B, ch_dim, 1, T₁)`` while refining timing.
     - **Optional attention module** (see *Additional Mechanisms*) applies channel reweighting
       (some variants also apply temporal gating).
     - **AvgPool (1, P₂)** with stride ``(1, S₂)`` and **Dropout** → outputs
       ``(B, ch_dim, 1, T₂)``.

 *Role.* Emphasizes informative channels (and, in certain modes, salient time steps)
 before the classifier; complements the convolutional priors with adaptive re-weighting.

 - **Classifier (aggregation + readout)**

 *Operations.* :class:`torch.nn.Flatten` → :class:`torch.nn.Linear` from
 ``(B, ch_dim·T₂)`` to classes.

 .. rubric:: Convolutional Details

 - **Temporal (where time-domain patterns are learned).**
     Wide kernels in the stem (``(1, L_t)``) act as a learned filter bank for oscillatory
     bands/transients; the attention block's depthwise temporal conv (``(1, L_a)``) sharpens
     short-term dynamics after downsampling. Pool sizes/strides (``P₁,S₁`` then ``P₂,S₂``)
     set the token rate and effective temporal resolution.

 - **Spatial (how electrodes are processed).**
     A depthwise spatial conv with kernel ``(n_chans, 1)`` spans the full montage to
     learn *per-temporal-filter* spatial projections (no cross-filter mixing at this step),
     mirroring the interpretable spatial stage in shallow CNNs.

 - **Spectral (how frequency content is captured).**
     No explicit Fourier/wavelet transform is used in the stem—spectral selectivity
     emerges from learned temporal kernels. When ``attention_mode="fca"``, a frequency
     channel attention (DCT-based) summarizes frequencies to drive channel weights.

 .. rubric:: Attention / Sequential Modules

 - **Type.** Channel attention chosen by ``attention_mode`` (SE, ECA, CBAM, CAT, GSoP,
     EncNet, GE, GCT, SRM, CATLite). Most operate purely on channels; CBAM/CAT additionally
     include temporal attention.

 - **Shapes.** Input/Output around attention: ``(B, ch_dim, 1, T₁)``. Re-arrangements
     (if any) are internal to the module; the block returns the same shape before pooling.

 - **Role.** Re-weights channels (and optionally time) to highlight informative sources
     and suppress distractors, improving SNR ahead of the linear head.

 .. rubric:: Additional Mechanisms

 **Attention variants at a glance:**

 - ``"se"``: Squeeze-and-Excitation (global pooling → bottleneck → gates).
 - ``"gsop"``: Global second-order pooling (covariance-aware channel weights).
 - ``"fca"``: Frequency Channel Attention (DCT summary; uses ``seq_len`` and ``freq_idx``).
 - ``"encnet"``: EncNet with learned codewords (uses ``n_codewords``).
 - ``"eca"``: Efficient Channel Attention (local 1-D conv over channel descriptor; uses ``kernel_size``).
 - ``"ge"``: Gather–Excite (context pooling with optional MLP; can use ``extra_params``).
 - ``"gct"``: Gated Channel Transformation (global context normalization + gating).
 - ``"srm"``: Style-based recalibration (mean–std descriptors; optional MLP).
 - ``"cbam"``: Channel then temporal attention (uses ``kernel_size``).
 - ``"cat"`` / ``"catlite"``: Collaborative (channel ± temporal) attention; *lite* omits temporal.

 **Auto-compatibility on short inputs:**

     If the input duration is too short for the configured kernels/pools, the implementation
     **automatically rescales** temporal lengths/strides downward (with a warning) to keep
     shapes valid and preserve the pipeline semantics.

 .. rubric:: Usage and Configuration

 - ``n_temporal_filters``, ``temporal_filter_length`` and ``spatial_expansion``:
     control the capacity and the number of spatial projections in the stem.
 - ``pool_length_inp``, ``pool_stride_inp`` then ``pool_length``, ``pool_stride``:
     trade temporal resolution for compute; they determine the final sequence length ``T₂``.
 - ``ch_dim``: width after the ``1x1`` expansion and the effective embedding size for attention.
 - ``attention_mode`` + its specific hyperparameters (``reduction_rate``,
     ``kernel_size``, ``seq_len``, ``freq_idx``, ``n_codewords``, ``use_mlp``):
     select and tune the reweighting mechanism.
 - ``drop_prob_inp`` and ``drop_prob_attn``: regularize stem and attention stages.
 - **Training tips.**

     Start with moderate pooling (e.g., ``P₁=75,S₁=15``) and ELU activations; enable attention
     only after the stem learns stable filters. For small datasets, prefer simpler modes
     (``"se"``, ``"eca"``) before heavier ones (``"gsop"``, ``"encnet"``).

 Parameters
 ----------
 n_temporal_filters : int, optional
     Number of temporal convolutional filters in the first layer. This defines
     the number of output channels after the temporal convolution.
     Default is 40.
 temp_filter_length : int, default=15
     The length of the temporal filters in the convolutional layers.
 spatial_expansion : int, optional
     Multiplicative factor to expand the spatial dimensions. Used to increase
     the capacity of the model by expanding spatial features. Default is 1.
 pool_length_inp : int, optional
     Length of the pooling window in the input layer. Determines how much
     temporal information is aggregated during pooling. Default is 75.
 pool_stride_inp : int, optional
     Stride of the pooling operation in the input layer. Controls the
     downsampling factor in the temporal dimension. Default is 15.
 drop_prob_inp : float, optional
     Dropout rate applied after the input layer. This is the probability of
     zeroing out elements during training to prevent overfitting.
     Default is 0.5.
 ch_dim : int, optional
     Number of channels in the subsequent convolutional layers. This controls
     the depth of the network after the initial layer. Default is 16.
 attention_mode : str, optional
     The type of attention mechanism to apply. If `None`, no attention is applied.

     - "se" for Squeeze-and-excitation network
     - "gsop" for Global Second-Order Pooling
     - "fca" for Frequency Channel Attention Network
     - "encnet" for context encoding module
     - "eca" for Efficient channel attention for deep convolutional neural networks
     - "ge" for Gather-Excite
     - "gct" for Gated Channel Transformation
     - "srm" for Style-based Recalibration Module
     - "cbam" for Convolutional Block Attention Module
     - "cat" for Learning to collaborate channel and temporal attention
       from multi-information fusion
     - "catlite" for Learning to collaborate channel attention
       from multi-information fusion (lite version, cat w/o temporal attention)

 pool_length : int, default=8
     The length of the window for the average pooling operation.
 pool_stride : int, default=8
     The stride of the average pooling operation.
 drop_prob_attn : float, default=0.5
     The dropout rate for regularization for the attention layer. Values should be between 0 and 1.
 reduction_rate : int, default=4
     The reduction rate used in the attention mechanism to reduce dimensionality
     and computational complexity.
 use_mlp : bool, default=False
     Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used within
     the attention mechanism for further processing.
 freq_idx : int, default=0
     DCT index used in fca attention mechanism.
 n_codewords : int, default=4
     The number of codewords (clusters) used in attention mechanisms that employ
     quantization or clustering strategies.
 kernel_size : int, default=9
     The kernel size used in certain types of attention mechanisms for convolution
     operations.
 activation : type[nn.Module] = nn.ELU,
     Activation function class to apply. Should be a PyTorch activation
     module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
 extra_params : bool, default=False
     Flag to indicate whether additional, custom parameters should be passed to
     the attention mechanism.

 Notes
 -----
 - Sequence length after each stage is computed internally; the final classifier expects
   a flattened ``ch_dim x T₂`` vector.
 - Attention operates on *channel* dimension by design; temporal gating exists only in
   specific variants (CBAM/CAT).
 - The paper and original code with more details about the methodological
   choices are available at the [Martin2023]_ and [MartinCode]_.

 .. versionadded:: 0.9

 References
 ----------
 .. [Martin2023] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B., 2023.
     EEG motor imagery decoding: A framework for comparative analysis with
     channel attention mechanisms. arXiv preprint arXiv:2310.11198.
 .. [MartinCode] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B.
     GitHub https://github.com/martinwimpff/channel-attention (accessed 2024-03-28)

 .. rubric:: Hugging Face Hub integration

 When the optional ``huggingface_hub`` package is installed, all models
 automatically gain the ability to be pushed to and loaded from the
 Hugging Face Hub. Install with::

     pip install braindecode[hub]

 **Pushing a model to the Hub:**

 .. code::
     from braindecode.models import AttentionBaseNet

     # Train your model
     model = AttentionBaseNet(n_chans=22, n_outputs=4, n_times=1000)
     # ... training code ...

     # Push to the Hub
     model.push_to_hub(
         repo_id="username/my-attentionbasenet-model",
         commit_message="Initial model upload",
     )

 **Loading a model from the Hub:**

 .. code::
     from braindecode.models import AttentionBaseNet

     # Load pretrained model
     model = AttentionBaseNet.from_pretrained("username/my-attentionbasenet-model")

     # Load with a different number of outputs (head is rebuilt automatically)
     model = AttentionBaseNet.from_pretrained("username/my-attentionbasenet-model", n_outputs=4)

 **Extracting features and replacing the head:**

 .. code::
     import torch

     x = torch.randn(1, model.n_chans, model.n_times)
     # Extract encoder features (consistent dict across all models)
     out = model(x, return_features=True)
     features = out["features"]

     # Replace the classification head
     model.reset_head(n_outputs=10)

 **Saving and restoring full configuration:**

 .. code::
     import json

     config = model.get_config()            # all __init__ params
     with open("config.json", "w") as f:
         json.dump(config, f)

     model2 = AttentionBaseNet.from_config(config)    # reconstruct (no weights)

 All model parameters (both EEG-specific and model-specific such as
 dropout rates, activation functions, number of filters) are automatically
 saved to the Hub and restored when loading.

 See :ref:`load-pretrained-models` for a complete tutorial.</main>
</div>

## Citation

Please cite both the original paper for this architecture (see the
*References* section above) and braindecode:

```bibtex
@article{aristimunha2025braindecode,
  title   = {Braindecode: a deep learning library for raw electrophysiological data},
  author  = {Aristimunha, Bruno and others},
  journal = {Zenodo},
  year    = {2025},
  doi     = {10.5281/zenodo.17699192},
}
```

## License

BSD-3-Clause for the model code (matching braindecode).
Pretraining-derived weights, if you fine-tune from a checkpoint,
inherit the licence of that checkpoint and its training corpus.