Image Segmentation
medical
biology
Jose commited on
Commit
ee42561
·
0 Parent(s):

first commit github branch

Browse files
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ *.pyc
2
+ __pycache__
3
+ *.egg-info
4
+ *.zip
5
+ /samples/fundus/*
6
+ !/samples/fundus/original
README.md ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: agpl-3.0
3
+ pipeline_tag: image-segmentation
4
+ tags:
5
+ - medical
6
+ - biology
7
+ ---
8
+
9
+ # VascX models
10
+
11
+ This repository contains the instructions for using the VascX models from the paper [VascX Models: Model Ensembles for Retinal Vascular Analysis from Color Fundus Images](https://arxiv.org/abs/2409.16016).
12
+
13
+ The model weights are in [huggingface](https://huggingface.co/Eyened/vascx).
14
+
15
+ <img src="imgs/CHASEDB1_12R_rgb.png" width="240" height="240" style="display:inline"><img src="imgs/CHASEDB1_12R.png" width="240" height="240" style="display:inline">
16
+
17
+ <img src="imgs/DRIVE_22_rgb.png" width="240" height="240" style="display:inline"><img src="imgs/DRIVE_22.png" width="240" height="240" style="display:inline">
18
+
19
+ <img src="imgs/HRF_04_g_rgb.png" width="240" height="240" style="display:inline"><img src="imgs/HRF_04_g.png" width="240" height="240" style="display:inline">
20
+
21
+ ## Installation
22
+
23
+ To install the entire fundus analysis pipeline including fundus preprocessing, model inference code and vascular biomarker extraction:
24
+
25
+ 1. Create a conda or virtualenv virtual environment, or otherwise ensure a clean environment.
26
+
27
+ 2. Install the [rtnls_inference package](https://github.com/Eyened/retinalysis-inference).
28
+
29
+
30
+ ## `vascx run` Command
31
+
32
+ The `run` command provides a comprehensive pipeline for processing fundus images, performing various analyses, and creating visualizations.
33
+
34
+ ### Usage
35
+
36
+ ```bash
37
+ vascx run DATA_PATH OUTPUT_PATH [OPTIONS]
38
+ ```
39
+
40
+ ### Arguments
41
+
42
+ - `DATA_PATH`: Path to input data. Can be either:
43
+ - A directory containing fundus images
44
+ - A CSV file with a 'path' column containing paths to images
45
+
46
+ - `OUTPUT_PATH`: Directory where processed results will be stored
47
+
48
+ ### Options
49
+
50
+ | Option | Default | Description |
51
+ |--------|---------|-------------|
52
+ | `--preprocess/--no-preprocess` | `--preprocess` | Run preprocessing to standardize images for model input |
53
+ | `--vessels/--no-vessels` | `--vessels` | Run vessel segmentation and artery-vein classification |
54
+ | `--disc/--no-disc` | `--disc` | Run optic disc segmentation |
55
+ | `--quality/--no-quality` | `--quality` | Run image quality assessment |
56
+ | `--fovea/--no-fovea` | `--fovea` | Run fovea detection |
57
+ | `--overlay/--no-overlay` | `--overlay` | Create visualization overlays combining all results |
58
+ | `--n_jobs` | `4` | Number of preprocessing workers for parallel processing |
59
+
60
+ ### Output Structure
61
+
62
+ When run with default options, the command creates the following structure in `OUTPUT_PATH`:
63
+
64
+ ```
65
+ OUTPUT_PATH/
66
+ ├── preprocessed_rgb/ # Standardized fundus images
67
+ ├── vessels/ # Vessel segmentation results
68
+ ├── artery_vein/ # Artery-vein classification
69
+ ├── disc/ # Optic disc segmentation
70
+ ├── overlays/ # Visualization images
71
+ ├── bounds.csv # Image boundary information
72
+ ├── quality.csv # Image quality scores
73
+ └── fovea.csv # Fovea coordinates
74
+ ```
75
+
76
+ ### Processing Stages
77
+
78
+ 1. **Preprocessing**:
79
+ - Standardizes input images for consistent analysis
80
+ - Outputs preprocessed images and boundary information
81
+
82
+ 2. **Quality Assessment**:
83
+ - Evaluates image quality with three quality metrics (q1, q2, q3)
84
+ - Higher scores indicate better image quality
85
+
86
+ 3. **Vessel Segmentation and Artery-Vein Classification**:
87
+ - Identifies blood vessels in the retina
88
+ - Classifies vessels as arteries (1) or veins (2) with intersections (3)
89
+
90
+ 4. **Optic Disc Segmentation**:
91
+ - Identifies the optic disc location and boundaries
92
+
93
+ 5. **Fovea Detection**:
94
+ - Determines the coordinates of the fovea (center of vision)
95
+
96
+ 6. **Visualization Overlays**:
97
+ - Creates color-coded images showing:
98
+ - Arteries in red
99
+ - Veins in blue
100
+ - Optic disc in white
101
+ - Fovea marked with yellow X
102
+
103
+ ### Examples
104
+
105
+ **Process a directory of images with all analyses:**
106
+ ```bash
107
+ vascx run /path/to/images /path/to/output
108
+ ```
109
+
110
+ **Process specific images listed in a CSV:**
111
+ ```bash
112
+ vascx run /path/to/image_list.csv /path/to/output
113
+ ```
114
+
115
+ **Only run preprocessing and vessel segmentation:**
116
+ ```bash
117
+ vascx run /path/to/images /path/to/output --no-disc --no-quality --no-fovea --no-overlay
118
+ ```
119
+
120
+ **Skip preprocessing on already preprocessed images:**
121
+ ```bash
122
+ vascx run /path/to/preprocessed/images /path/to/output --no-preprocess
123
+ ```
124
+
125
+ **Increase parallel processing workers:**
126
+ ```bash
127
+ vascx run /path/to/images /path/to/output --n_jobs 8
128
+ ```
129
+
130
+ ### Notes
131
+
132
+ - The CSV input must contain a 'path' column with image file paths
133
+ - If the CSV includes an 'id' column, these IDs will be used instead of filenames
134
+ - When `--no-preprocess` is used, input images must already be in the proper format
135
+ - The overlay visualization requires at least one analysis component to be enabled
136
+
137
+ ###
138
+
139
+ To speed up re-execution of vascx we recommend to run the preprocessing and segmentation steps separately:
140
+
141
+ 1. Preprocessing. See [this notebook](./notebooks/0_preprocess.ipynb). This step is CPU-heavy and benefits from parallelization (see notebook).
142
+
143
+ 2. Inference. See [this notebook](./notebooks/1_segment_preprocessed.ipynb). All models can be ran in a single GPU with >10GB VRAM.
fovea/fovea_july24.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1af042f7e2a398f512be8a8d54cc480300312c3f3692c13bd98be65439a33222
3
+ size 352714676
imgs/CHASEDB1_08L.png ADDED

Git LFS Details

  • SHA256: 4b3537bcda4faa0abd2f187bf508d9dfc3b469f73f3b889d158bd3ef30fa64a9
  • Pointer size: 131 Bytes
  • Size of remote file: 694 kB
imgs/CHASEDB1_08L_rgb.png ADDED

Git LFS Details

  • SHA256: 923cbd785406a4d370b48cc0ffe2525309d35f8ecdf66a7552db6d0e3b0fd758
  • Pointer size: 131 Bytes
  • Size of remote file: 757 kB
imgs/CHASEDB1_12R.png ADDED

Git LFS Details

  • SHA256: d5457e090dc4de46bdc5c7eae45e536d680862e6059eff2c46f7425030672a79
  • Pointer size: 131 Bytes
  • Size of remote file: 804 kB
imgs/CHASEDB1_12R_rgb.png ADDED

Git LFS Details

  • SHA256: d0af405bbd3e8df582bfdd4cd91ca0006aeea307a129221c2d5d46da0ed62234
  • Pointer size: 131 Bytes
  • Size of remote file: 883 kB
imgs/DRIVE_22.png ADDED

Git LFS Details

  • SHA256: cf12b1603f3a50aa125a327aefa07c512ed4b804243e70b3d23b2f4145416d91
  • Pointer size: 131 Bytes
  • Size of remote file: 852 kB
imgs/DRIVE_22_rgb.png ADDED

Git LFS Details

  • SHA256: 87df6604a7348fd328cc5c4e51c028bd996e183fea5f012d9b045de15d8608eb
  • Pointer size: 131 Bytes
  • Size of remote file: 893 kB
imgs/DRIVE_40.png ADDED

Git LFS Details

  • SHA256: 33a24859edb67575ee6fbd2c797dc903cd64df13d314649a2bb9643706895c70
  • Pointer size: 131 Bytes
  • Size of remote file: 834 kB
imgs/DRIVE_40_rgb.png ADDED

Git LFS Details

  • SHA256: b0dcb48533f7b6859a4187eab7ca386e0655be5f7e356ad1d46d02bb3b52caa7
  • Pointer size: 131 Bytes
  • Size of remote file: 874 kB
imgs/HRF_04_g.png ADDED

Git LFS Details

  • SHA256: 64113c3789edace497c717418879e7257a0b20f73af972ac61417ba3c709a50f
  • Pointer size: 131 Bytes
  • Size of remote file: 711 kB
imgs/HRF_04_g_rgb.png ADDED

Git LFS Details

  • SHA256: 4f4f9698e15221b6dd61a3636c5b266b35fc6b221dbbaa1fc25b9e4b410c77b9
  • Pointer size: 131 Bytes
  • Size of remote file: 843 kB
imgs/HRF_07_dr.png ADDED

Git LFS Details

  • SHA256: cfcb0a41b79cd31d3531e0277ec8c44fbc829cc8c733f7fc21c99183453dcd17
  • Pointer size: 131 Bytes
  • Size of remote file: 767 kB
imgs/HRF_07_dr_rgb.png ADDED

Git LFS Details

  • SHA256: 74046f4d4d50dd3673394ba1cf1db33aeabfdd8e733c9fa1f0ad1fc9ef38dd15
  • Pointer size: 131 Bytes
  • Size of remote file: 898 kB
imgs/samples_vascx_hrf.png ADDED

Git LFS Details

  • SHA256: 17499c0fef958fe55ed8bc359d71d803048ef16c106e7cee78e01d95a38de1ec
  • Pointer size: 132 Bytes
  • Size of remote file: 6.08 MB
notebooks/0_preprocess.ipynb ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from pathlib import Path\n",
10
+ "\n",
11
+ "import pandas as pd\n",
12
+ "\n",
13
+ "from rtnls_fundusprep.preprocessor import parallel_preprocess"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "markdown",
18
+ "metadata": {},
19
+ "source": [
20
+ "## Preprocessing\n",
21
+ "\n",
22
+ "This code will preprocess the images and write .png files with the square fundus image and the contrast enhanced version\n",
23
+ "\n",
24
+ "This step is not strictly necessary, but it is useful if you want to run the preprocessing step separately before model inference\n"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "markdown",
29
+ "metadata": {},
30
+ "source": [
31
+ "Create a list of files to be preprocessed:"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": 2,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "ds_path = Path(\"../samples/fundus\")\n",
41
+ "files = list((ds_path / \"original\").glob(\"*\"))"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "markdown",
46
+ "metadata": {},
47
+ "source": [
48
+ "Images with .dcm extension will be read as dicom and the pixel_array will be read as RGB. All other images will be read using PIL's Image.open"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 3,
54
+ "metadata": {},
55
+ "outputs": [
56
+ {
57
+ "name": "stderr",
58
+ "output_type": "stream",
59
+ "text": [
60
+ "0it [00:00, ?it/s][Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.\n",
61
+ "6it [00:00, 154.80it/s]\n"
62
+ ]
63
+ },
64
+ {
65
+ "name": "stdout",
66
+ "output_type": "stream",
67
+ "text": [
68
+ "Error with image ../samples/fundus/original/HRF_07_dr.jpg\n",
69
+ "Error with image ../samples/fundus/original/HRF_04_g.jpg\n"
70
+ ]
71
+ },
72
+ {
73
+ "name": "stderr",
74
+ "output_type": "stream",
75
+ "text": [
76
+ "[Parallel(n_jobs=4)]: Done 2 out of 6 | elapsed: 0.9s remaining: 1.8s\n",
77
+ "[Parallel(n_jobs=4)]: Done 3 out of 6 | elapsed: 1.5s remaining: 1.5s\n",
78
+ "[Parallel(n_jobs=4)]: Done 4 out of 6 | elapsed: 1.5s remaining: 0.8s\n",
79
+ "[Parallel(n_jobs=4)]: Done 6 out of 6 | elapsed: 1.6s finished\n"
80
+ ]
81
+ }
82
+ ],
83
+ "source": [
84
+ "bounds = parallel_preprocess(\n",
85
+ " files, # List of image files\n",
86
+ " rgb_path=ds_path / \"rgb\", # Output path for RGB images\n",
87
+ " ce_path=ds_path / \"ce\", # Output path for Contrast Enhanced images\n",
88
+ " n_jobs=4, # number of preprocessing workers\n",
89
+ ")\n",
90
+ "df_bounds = pd.DataFrame(bounds).set_index(\"id\")"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "markdown",
95
+ "metadata": {},
96
+ "source": [
97
+ "The preprocessor will produce RGB and contrast-enhanced preprocessed images cropped to a square and return a dataframe with the image bounds that can be used to reconstruct the original image. Output files will be named the same as input images, but with .png extension. Be careful with providing multiple inputs with the same filename without extension as this will result in over-written images. Any exceptions during pre-processing will not stop execution but will print error. Images that failed pre-processing for any reason will be marked with `success=False` in the df_bounds dataframe."
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": 4,
103
+ "metadata": {},
104
+ "outputs": [],
105
+ "source": [
106
+ "df_bounds.to_csv(ds_path / \"meta.csv\")"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": []
115
+ }
116
+ ],
117
+ "metadata": {
118
+ "kernelspec": {
119
+ "display_name": "retinalysis",
120
+ "language": "python",
121
+ "name": "python3"
122
+ },
123
+ "language_info": {
124
+ "codemirror_mode": {
125
+ "name": "ipython",
126
+ "version": 3
127
+ },
128
+ "file_extension": ".py",
129
+ "mimetype": "text/x-python",
130
+ "name": "python",
131
+ "nbconvert_exporter": "python",
132
+ "pygments_lexer": "ipython3",
133
+ "version": "3.10.13"
134
+ }
135
+ },
136
+ "nbformat": 4,
137
+ "nbformat_minor": 2
138
+ }
notebooks/1_segment_preprocessed.ipynb ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from pathlib import Path\n",
10
+ "\n",
11
+ "import torch\n",
12
+ "\n",
13
+ "from rtnls_inference import (\n",
14
+ " HeatmapRegressionEnsemble,\n",
15
+ " SegmentationEnsemble,\n",
16
+ ")"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "markdown",
21
+ "metadata": {},
22
+ "source": [
23
+ "## Segmentation of preprocessed images\n",
24
+ "\n",
25
+ "Here we segment images preprocessed using 0_preprocess.ipynb\n"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "markdown",
30
+ "metadata": {},
31
+ "source": []
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 2,
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "ds_path = Path(\"../samples/fundus\")\n",
40
+ "\n",
41
+ "# input folders. these are the folders where we stored the preprocessed images\n",
42
+ "rgb_path = ds_path / \"rgb\"\n",
43
+ "ce_path = ds_path / \"ce\"\n",
44
+ "\n",
45
+ "# these are the output folders for:\n",
46
+ "av_path = ds_path / \"av\" # artery-vein segmentations\n",
47
+ "discs_path = ds_path / \"discs\" # optic disc segmentations\n",
48
+ "overlays_path = ds_path / \"overlays\" # optional overlay visualizations\n",
49
+ "\n",
50
+ "device = torch.device(\"cuda:0\") # device to use for inference"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": 3,
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "rgb_paths = sorted(list(rgb_path.glob(\"*.png\")))\n",
60
+ "ce_paths = sorted(list(ce_path.glob(\"*.png\")))\n",
61
+ "paired_paths = list(zip(rgb_paths, ce_paths))"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": null,
67
+ "metadata": {},
68
+ "outputs": [],
69
+ "source": [
70
+ "paired_paths[0] # important to make sure that the paths are paired correctly"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "markdown",
75
+ "metadata": {},
76
+ "source": [
77
+ "### Artery-vein segmentation\n"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "av_ensemble = SegmentationEnsemble.from_huggingface('Eyened/vascx:artery_vein/av_july24.pt').to(device)\n",
87
+ "\n",
88
+ "av_ensemble.predict_preprocessed(paired_paths, dest_path=av_path, num_workers=2)"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "markdown",
93
+ "metadata": {},
94
+ "source": [
95
+ "### Disc segmentation\n"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": null,
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": [
104
+ "disc_ensemble = SegmentationEnsemble.from_huggingface('Eyened/vascx:disc/disc_july24.pt').to(device)\n",
105
+ "disc_ensemble.predict_preprocessed(paired_paths, dest_path=discs_path, num_workers=2)"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "markdown",
110
+ "metadata": {},
111
+ "source": [
112
+ "### Fovea detection\n"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "fovea_ensemble = HeatmapRegressionEnsemble.from_huggingface('Eyened/vascx:fovea/fovea_july24.pt').to(device)\n",
122
+ "# note: this model does not use contrast enhanced images\n",
123
+ "df = fovea_ensemble.predict_preprocessed(paired_paths, num_workers=2)\n",
124
+ "df.columns = [\"mean_x\", \"mean_y\"]\n",
125
+ "df.to_csv(ds_path / \"fovea.csv\")"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": null,
131
+ "metadata": {},
132
+ "outputs": [],
133
+ "source": [
134
+ "df"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "markdown",
139
+ "metadata": {},
140
+ "source": [
141
+ "### Plotting the retinas (optional)\n",
142
+ "\n",
143
+ "This will only work if you ran all the models and stored the outputs using the same folder/file names as above\n"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": null,
149
+ "metadata": {},
150
+ "outputs": [],
151
+ "source": [
152
+ "from vascx.fundus.loader import RetinaLoader\n",
153
+ "\n",
154
+ "from rtnls_enface.utils.plotting import plot_gridfns\n",
155
+ "\n",
156
+ "loader = RetinaLoader.from_folder(ds_path)"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": [
165
+ "plot_gridfns([ret.plot for ret in loader[:6]])"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "markdown",
170
+ "metadata": {},
171
+ "source": [
172
+ "### Storing visualizations (optional)\n"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": 10,
178
+ "metadata": {},
179
+ "outputs": [],
180
+ "source": [
181
+ "if not overlays_path.exists():\n",
182
+ " overlays_path.mkdir()\n",
183
+ "for ret in loader:\n",
184
+ " fig, _ = ret.plot()\n",
185
+ " fig.savefig(overlays_path / f\"{ret.id}.png\", bbox_inches=\"tight\", pad_inches=0)"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": null,
191
+ "metadata": {},
192
+ "outputs": [],
193
+ "source": []
194
+ }
195
+ ],
196
+ "metadata": {
197
+ "kernelspec": {
198
+ "display_name": "retinalysis",
199
+ "language": "python",
200
+ "name": "python3"
201
+ },
202
+ "language_info": {
203
+ "codemirror_mode": {
204
+ "name": "ipython",
205
+ "version": 3
206
+ },
207
+ "file_extension": ".py",
208
+ "mimetype": "text/x-python",
209
+ "name": "python",
210
+ "nbconvert_exporter": "python",
211
+ "pygments_lexer": "ipython3",
212
+ "version": "3.10.13"
213
+ }
214
+ },
215
+ "nbformat": 4,
216
+ "nbformat_minor": 2
217
+ }
samples/fundus/original/CHASEDB1_08L.png ADDED

Git LFS Details

  • SHA256: 16735352efdb2d951be1f07882e3906ee159c81e12e69c8230a7172d562cfc6b
  • Pointer size: 131 Bytes
  • Size of remote file: 621 kB
samples/fundus/original/CHASEDB1_12R.png ADDED

Git LFS Details

  • SHA256: a541717baf5d83c7295657604d1f529e9ed1cd3a4327aa224dbb83d80d49cc3a
  • Pointer size: 131 Bytes
  • Size of remote file: 776 kB
samples/fundus/original/DRIVE_22.png ADDED

Git LFS Details

  • SHA256: 58a0a44558d23d9cd4ffc60326abf91eed824bbe5718e995cb181595499f595b
  • Pointer size: 131 Bytes
  • Size of remote file: 394 kB
samples/fundus/original/DRIVE_40.png ADDED

Git LFS Details

  • SHA256: 0d8d7685974b7c0eff3583245dbb9e88a1a6a82ed60dbe09112364ba51894438
  • Pointer size: 131 Bytes
  • Size of remote file: 387 kB
samples/fundus/original/HRF_04_g.jpg ADDED

Git LFS Details

  • SHA256: fc9ed13ef42502eeecb3f1754dc0d3b72a454c82884b40dde934e8a516495588
  • Pointer size: 132 Bytes
  • Size of remote file: 1.9 MB
samples/fundus/original/HRF_07_dr.jpg ADDED

Git LFS Details

  • SHA256: 203ddec480816b6c9d7ea3c19c1ff0870a5a61b5b6c9a176300402ac47fbc10f
  • Pointer size: 131 Bytes
  • Size of remote file: 921 kB
setup.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import find_packages, setup
2
+
3
+ with open("README.md", "r") as fh:
4
+ long_description = fh.read()
5
+
6
+ setup(
7
+ name="vascx_models",
8
+ # using versioneer for versioning using git tags
9
+ # https://github.com/python-versioneer/python-versioneer/blob/master/INSTALL.md
10
+ # version=versioneer.get_version(),
11
+ # cmdclass=versioneer.get_cmdclass(),
12
+ author="Jose Vargas",
13
+ author_email="j.vargasquiros@erasmusmc.nl",
14
+ description="Retinal analysis toolbox for Python",
15
+ long_description=long_description,
16
+ long_description_content_type="text/markdown",
17
+ packages=find_packages(),
18
+ include_package_data=True,
19
+ zip_safe=False,
20
+ entry_points={
21
+ "console_scripts": [
22
+ "vascx = vascx_models.cli:cli",
23
+ ]
24
+ },
25
+ install_requires=[
26
+ "numpy == 1.*",
27
+ "pandas == 2.*",
28
+ "tqdm == 4.*",
29
+ "Pillow == 9.*",
30
+ "click==8.*",
31
+ ],
32
+ python_requires=">=3.10, <3.11",
33
+ )
vascx_models/cli.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import click
4
+ import pandas as pd
5
+ import torch
6
+
7
+ from rtnls_fundusprep.cli import _run_preprocessing
8
+
9
+ from .inference import (
10
+ run_fovea_detection,
11
+ run_quality_estimation,
12
+ run_segmentation_disc,
13
+ run_segmentation_vessels_and_av,
14
+ )
15
+ from .utils import batch_create_overlays
16
+
17
+
18
+ @click.group(name="vascx")
19
+ def cli():
20
+ pass
21
+
22
+
23
+ @cli.command()
24
+ @click.argument("data_path", type=click.Path(exists=True))
25
+ @click.argument("output_path", type=click.Path())
26
+ @click.option(
27
+ "--preprocess/--no-preprocess",
28
+ default=True,
29
+ help="Run preprocessing or use preprocessed images",
30
+ )
31
+ @click.option(
32
+ "--vessels/--no-vessels", default=True, help="Run vessels and AV segmentation"
33
+ )
34
+ @click.option("--disc/--no-disc", default=True, help="Run optic disc segmentation")
35
+ @click.option(
36
+ "--quality/--no-quality", default=True, help="Run image quality estimation"
37
+ )
38
+ @click.option("--fovea/--no-fovea", default=True, help="Run fovea detection")
39
+ @click.option(
40
+ "--overlay/--no-overlay", default=True, help="Create visualization overlays"
41
+ )
42
+ @click.option("--n_jobs", type=int, default=4, help="Number of preprocessing workers")
43
+ def run(
44
+ data_path, output_path, preprocess, vessels, disc, quality, fovea, overlay, n_jobs
45
+ ):
46
+ """Run the complete inference pipeline on fundus images.
47
+
48
+ DATA_PATH is either a directory containing images or a CSV file with 'path' column.
49
+ OUTPUT_PATH is the directory where results will be stored.
50
+ """
51
+
52
+ output_path = Path(output_path)
53
+ output_path.mkdir(exist_ok=True, parents=True)
54
+
55
+ # Setup output directories
56
+ preprocess_rgb_path = output_path / "preprocessed_rgb"
57
+ vessels_path = output_path / "vessels"
58
+ av_path = output_path / "artery_vein"
59
+ disc_path = output_path / "disc"
60
+ overlay_path = output_path / "overlays"
61
+
62
+ # Create required directories
63
+ if preprocess:
64
+ preprocess_rgb_path.mkdir(exist_ok=True, parents=True)
65
+ if vessels:
66
+ av_path.mkdir(exist_ok=True, parents=True)
67
+ vessels_path.mkdir(exist_ok=True, parents=True)
68
+ if disc:
69
+ disc_path.mkdir(exist_ok=True, parents=True)
70
+ if overlay:
71
+ overlay_path.mkdir(exist_ok=True, parents=True)
72
+
73
+ bounds_path = output_path / "bounds.csv" if preprocess else None
74
+ quality_path = output_path / "quality.csv" if quality else None
75
+ fovea_path = output_path / "fovea.csv" if fovea else None
76
+
77
+ # Determine if input is a folder or CSV file
78
+ data_path = Path(data_path)
79
+ is_csv = data_path.suffix.lower() == ".csv"
80
+
81
+ # Get files to process
82
+ files = []
83
+ ids = None
84
+
85
+ if is_csv:
86
+ click.echo(f"Reading file paths from CSV: {data_path}")
87
+ try:
88
+ df = pd.read_csv(data_path)
89
+ if "path" not in df.columns:
90
+ click.echo("Error: CSV must contain a 'path' column")
91
+ return
92
+
93
+ # Get file paths and convert to Path objects
94
+ files = [Path(p) for p in df["path"]]
95
+
96
+ if "id" in df.columns:
97
+ ids = df["id"].tolist()
98
+ click.echo("Using IDs from CSV 'id' column")
99
+
100
+ except Exception as e:
101
+ click.echo(f"Error reading CSV file: {e}")
102
+ return
103
+ else:
104
+ click.echo(f"Finding files in directory: {data_path}")
105
+ files = list(data_path.glob("*"))
106
+ ids = [f.stem for f in files]
107
+
108
+ if not files:
109
+ click.echo("No files found to process")
110
+ return
111
+
112
+ click.echo(f"Found {len(files)} files to process")
113
+
114
+ # Step 1: Preprocess images if requested
115
+ if preprocess:
116
+ click.echo("Running preprocessing...")
117
+ _run_preprocessing(
118
+ files=files,
119
+ ids=ids,
120
+ rgb_path=preprocess_rgb_path,
121
+ bounds_path=bounds_path,
122
+ n_jobs=n_jobs,
123
+ )
124
+ # Use the preprocessed images for subsequent steps
125
+ preprocessed_files = list(preprocess_rgb_path.glob("*.png"))
126
+ else:
127
+ # Use the input files directly
128
+ preprocessed_files = files
129
+ ids = [f.stem for f in preprocessed_files]
130
+
131
+ # Set up GPU device
132
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
133
+ click.echo(f"Using device: {device}")
134
+
135
+ # Step 2: Run quality estimation if requested
136
+ if quality:
137
+ click.echo("Running quality estimation...")
138
+ df_quality = run_quality_estimation(
139
+ fpaths=preprocessed_files, ids=ids, device=device
140
+ )
141
+ df_quality.to_csv(quality_path)
142
+ click.echo(f"Quality results saved to {quality_path}")
143
+
144
+ # Step 3: Run vessels and AV segmentation if requested
145
+ if vessels:
146
+ click.echo("Running vessels and AV segmentation...")
147
+ run_segmentation_vessels_and_av(
148
+ rgb_paths=preprocessed_files,
149
+ ids=ids,
150
+ av_path=av_path,
151
+ vessels_path=vessels_path,
152
+ device=device,
153
+ )
154
+ click.echo(f"Vessel segmentation saved to {vessels_path}")
155
+ click.echo(f"AV segmentation saved to {av_path}")
156
+
157
+ # Step 4: Run optic disc segmentation if requested
158
+ if disc:
159
+ click.echo("Running optic disc segmentation...")
160
+ run_segmentation_disc(
161
+ rgb_paths=preprocessed_files, ids=ids, output_path=disc_path, device=device
162
+ )
163
+ click.echo(f"Disc segmentation saved to {disc_path}")
164
+
165
+ # Step 5: Run fovea detection if requested
166
+ df_fovea = None
167
+ if fovea:
168
+ click.echo("Running fovea detection...")
169
+ df_fovea = run_fovea_detection(
170
+ rgb_paths=preprocessed_files, ids=ids, device=device
171
+ )
172
+ df_fovea.to_csv(fovea_path)
173
+ click.echo(f"Fovea detection results saved to {fovea_path}")
174
+
175
+ # Step 6: Create overlays if requested
176
+ if overlay:
177
+ click.echo("Creating visualization overlays...")
178
+
179
+ # Prepare fovea data if available
180
+ fovea_data = None
181
+ if df_fovea is not None:
182
+ fovea_data = {
183
+ idx: (row["x_fovea"], row["y_fovea"])
184
+ for idx, row in df_fovea.iterrows()
185
+ }
186
+
187
+ # Create visualization overlays
188
+ batch_create_overlays(
189
+ rgb_dir=preprocess_rgb_path if preprocess else data_path,
190
+ output_dir=overlay_path,
191
+ av_dir=av_path,
192
+ disc_dir=disc_path,
193
+ fovea_data=fovea_data,
194
+ )
195
+
196
+ click.echo(f"Visualization overlays saved to {overlay_path}")
197
+
198
+ click.echo(f"All requested processing complete. Results saved to {output_path}")
vascx_models/inference.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import List, Optional
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+
11
+ from rtnls_inference.ensembles.ensemble_classification import ClassificationEnsemble
12
+ from rtnls_inference.ensembles.ensemble_heatmap_regression import (
13
+ HeatmapRegressionEnsemble,
14
+ )
15
+ from rtnls_inference.ensembles.ensemble_segmentation import SegmentationEnsemble
16
+ from rtnls_inference.utils import decollate_batch, extract_keypoints_from_heatmaps
17
+
18
+
19
+ def run_quality_estimation(fpaths, ids, device: torch.device):
20
+ ensemble_quality = ClassificationEnsemble.from_release("quality.pt").to(device)
21
+ dataloader = ensemble_quality._make_inference_dataloader(
22
+ fpaths,
23
+ ids=ids,
24
+ num_workers=8,
25
+ preprocess=False,
26
+ batch_size=16,
27
+ )
28
+
29
+ output_ids, outputs = [], []
30
+ with torch.no_grad():
31
+ for batch in tqdm(dataloader):
32
+ if len(batch) == 0:
33
+ continue
34
+
35
+ im = batch["image"].to(device)
36
+
37
+ # QUALITY
38
+ quality = ensemble_quality.predict_step(im)
39
+ quality = torch.mean(quality, dim=0)
40
+
41
+ items = {"id": batch["id"], "quality": quality}
42
+ items = decollate_batch(items)
43
+
44
+ for item in items:
45
+ output_ids.append(item["id"])
46
+ outputs.append(item["quality"].tolist())
47
+
48
+ return pd.DataFrame(
49
+ outputs,
50
+ index=output_ids,
51
+ columns=["q1", "q2", "q3"],
52
+ )
53
+
54
+
55
+ def run_segmentation_vessels_and_av(
56
+ rgb_paths: List[Path],
57
+ ce_paths: Optional[List[Path]] = None,
58
+ ids: Optional[List[str]] = None,
59
+ av_path: Optional[Path] = None,
60
+ vessels_path: Optional[Path] = None,
61
+ device: torch.device = torch.device(
62
+ "cuda:0" if torch.cuda.is_available() else "cpu"
63
+ ),
64
+ ) -> None:
65
+ """
66
+ Run AV and vessel segmentation on the provided images.
67
+
68
+ Args:
69
+ rgb_paths: List of paths to RGB fundus images
70
+ ce_paths: Optional list of paths to contrast enhanced images
71
+ ids: Optional list of ids to pass to _make_inference_dataloader
72
+ av_path: Folder where to store output AV segmentations
73
+ vessels_path: Folder where to store output vessel segmentations
74
+ device: Device to run inference on
75
+ """
76
+ # Create output directories if they don't exist
77
+ if av_path is not None:
78
+ av_path.mkdir(exist_ok=True, parents=True)
79
+ if vessels_path is not None:
80
+ vessels_path.mkdir(exist_ok=True, parents=True)
81
+
82
+ # Load models
83
+ ensemble_av = SegmentationEnsemble.from_release("av_july24.pt").to(device).eval()
84
+ ensemble_vessels = (
85
+ SegmentationEnsemble.from_release("vessels_july24.pt").to(device).eval()
86
+ )
87
+
88
+ # Prepare input paths
89
+ if ce_paths is None:
90
+ # If CE paths are not provided, use RGB paths for both inputs
91
+ fpaths = rgb_paths
92
+ else:
93
+ # If CE paths are provided, pair them with RGB paths
94
+ if len(rgb_paths) != len(ce_paths):
95
+ raise ValueError("rgb_paths and ce_paths must have the same length")
96
+ fpaths = list(zip(rgb_paths, ce_paths))
97
+
98
+ # Create dataloader
99
+ dataloader = ensemble_av._make_inference_dataloader(
100
+ fpaths,
101
+ ids=ids,
102
+ num_workers=8,
103
+ preprocess=False,
104
+ batch_size=8,
105
+ )
106
+
107
+ # Run inference
108
+ with torch.no_grad():
109
+ for batch in tqdm(dataloader):
110
+ # AV segmentation
111
+ if av_path is not None:
112
+ with torch.autocast(device_type=device.type):
113
+ proba = ensemble_av.forward(batch["image"].to(device))
114
+ proba = torch.mean(proba, dim=0) # average over models
115
+ proba = torch.permute(proba, (0, 2, 3, 1)) # NCHW -> NHWC
116
+ proba = torch.nn.functional.softmax(proba, dim=-1)
117
+
118
+ items = {
119
+ "id": batch["id"],
120
+ "image": proba,
121
+ }
122
+
123
+ items = decollate_batch(items)
124
+ for i, item in enumerate(items):
125
+ fpath = os.path.join(av_path, f"{item['id']}.png")
126
+ mask = np.argmax(item["image"], -1)
127
+ Image.fromarray(mask.squeeze().astype(np.uint8)).save(fpath)
128
+
129
+ # Vessel segmentation
130
+ if vessels_path is not None:
131
+ with torch.autocast(device_type=device.type):
132
+ proba = ensemble_vessels.forward(batch["image"].to(device))
133
+ proba = torch.mean(proba, dim=0) # average over models
134
+ proba = torch.permute(proba, (0, 2, 3, 1)) # NCHW -> NHWC
135
+ proba = torch.nn.functional.softmax(proba, dim=-1)
136
+
137
+ items = {
138
+ "id": batch["id"],
139
+ "image": proba,
140
+ }
141
+
142
+ items = decollate_batch(items)
143
+ for i, item in enumerate(items):
144
+ fpath = os.path.join(vessels_path, f"{item['id']}.png")
145
+ mask = np.argmax(item["image"], -1)
146
+ Image.fromarray(mask.squeeze().astype(np.uint8)).save(fpath)
147
+
148
+
149
+ def run_segmentation_disc(
150
+ rgb_paths: List[Path],
151
+ ce_paths: Optional[List[Path]] = None,
152
+ ids: Optional[List[str]] = None,
153
+ output_path: Optional[Path] = None,
154
+ device: torch.device = torch.device(
155
+ "cuda:0" if torch.cuda.is_available() else "cpu"
156
+ ),
157
+ ) -> None:
158
+ ensemble_disc = (
159
+ SegmentationEnsemble.from_release("disc_july24.pt").to(device).eval()
160
+ )
161
+
162
+ # Prepare input paths
163
+ if ce_paths is None:
164
+ # If CE paths are not provided, use RGB paths for both inputs
165
+ fpaths = rgb_paths
166
+ else:
167
+ # If CE paths are provided, pair them with RGB paths
168
+ if len(rgb_paths) != len(ce_paths):
169
+ raise ValueError("rgb_paths and ce_paths must have the same length")
170
+ fpaths = list(zip(rgb_paths, ce_paths))
171
+
172
+ dataloader = ensemble_disc._make_inference_dataloader(
173
+ fpaths,
174
+ ids=ids,
175
+ num_workers=8,
176
+ preprocess=False,
177
+ batch_size=8,
178
+ )
179
+
180
+ with torch.no_grad():
181
+ for batch in tqdm(dataloader):
182
+ # AV
183
+ with torch.autocast(device_type=device.type):
184
+ proba = ensemble_disc.forward(batch["image"].to(device))
185
+ proba = torch.mean(proba, dim=0) # average over models
186
+ proba = torch.permute(proba, (0, 2, 3, 1)) # NCHW -> NHWC
187
+ proba = torch.nn.functional.softmax(proba, dim=-1)
188
+
189
+ items = {
190
+ "id": batch["id"],
191
+ "image": proba,
192
+ }
193
+
194
+ items = decollate_batch(items)
195
+ items = [dataloader.dataset.transform.undo_item(item) for item in items]
196
+ for i, item in enumerate(items):
197
+ fpath = os.path.join(output_path, f"{item['id']}.png")
198
+
199
+ mask = np.argmax(item["image"], -1)
200
+ Image.fromarray(mask.squeeze().astype(np.uint8)).save(fpath)
201
+
202
+
203
+ def run_fovea_detection(
204
+ rgb_paths: List[Path],
205
+ ce_paths: Optional[List[Path]] = None,
206
+ ids: Optional[List[str]] = None,
207
+ device: torch.device = torch.device(
208
+ "cuda:0" if torch.cuda.is_available() else "cpu"
209
+ ),
210
+ ) -> None:
211
+ # def run_fovea_detection(fpaths, ids, device: torch.device):
212
+ ensemble_fovea = HeatmapRegressionEnsemble.from_release("fovea_july24.pt").to(
213
+ device
214
+ )
215
+
216
+ # Prepare input paths
217
+ if ce_paths is None:
218
+ # If CE paths are not provided, use RGB paths for both inputs
219
+ fpaths = rgb_paths
220
+ else:
221
+ # If CE paths are provided, pair them with RGB paths
222
+ if len(rgb_paths) != len(ce_paths):
223
+ raise ValueError("rgb_paths and ce_paths must have the same length")
224
+ fpaths = list(zip(rgb_paths, ce_paths))
225
+
226
+ dataloader = ensemble_fovea._make_inference_dataloader(
227
+ fpaths,
228
+ ids=ids,
229
+ num_workers=8,
230
+ preprocess=False,
231
+ batch_size=8,
232
+ )
233
+
234
+ output_ids, outputs = [], []
235
+ with torch.no_grad():
236
+ for batch in tqdm(dataloader):
237
+ if len(batch) == 0:
238
+ continue
239
+
240
+ im = batch["image"].to(device)
241
+
242
+ # FOVEA DETECTION
243
+ with torch.autocast(device_type=device.type):
244
+ heatmap = ensemble_fovea.forward(im)
245
+ keypoints = extract_keypoints_from_heatmaps(heatmap)
246
+
247
+ kp_fovea = torch.mean(keypoints, dim=0) # average over models
248
+
249
+ items = {
250
+ "id": batch["id"],
251
+ "keypoints": kp_fovea,
252
+ "metadata": batch["metadata"],
253
+ }
254
+ items = decollate_batch(items)
255
+
256
+ items = [dataloader.dataset.transform.undo_item(item) for item in items]
257
+
258
+ for item in items:
259
+ output_ids.append(item["id"])
260
+ outputs.append(
261
+ [
262
+ *item["keypoints"][0].tolist(),
263
+ ]
264
+ )
265
+ return pd.DataFrame(
266
+ outputs,
267
+ index=output_ids,
268
+ columns=["x_fovea", "y_fovea"],
269
+ )
vascx_models/utils.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Dict, Optional, Tuple
3
+
4
+ import numpy as np
5
+ from PIL import Image, ImageDraw
6
+
7
+
8
+ def create_fundus_overlay(
9
+ rgb_path: str,
10
+ av_path: Optional[str] = None,
11
+ disc_path: Optional[str] = None,
12
+ fovea_location: Optional[Tuple[int, int]] = None,
13
+ output_path: Optional[str] = None,
14
+ ) -> np.ndarray:
15
+ """
16
+ Create a visualization of a fundus image with overlaid segmentations and markers.
17
+
18
+ Args:
19
+ rgb_path: Path to the RGB fundus image
20
+ av_path: Optional path to artery-vein segmentation (1=artery, 2=vein, 3=intersection)
21
+ disc_path: Optional path to binary disc segmentation
22
+ fovea_location: Optional (x,y) tuple indicating the location of the fovea
23
+ output_path: Optional path to save the visualization image
24
+
25
+ Returns:
26
+ Numpy array containing the visualization image
27
+ """
28
+ print(rgb_path, av_path, disc_path, fovea_location, output_path)
29
+ # Load RGB image
30
+ rgb_img = np.array(Image.open(rgb_path))
31
+
32
+ # Create output image starting with the RGB image
33
+ output_img = rgb_img.copy()
34
+
35
+ # Load and overlay AV segmentation if provided
36
+ if av_path:
37
+ av_mask = np.array(Image.open(av_path))
38
+
39
+ # Create masks for arteries (1), veins (2) and intersections (3)
40
+ artery_mask = av_mask == 1
41
+ vein_mask = av_mask == 2
42
+ intersection_mask = av_mask == 3
43
+
44
+ # Combine artery and intersection for visualization
45
+ artery_combined = np.logical_or(artery_mask, intersection_mask)
46
+ vein_combined = np.logical_or(vein_mask, intersection_mask)
47
+
48
+ # Apply colors: red for arteries, blue for veins
49
+ # Red channel - increase for arteries
50
+ output_img[artery_combined, 0] = 255
51
+ output_img[artery_combined, 1] = 0
52
+ output_img[artery_combined, 2] = 0
53
+
54
+ # Blue channel - increase for veins
55
+ output_img[vein_combined, 0] = 0
56
+ output_img[vein_combined, 1] = 0
57
+ output_img[vein_combined, 2] = 255
58
+
59
+ # Load and overlay optic disc segmentation if provided
60
+ if disc_path:
61
+ disc_mask = np.array(Image.open(disc_path)) > 0
62
+
63
+ # Apply white color for disc
64
+ output_img[disc_mask, :] = [255, 255, 255] # White
65
+
66
+ # Convert to PIL image for drawing the fovea marker
67
+ pil_img = Image.fromarray(output_img)
68
+
69
+ # Add fovea marker if provided
70
+ if fovea_location:
71
+ draw = ImageDraw.Draw(pil_img)
72
+ x, y = fovea_location
73
+ marker_size = (
74
+ min(pil_img.width, pil_img.height) // 50
75
+ ) # Scale marker with image
76
+
77
+ # Draw yellow X at fovea location
78
+ draw.line(
79
+ [(x - marker_size, y - marker_size), (x + marker_size, y + marker_size)],
80
+ fill=(255, 255, 0),
81
+ width=2,
82
+ )
83
+ draw.line(
84
+ [(x - marker_size, y + marker_size), (x + marker_size, y - marker_size)],
85
+ fill=(255, 255, 0),
86
+ width=2,
87
+ )
88
+
89
+ # Convert back to numpy array
90
+ output_img = np.array(pil_img)
91
+
92
+ # Save output if path provided
93
+ if output_path:
94
+ Image.fromarray(output_img).save(output_path)
95
+
96
+ return output_img
97
+
98
+
99
+ def batch_create_overlays(
100
+ rgb_dir: Path,
101
+ output_dir: Path,
102
+ av_dir: Optional[Path] = None,
103
+ disc_dir: Optional[Path] = None,
104
+ fovea_data: Optional[Dict[str, Tuple[int, int]]] = None,
105
+ ) -> None:
106
+ """
107
+ Create visualization overlays for a batch of images.
108
+
109
+ Args:
110
+ rgb_dir: Directory containing RGB fundus images
111
+ output_dir: Directory to save visualization images
112
+ av_dir: Optional directory containing AV segmentations
113
+ disc_dir: Optional directory containing disc segmentations
114
+ fovea_data: Optional dictionary mapping image IDs to fovea coordinates
115
+
116
+ Returns:
117
+ List of paths to created visualization images
118
+ """
119
+ # Create output directory if it doesn't exist
120
+ output_dir.mkdir(exist_ok=True, parents=True)
121
+
122
+ # Get all RGB images
123
+ rgb_files = list(rgb_dir.glob("*.png"))
124
+ if not rgb_files:
125
+ return []
126
+
127
+ # Process each image
128
+ for rgb_file in rgb_files:
129
+ image_id = rgb_file.stem
130
+
131
+ # Check for corresponding AV segmentation
132
+ av_file = None
133
+ if av_dir:
134
+ av_file_path = av_dir / f"{image_id}.png"
135
+ if av_file_path.exists():
136
+ av_file = str(av_file_path)
137
+
138
+ # Check for corresponding disc segmentation
139
+ disc_file = None
140
+ if disc_dir:
141
+ disc_file_path = disc_dir / f"{image_id}.png"
142
+ if disc_file_path.exists():
143
+ disc_file = str(disc_file_path)
144
+
145
+ # Get fovea location if available
146
+ fovea_location = None
147
+ if fovea_data and image_id in fovea_data:
148
+ fovea_location = fovea_data[image_id]
149
+
150
+ # Create output path
151
+ output_file = output_dir / f"{image_id}.png"
152
+
153
+ # Create and save overlay
154
+ create_fundus_overlay(
155
+ rgb_path=str(rgb_file),
156
+ av_path=av_file,
157
+ disc_path=disc_file,
158
+ fovea_location=fovea_location,
159
+ output_path=str(output_file),
160
+ )