epi-embedding / README.md
avocardio's picture
Upload README.md with huggingface_hub
9623537 verified
metadata
license: apache-2.0
tags:
  - eeg
  - neuroscience
  - foundation-model
  - embeddings
  - matryoshka
pipeline_tag: feature-extraction
library_name: neuroencoder
extra_gated_prompt: |-
  The MRL model is currently gated. Access is granted to verified researchers.
  Please briefly describe your institution, role, and intended use.
  If you have a private invitation code, paste it in the "Intended use" field.
extra_gated_fields:
  Institution: text
  Role: text
  Intended use: text
  I agree to use this model for research purposes only: checkbox

EPI Embedding

EEG model embeddings, distilled from EPI-250k (trained on ~250,000 hours of clinical EEG).

The model produces a 768-dimensional embedding that you can truncate to 768, 384, 192, 48, or 16 dimensions via Matryoshka Representation Learning.

Usage

Install:

pip install neuroencoder

Then:

import mne, neuroencoder as ne
from neuroencoder import MRL

raw = mne.io.read_raw_edf("recording.edf", preload=True)
model = MRL.from_pretrained()                         # auto-downloads on first use

embeddings = model.embed(
    raw.get_data(),
    sfreq=raw.info["sfreq"],
    channel_names=raw.ch_names,
    dim=192,
)
# -> numpy array, shape [N, 192], L2-normalized

ne.explore(embeddings)                                # interactive Apple Embedding Atlas

model.embed runs the full pipeline (filter -> resample -> 8-region average -> 30s sliding window -> embed) and returns numpy. For more control, split into:

images = ne.preprocess(eeg, sfreq=256, channel_names=ch_names)   # [N, 8, 224, 224]
embeddings = model.predict(images, dim=192)                       # torch tensor on model device

Loading directly from a checkpoint

model = MRL.from_checkpoint("path/to/last.ckpt")

Handles both raw state dicts and PyTorch Lightning checkpoint formats.

Benchmarks

Frozen linear probes, 5-fold subject-level cross-validation. Balanced accuracy (%). The first column is EPI-250k, our base foundation model (not publicly released) — the upper bound on what the distilled MRL model can preserve. The remaining columns are the MRL model at each truncation dimension.

Private clinical tasks

40,909 annotated 30-second epochs from the Swiss Epilepsy Center.

Task EPI-250k 768 384 192 48 16
Seizure / Wake 93.4 93.1 92.7 92.5 91.5 84.1
Sleep (5-class) 85.1 77.0 77.4 76.9 76.5 73.2
Artifact / Wake 90.2 90.5 90.3 90.5 90.7 65.9
Seizure / Sleep 88.8 85.2 84.9 84.0 82.1 79.4
Spike / Seizure 81.5 76.2 75.9 74.7 71.0 65.5
Spike / Wake 97.0 94.8 94.7 94.6 92.9 87.2
Artifact / Spike 78.8 76.0 75.6 75.3 74.4 70.4
Category (6-cls) 36.3 33.6 33.3 32.8 31.7 27.4
Clinical Sub (7-cls) 42.7 31.4 31.4 31.4 27.0 23.7
All Sublabels (49-cls) 22.1 14.8 14.4 13.7 12.3 10.6

Public benchmarks

10 standard public EEG datasets, evaluated under identical conditions.

Task EPI-250k 768 384 192 48 16
TUAB 73.1 72.4 72.5 72.9 72.2 70.4
TUEV 54.5 45.9 47.2 46.7 42.8 32.1
TUAR 45.2 43.0 42.9 42.2 39.5 36.5
TUSL 73.3 71.5 75.1 77.1 71.3 69.7
Mumtaz 82.1 80.7 81.8 82.6 83.2 83.1
Schizo 71.1 70.1 69.4 69.5 69.4 66.7
MentArith 60.9 60.2 59.9 58.6 55.6 52.2
ADFTD 43.2 40.0 40.0 41.0 38.6 35.9
PhysioMI 30.3 28.3 28.4 27.3 27.7 25.2
Parkinsons 62.9 58.9 58.6 58.2 55.9 53.2

Numeric column headers (768, 384, ...) are the MRL truncation dimensions.

Documentation

Citation

Paper in preparation. A citation will be added once published.