PhurinutR commited on
Commit
8b11b00
·
1 Parent(s): 7a63dcf

added installation guide

Browse files
Files changed (3) hide show
  1. .gitignore +1 -1
  2. README.md +57 -1
  3. dataset/README.md +22 -0
.gitignore CHANGED
@@ -1,3 +1,3 @@
1
- dataset
2
  .venv
3
  runs
 
1
+ dataset/ptb_xl/
2
  .venv
3
  runs
README.md CHANGED
@@ -1,2 +1,58 @@
1
  # ecg_reconstruction
2
- Just Playing around for now.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # ecg_reconstruction
2
+
3
+ Self-supervised **12-lead ECG reconstruction** with a **masked autoencoder (MAE)** built using the idea of **[CoRe-ECG](https://arxiv.org/abs/2604.11359)**: spatio-temporal dual masking (STDM), a **visibility-restricted encoder**, and a full **decoder** that predicts masked patches. This repo implements the **reconstruction branch only** (no downstream classification head in the training loop).
4
+
5
+ ## Techniques at a glance
6
+
7
+
8
+ | Idea | What it does here |
9
+ | --------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
10
+ | **Patch embedding** | Each lead is split into non-overlapping time patches (default **75 samples** per patch, **66 patches** → **4950** samples per lead at **500 Hz**). Patches are linearly embedded plus learned **2D position** over leads × time indices. |
11
+ | **STDM (Spatio-Temporal Dual Masking)** | Per batch and per time index, either **full temporal masking** (all leads supervised at that index) or **partial masking**: **k** leads stay **visible** (encoder input), remaining leads are split into **reconstruction targets** vs **dropped** (no loss). Controlled by `p_time`, `p_lead`, and `num_visible_leads` (`k`) in `MAEConfig` / `mae/stdm.py`. |
12
+ | **Visibility-restricted encoder** | Standard self-attention would let masked positions “peek” at neighbors. Here, **additive attention bias** restricts visible tokens to attend to visible keys; non-visible positions use a **stabilized identity row** so the encoder does not mix information across the visibility boundary (`mae/encoder.py`). |
13
+ | **Decoder + mask token** | Encoder outputs are projected; **visible** slots keep signal tokens, **non-visible** slots get a learned **mask token** plus decoder positions. A shallow **decoder stack** (full attention over all positions) predicts **patch pixels**; loss is **MSE only on STDM supervision mask `M`** (`mae/losses.py`). |
14
+ | **Preprocessing** | **Butterworth bandpass** 0.65–40 Hz, zero-phase `filtfilt`, per lead (`preprocessor.py`), aligned with common ECG MAE setups. |
15
+ | **Data** | **PTB-XL** high-rate records (`filename_hr` @ 500 Hz), official **stratified folds** 1–8 train, 9 val, 10 test (`ptb_xl_dataset.py`). |
16
+
17
+
18
+ ## Project layout
19
+
20
+ - `train_ecg_mae.py` — training loop: **AdamW**, TensorBoard scalars and periodic **reconstruction figures**, checkpoints.
21
+ - `mae/` — config, **ECGDataMAE** model, STDM sampling, encoder/decoder blocks, loss.
22
+ - `preprocessor.py` — filtering, patch length / signal window constants shared with the model.
23
+ - `ptb_xl_dataset.py` — CSV-driven PTB-XL loading via **WFDB**.
24
+ - `inference.py` — `load_pipeline`, `reconstruct`, checkpoint I/O, plotting helpers for dashboards or notebooks.
25
+ - `visualization.ipynb` — exploratory plots.
26
+
27
+ ## Setup
28
+
29
+ Create a virtual environment:
30
+
31
+ ```bash
32
+ python -m venv .venv
33
+ source .venv/bin/activate
34
+ ```
35
+
36
+ Install **PyTorch** separately (CUDA build as you need). Other dependencies:
37
+
38
+ ```bash
39
+ pip install -r requirements.txt
40
+ ```
41
+
42
+ Check this [GUIDE](./dataset/README.md) to install the PTB-XL dataset.
43
+
44
+ ## Train
45
+
46
+ ```bash
47
+ python train_ecg_mae.py --data-root dataset/ptb_xl --log-dir runs/[experiment name] --epochs 80
48
+ ```
49
+
50
+ Useful flags: `--batch-size`, `--lr`, `--weight-decay`, `--resume path/to/checkpoint.pt`, `--ckpt-every`, `--vis-every`. Logs and checkpoints go under `--log-dir`; that folder is listed in `.gitignore` so artifacts stay local.
51
+
52
+ ## Inference
53
+
54
+ Load a saved checkpoint and run `reconstruct()` from `inference.py` on tensors shaped `(batch, 12, signal_length)` with the same preprocessing as training.
55
+
56
+ ## Note on patch length vs. paper
57
+
58
+ The implementation keeps **75 samples per patch** (paper patch length in samples). At **500 Hz** the temporal extent per patch differs from a 250 Hz setup; the code comments in `preprocessor.py` describe this tradeoff explicitly.
dataset/README.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PTB-XL dataset (local install)
2
+
3
+ Training and evaluation use **PTB-XL v1.0.3** at **500 Hz**: waveforms are read from paths in the `filename_hr` column of `ptbxl_database.csv` (under `records500/`). See `ptb_xl_dataset.py` for how splits use `strat_fold`.
4
+
5
+ `dataset/ptb_xl/` is listed in the repo `.gitignore`, so downloaded files stay on your machine and are not committed.
6
+
7
+ ## Prerequisites
8
+
9
+ - **AWS CLI v2** with the `aws` command available ([install guide](https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html)).
10
+ - Enough **disk space** for the release (full **1.0.3** sync is on the order of **many gigabytes**; exact size depends on PhysioNet’s current layout).
11
+ - Network access to the public **PhysioNet** S3 bucket (no AWS account required for this bucket).
12
+
13
+ ## Download
14
+
15
+ From the **repository root** (parent of `dataset/`):
16
+
17
+ ```bash
18
+ cd dataset
19
+ aws s3 sync --no-sign-request s3://physionet-open/ptb-xl/1.0.3/ ptb_xl
20
+ ```
21
+
22
+ This creates `dataset/ptb_xl/` with metadata and WFDB files, including `ptbxl_database.csv` and the `records500/` tree used for high-rate records.