Jose commited on
Commit ·
ee42561
0
Parent(s):
first commit github branch
Browse files- .gitattributes +37 -0
- .gitignore +6 -0
- README.md +143 -0
- fovea/fovea_july24.pt +3 -0
- imgs/CHASEDB1_08L.png +3 -0
- imgs/CHASEDB1_08L_rgb.png +3 -0
- imgs/CHASEDB1_12R.png +3 -0
- imgs/CHASEDB1_12R_rgb.png +3 -0
- imgs/DRIVE_22.png +3 -0
- imgs/DRIVE_22_rgb.png +3 -0
- imgs/DRIVE_40.png +3 -0
- imgs/DRIVE_40_rgb.png +3 -0
- imgs/HRF_04_g.png +3 -0
- imgs/HRF_04_g_rgb.png +3 -0
- imgs/HRF_07_dr.png +3 -0
- imgs/HRF_07_dr_rgb.png +3 -0
- imgs/samples_vascx_hrf.png +3 -0
- notebooks/0_preprocess.ipynb +138 -0
- notebooks/1_segment_preprocessed.ipynb +217 -0
- samples/fundus/original/CHASEDB1_08L.png +3 -0
- samples/fundus/original/CHASEDB1_12R.png +3 -0
- samples/fundus/original/DRIVE_22.png +3 -0
- samples/fundus/original/DRIVE_40.png +3 -0
- samples/fundus/original/HRF_04_g.jpg +3 -0
- samples/fundus/original/HRF_07_dr.jpg +3 -0
- setup.py +33 -0
- vascx_models/cli.py +198 -0
- vascx_models/inference.py +269 -0
- vascx_models/utils.py +160 -0
.gitattributes
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.pyc
|
| 2 |
+
__pycache__
|
| 3 |
+
*.egg-info
|
| 4 |
+
*.zip
|
| 5 |
+
/samples/fundus/*
|
| 6 |
+
!/samples/fundus/original
|
README.md
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: agpl-3.0
|
| 3 |
+
pipeline_tag: image-segmentation
|
| 4 |
+
tags:
|
| 5 |
+
- medical
|
| 6 |
+
- biology
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
# VascX models
|
| 10 |
+
|
| 11 |
+
This repository contains the instructions for using the VascX models from the paper [VascX Models: Model Ensembles for Retinal Vascular Analysis from Color Fundus Images](https://arxiv.org/abs/2409.16016).
|
| 12 |
+
|
| 13 |
+
The model weights are in [huggingface](https://huggingface.co/Eyened/vascx).
|
| 14 |
+
|
| 15 |
+
<img src="imgs/CHASEDB1_12R_rgb.png" width="240" height="240" style="display:inline"><img src="imgs/CHASEDB1_12R.png" width="240" height="240" style="display:inline">
|
| 16 |
+
|
| 17 |
+
<img src="imgs/DRIVE_22_rgb.png" width="240" height="240" style="display:inline"><img src="imgs/DRIVE_22.png" width="240" height="240" style="display:inline">
|
| 18 |
+
|
| 19 |
+
<img src="imgs/HRF_04_g_rgb.png" width="240" height="240" style="display:inline"><img src="imgs/HRF_04_g.png" width="240" height="240" style="display:inline">
|
| 20 |
+
|
| 21 |
+
## Installation
|
| 22 |
+
|
| 23 |
+
To install the entire fundus analysis pipeline including fundus preprocessing, model inference code and vascular biomarker extraction:
|
| 24 |
+
|
| 25 |
+
1. Create a conda or virtualenv virtual environment, or otherwise ensure a clean environment.
|
| 26 |
+
|
| 27 |
+
2. Install the [rtnls_inference package](https://github.com/Eyened/retinalysis-inference).
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
## `vascx run` Command
|
| 31 |
+
|
| 32 |
+
The `run` command provides a comprehensive pipeline for processing fundus images, performing various analyses, and creating visualizations.
|
| 33 |
+
|
| 34 |
+
### Usage
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
vascx run DATA_PATH OUTPUT_PATH [OPTIONS]
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
### Arguments
|
| 41 |
+
|
| 42 |
+
- `DATA_PATH`: Path to input data. Can be either:
|
| 43 |
+
- A directory containing fundus images
|
| 44 |
+
- A CSV file with a 'path' column containing paths to images
|
| 45 |
+
|
| 46 |
+
- `OUTPUT_PATH`: Directory where processed results will be stored
|
| 47 |
+
|
| 48 |
+
### Options
|
| 49 |
+
|
| 50 |
+
| Option | Default | Description |
|
| 51 |
+
|--------|---------|-------------|
|
| 52 |
+
| `--preprocess/--no-preprocess` | `--preprocess` | Run preprocessing to standardize images for model input |
|
| 53 |
+
| `--vessels/--no-vessels` | `--vessels` | Run vessel segmentation and artery-vein classification |
|
| 54 |
+
| `--disc/--no-disc` | `--disc` | Run optic disc segmentation |
|
| 55 |
+
| `--quality/--no-quality` | `--quality` | Run image quality assessment |
|
| 56 |
+
| `--fovea/--no-fovea` | `--fovea` | Run fovea detection |
|
| 57 |
+
| `--overlay/--no-overlay` | `--overlay` | Create visualization overlays combining all results |
|
| 58 |
+
| `--n_jobs` | `4` | Number of preprocessing workers for parallel processing |
|
| 59 |
+
|
| 60 |
+
### Output Structure
|
| 61 |
+
|
| 62 |
+
When run with default options, the command creates the following structure in `OUTPUT_PATH`:
|
| 63 |
+
|
| 64 |
+
```
|
| 65 |
+
OUTPUT_PATH/
|
| 66 |
+
├── preprocessed_rgb/ # Standardized fundus images
|
| 67 |
+
├── vessels/ # Vessel segmentation results
|
| 68 |
+
├── artery_vein/ # Artery-vein classification
|
| 69 |
+
├── disc/ # Optic disc segmentation
|
| 70 |
+
├── overlays/ # Visualization images
|
| 71 |
+
├── bounds.csv # Image boundary information
|
| 72 |
+
├── quality.csv # Image quality scores
|
| 73 |
+
└── fovea.csv # Fovea coordinates
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
### Processing Stages
|
| 77 |
+
|
| 78 |
+
1. **Preprocessing**:
|
| 79 |
+
- Standardizes input images for consistent analysis
|
| 80 |
+
- Outputs preprocessed images and boundary information
|
| 81 |
+
|
| 82 |
+
2. **Quality Assessment**:
|
| 83 |
+
- Evaluates image quality with three quality metrics (q1, q2, q3)
|
| 84 |
+
- Higher scores indicate better image quality
|
| 85 |
+
|
| 86 |
+
3. **Vessel Segmentation and Artery-Vein Classification**:
|
| 87 |
+
- Identifies blood vessels in the retina
|
| 88 |
+
- Classifies vessels as arteries (1) or veins (2) with intersections (3)
|
| 89 |
+
|
| 90 |
+
4. **Optic Disc Segmentation**:
|
| 91 |
+
- Identifies the optic disc location and boundaries
|
| 92 |
+
|
| 93 |
+
5. **Fovea Detection**:
|
| 94 |
+
- Determines the coordinates of the fovea (center of vision)
|
| 95 |
+
|
| 96 |
+
6. **Visualization Overlays**:
|
| 97 |
+
- Creates color-coded images showing:
|
| 98 |
+
- Arteries in red
|
| 99 |
+
- Veins in blue
|
| 100 |
+
- Optic disc in white
|
| 101 |
+
- Fovea marked with yellow X
|
| 102 |
+
|
| 103 |
+
### Examples
|
| 104 |
+
|
| 105 |
+
**Process a directory of images with all analyses:**
|
| 106 |
+
```bash
|
| 107 |
+
vascx run /path/to/images /path/to/output
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
**Process specific images listed in a CSV:**
|
| 111 |
+
```bash
|
| 112 |
+
vascx run /path/to/image_list.csv /path/to/output
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
**Only run preprocessing and vessel segmentation:**
|
| 116 |
+
```bash
|
| 117 |
+
vascx run /path/to/images /path/to/output --no-disc --no-quality --no-fovea --no-overlay
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
**Skip preprocessing on already preprocessed images:**
|
| 121 |
+
```bash
|
| 122 |
+
vascx run /path/to/preprocessed/images /path/to/output --no-preprocess
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
**Increase parallel processing workers:**
|
| 126 |
+
```bash
|
| 127 |
+
vascx run /path/to/images /path/to/output --n_jobs 8
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
### Notes
|
| 131 |
+
|
| 132 |
+
- The CSV input must contain a 'path' column with image file paths
|
| 133 |
+
- If the CSV includes an 'id' column, these IDs will be used instead of filenames
|
| 134 |
+
- When `--no-preprocess` is used, input images must already be in the proper format
|
| 135 |
+
- The overlay visualization requires at least one analysis component to be enabled
|
| 136 |
+
|
| 137 |
+
###
|
| 138 |
+
|
| 139 |
+
To speed up re-execution of vascx we recommend to run the preprocessing and segmentation steps separately:
|
| 140 |
+
|
| 141 |
+
1. Preprocessing. See [this notebook](./notebooks/0_preprocess.ipynb). This step is CPU-heavy and benefits from parallelization (see notebook).
|
| 142 |
+
|
| 143 |
+
2. Inference. See [this notebook](./notebooks/1_segment_preprocessed.ipynb). All models can be ran in a single GPU with >10GB VRAM.
|
fovea/fovea_july24.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1af042f7e2a398f512be8a8d54cc480300312c3f3692c13bd98be65439a33222
|
| 3 |
+
size 352714676
|
imgs/CHASEDB1_08L.png
ADDED
|
Git LFS Details
|
imgs/CHASEDB1_08L_rgb.png
ADDED
|
Git LFS Details
|
imgs/CHASEDB1_12R.png
ADDED
|
Git LFS Details
|
imgs/CHASEDB1_12R_rgb.png
ADDED
|
Git LFS Details
|
imgs/DRIVE_22.png
ADDED
|
Git LFS Details
|
imgs/DRIVE_22_rgb.png
ADDED
|
Git LFS Details
|
imgs/DRIVE_40.png
ADDED
|
Git LFS Details
|
imgs/DRIVE_40_rgb.png
ADDED
|
Git LFS Details
|
imgs/HRF_04_g.png
ADDED
|
Git LFS Details
|
imgs/HRF_04_g_rgb.png
ADDED
|
Git LFS Details
|
imgs/HRF_07_dr.png
ADDED
|
Git LFS Details
|
imgs/HRF_07_dr_rgb.png
ADDED
|
Git LFS Details
|
imgs/samples_vascx_hrf.png
ADDED
|
Git LFS Details
|
notebooks/0_preprocess.ipynb
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"from pathlib import Path\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"import pandas as pd\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"from rtnls_fundusprep.preprocessor import parallel_preprocess"
|
| 14 |
+
]
|
| 15 |
+
},
|
| 16 |
+
{
|
| 17 |
+
"cell_type": "markdown",
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"source": [
|
| 20 |
+
"## Preprocessing\n",
|
| 21 |
+
"\n",
|
| 22 |
+
"This code will preprocess the images and write .png files with the square fundus image and the contrast enhanced version\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"This step is not strictly necessary, but it is useful if you want to run the preprocessing step separately before model inference\n"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "markdown",
|
| 29 |
+
"metadata": {},
|
| 30 |
+
"source": [
|
| 31 |
+
"Create a list of files to be preprocessed:"
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"cell_type": "code",
|
| 36 |
+
"execution_count": 2,
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"outputs": [],
|
| 39 |
+
"source": [
|
| 40 |
+
"ds_path = Path(\"../samples/fundus\")\n",
|
| 41 |
+
"files = list((ds_path / \"original\").glob(\"*\"))"
|
| 42 |
+
]
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"cell_type": "markdown",
|
| 46 |
+
"metadata": {},
|
| 47 |
+
"source": [
|
| 48 |
+
"Images with .dcm extension will be read as dicom and the pixel_array will be read as RGB. All other images will be read using PIL's Image.open"
|
| 49 |
+
]
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"cell_type": "code",
|
| 53 |
+
"execution_count": 3,
|
| 54 |
+
"metadata": {},
|
| 55 |
+
"outputs": [
|
| 56 |
+
{
|
| 57 |
+
"name": "stderr",
|
| 58 |
+
"output_type": "stream",
|
| 59 |
+
"text": [
|
| 60 |
+
"0it [00:00, ?it/s][Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.\n",
|
| 61 |
+
"6it [00:00, 154.80it/s]\n"
|
| 62 |
+
]
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"name": "stdout",
|
| 66 |
+
"output_type": "stream",
|
| 67 |
+
"text": [
|
| 68 |
+
"Error with image ../samples/fundus/original/HRF_07_dr.jpg\n",
|
| 69 |
+
"Error with image ../samples/fundus/original/HRF_04_g.jpg\n"
|
| 70 |
+
]
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"name": "stderr",
|
| 74 |
+
"output_type": "stream",
|
| 75 |
+
"text": [
|
| 76 |
+
"[Parallel(n_jobs=4)]: Done 2 out of 6 | elapsed: 0.9s remaining: 1.8s\n",
|
| 77 |
+
"[Parallel(n_jobs=4)]: Done 3 out of 6 | elapsed: 1.5s remaining: 1.5s\n",
|
| 78 |
+
"[Parallel(n_jobs=4)]: Done 4 out of 6 | elapsed: 1.5s remaining: 0.8s\n",
|
| 79 |
+
"[Parallel(n_jobs=4)]: Done 6 out of 6 | elapsed: 1.6s finished\n"
|
| 80 |
+
]
|
| 81 |
+
}
|
| 82 |
+
],
|
| 83 |
+
"source": [
|
| 84 |
+
"bounds = parallel_preprocess(\n",
|
| 85 |
+
" files, # List of image files\n",
|
| 86 |
+
" rgb_path=ds_path / \"rgb\", # Output path for RGB images\n",
|
| 87 |
+
" ce_path=ds_path / \"ce\", # Output path for Contrast Enhanced images\n",
|
| 88 |
+
" n_jobs=4, # number of preprocessing workers\n",
|
| 89 |
+
")\n",
|
| 90 |
+
"df_bounds = pd.DataFrame(bounds).set_index(\"id\")"
|
| 91 |
+
]
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"cell_type": "markdown",
|
| 95 |
+
"metadata": {},
|
| 96 |
+
"source": [
|
| 97 |
+
"The preprocessor will produce RGB and contrast-enhanced preprocessed images cropped to a square and return a dataframe with the image bounds that can be used to reconstruct the original image. Output files will be named the same as input images, but with .png extension. Be careful with providing multiple inputs with the same filename without extension as this will result in over-written images. Any exceptions during pre-processing will not stop execution but will print error. Images that failed pre-processing for any reason will be marked with `success=False` in the df_bounds dataframe."
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"cell_type": "code",
|
| 102 |
+
"execution_count": 4,
|
| 103 |
+
"metadata": {},
|
| 104 |
+
"outputs": [],
|
| 105 |
+
"source": [
|
| 106 |
+
"df_bounds.to_csv(ds_path / \"meta.csv\")"
|
| 107 |
+
]
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"cell_type": "code",
|
| 111 |
+
"execution_count": null,
|
| 112 |
+
"metadata": {},
|
| 113 |
+
"outputs": [],
|
| 114 |
+
"source": []
|
| 115 |
+
}
|
| 116 |
+
],
|
| 117 |
+
"metadata": {
|
| 118 |
+
"kernelspec": {
|
| 119 |
+
"display_name": "retinalysis",
|
| 120 |
+
"language": "python",
|
| 121 |
+
"name": "python3"
|
| 122 |
+
},
|
| 123 |
+
"language_info": {
|
| 124 |
+
"codemirror_mode": {
|
| 125 |
+
"name": "ipython",
|
| 126 |
+
"version": 3
|
| 127 |
+
},
|
| 128 |
+
"file_extension": ".py",
|
| 129 |
+
"mimetype": "text/x-python",
|
| 130 |
+
"name": "python",
|
| 131 |
+
"nbconvert_exporter": "python",
|
| 132 |
+
"pygments_lexer": "ipython3",
|
| 133 |
+
"version": "3.10.13"
|
| 134 |
+
}
|
| 135 |
+
},
|
| 136 |
+
"nbformat": 4,
|
| 137 |
+
"nbformat_minor": 2
|
| 138 |
+
}
|
notebooks/1_segment_preprocessed.ipynb
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"from pathlib import Path\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"import torch\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"from rtnls_inference import (\n",
|
| 14 |
+
" HeatmapRegressionEnsemble,\n",
|
| 15 |
+
" SegmentationEnsemble,\n",
|
| 16 |
+
")"
|
| 17 |
+
]
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "markdown",
|
| 21 |
+
"metadata": {},
|
| 22 |
+
"source": [
|
| 23 |
+
"## Segmentation of preprocessed images\n",
|
| 24 |
+
"\n",
|
| 25 |
+
"Here we segment images preprocessed using 0_preprocess.ipynb\n"
|
| 26 |
+
]
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"cell_type": "markdown",
|
| 30 |
+
"metadata": {},
|
| 31 |
+
"source": []
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"cell_type": "code",
|
| 35 |
+
"execution_count": 2,
|
| 36 |
+
"metadata": {},
|
| 37 |
+
"outputs": [],
|
| 38 |
+
"source": [
|
| 39 |
+
"ds_path = Path(\"../samples/fundus\")\n",
|
| 40 |
+
"\n",
|
| 41 |
+
"# input folders. these are the folders where we stored the preprocessed images\n",
|
| 42 |
+
"rgb_path = ds_path / \"rgb\"\n",
|
| 43 |
+
"ce_path = ds_path / \"ce\"\n",
|
| 44 |
+
"\n",
|
| 45 |
+
"# these are the output folders for:\n",
|
| 46 |
+
"av_path = ds_path / \"av\" # artery-vein segmentations\n",
|
| 47 |
+
"discs_path = ds_path / \"discs\" # optic disc segmentations\n",
|
| 48 |
+
"overlays_path = ds_path / \"overlays\" # optional overlay visualizations\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"device = torch.device(\"cuda:0\") # device to use for inference"
|
| 51 |
+
]
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"cell_type": "code",
|
| 55 |
+
"execution_count": 3,
|
| 56 |
+
"metadata": {},
|
| 57 |
+
"outputs": [],
|
| 58 |
+
"source": [
|
| 59 |
+
"rgb_paths = sorted(list(rgb_path.glob(\"*.png\")))\n",
|
| 60 |
+
"ce_paths = sorted(list(ce_path.glob(\"*.png\")))\n",
|
| 61 |
+
"paired_paths = list(zip(rgb_paths, ce_paths))"
|
| 62 |
+
]
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"cell_type": "code",
|
| 66 |
+
"execution_count": null,
|
| 67 |
+
"metadata": {},
|
| 68 |
+
"outputs": [],
|
| 69 |
+
"source": [
|
| 70 |
+
"paired_paths[0] # important to make sure that the paths are paired correctly"
|
| 71 |
+
]
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"cell_type": "markdown",
|
| 75 |
+
"metadata": {},
|
| 76 |
+
"source": [
|
| 77 |
+
"### Artery-vein segmentation\n"
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"cell_type": "code",
|
| 82 |
+
"execution_count": null,
|
| 83 |
+
"metadata": {},
|
| 84 |
+
"outputs": [],
|
| 85 |
+
"source": [
|
| 86 |
+
"av_ensemble = SegmentationEnsemble.from_huggingface('Eyened/vascx:artery_vein/av_july24.pt').to(device)\n",
|
| 87 |
+
"\n",
|
| 88 |
+
"av_ensemble.predict_preprocessed(paired_paths, dest_path=av_path, num_workers=2)"
|
| 89 |
+
]
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"cell_type": "markdown",
|
| 93 |
+
"metadata": {},
|
| 94 |
+
"source": [
|
| 95 |
+
"### Disc segmentation\n"
|
| 96 |
+
]
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
"cell_type": "code",
|
| 100 |
+
"execution_count": null,
|
| 101 |
+
"metadata": {},
|
| 102 |
+
"outputs": [],
|
| 103 |
+
"source": [
|
| 104 |
+
"disc_ensemble = SegmentationEnsemble.from_huggingface('Eyened/vascx:disc/disc_july24.pt').to(device)\n",
|
| 105 |
+
"disc_ensemble.predict_preprocessed(paired_paths, dest_path=discs_path, num_workers=2)"
|
| 106 |
+
]
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"cell_type": "markdown",
|
| 110 |
+
"metadata": {},
|
| 111 |
+
"source": [
|
| 112 |
+
"### Fovea detection\n"
|
| 113 |
+
]
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"cell_type": "code",
|
| 117 |
+
"execution_count": null,
|
| 118 |
+
"metadata": {},
|
| 119 |
+
"outputs": [],
|
| 120 |
+
"source": [
|
| 121 |
+
"fovea_ensemble = HeatmapRegressionEnsemble.from_huggingface('Eyened/vascx:fovea/fovea_july24.pt').to(device)\n",
|
| 122 |
+
"# note: this model does not use contrast enhanced images\n",
|
| 123 |
+
"df = fovea_ensemble.predict_preprocessed(paired_paths, num_workers=2)\n",
|
| 124 |
+
"df.columns = [\"mean_x\", \"mean_y\"]\n",
|
| 125 |
+
"df.to_csv(ds_path / \"fovea.csv\")"
|
| 126 |
+
]
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"cell_type": "code",
|
| 130 |
+
"execution_count": null,
|
| 131 |
+
"metadata": {},
|
| 132 |
+
"outputs": [],
|
| 133 |
+
"source": [
|
| 134 |
+
"df"
|
| 135 |
+
]
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"cell_type": "markdown",
|
| 139 |
+
"metadata": {},
|
| 140 |
+
"source": [
|
| 141 |
+
"### Plotting the retinas (optional)\n",
|
| 142 |
+
"\n",
|
| 143 |
+
"This will only work if you ran all the models and stored the outputs using the same folder/file names as above\n"
|
| 144 |
+
]
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"cell_type": "code",
|
| 148 |
+
"execution_count": null,
|
| 149 |
+
"metadata": {},
|
| 150 |
+
"outputs": [],
|
| 151 |
+
"source": [
|
| 152 |
+
"from vascx.fundus.loader import RetinaLoader\n",
|
| 153 |
+
"\n",
|
| 154 |
+
"from rtnls_enface.utils.plotting import plot_gridfns\n",
|
| 155 |
+
"\n",
|
| 156 |
+
"loader = RetinaLoader.from_folder(ds_path)"
|
| 157 |
+
]
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"cell_type": "code",
|
| 161 |
+
"execution_count": null,
|
| 162 |
+
"metadata": {},
|
| 163 |
+
"outputs": [],
|
| 164 |
+
"source": [
|
| 165 |
+
"plot_gridfns([ret.plot for ret in loader[:6]])"
|
| 166 |
+
]
|
| 167 |
+
},
|
| 168 |
+
{
|
| 169 |
+
"cell_type": "markdown",
|
| 170 |
+
"metadata": {},
|
| 171 |
+
"source": [
|
| 172 |
+
"### Storing visualizations (optional)\n"
|
| 173 |
+
]
|
| 174 |
+
},
|
| 175 |
+
{
|
| 176 |
+
"cell_type": "code",
|
| 177 |
+
"execution_count": 10,
|
| 178 |
+
"metadata": {},
|
| 179 |
+
"outputs": [],
|
| 180 |
+
"source": [
|
| 181 |
+
"if not overlays_path.exists():\n",
|
| 182 |
+
" overlays_path.mkdir()\n",
|
| 183 |
+
"for ret in loader:\n",
|
| 184 |
+
" fig, _ = ret.plot()\n",
|
| 185 |
+
" fig.savefig(overlays_path / f\"{ret.id}.png\", bbox_inches=\"tight\", pad_inches=0)"
|
| 186 |
+
]
|
| 187 |
+
},
|
| 188 |
+
{
|
| 189 |
+
"cell_type": "code",
|
| 190 |
+
"execution_count": null,
|
| 191 |
+
"metadata": {},
|
| 192 |
+
"outputs": [],
|
| 193 |
+
"source": []
|
| 194 |
+
}
|
| 195 |
+
],
|
| 196 |
+
"metadata": {
|
| 197 |
+
"kernelspec": {
|
| 198 |
+
"display_name": "retinalysis",
|
| 199 |
+
"language": "python",
|
| 200 |
+
"name": "python3"
|
| 201 |
+
},
|
| 202 |
+
"language_info": {
|
| 203 |
+
"codemirror_mode": {
|
| 204 |
+
"name": "ipython",
|
| 205 |
+
"version": 3
|
| 206 |
+
},
|
| 207 |
+
"file_extension": ".py",
|
| 208 |
+
"mimetype": "text/x-python",
|
| 209 |
+
"name": "python",
|
| 210 |
+
"nbconvert_exporter": "python",
|
| 211 |
+
"pygments_lexer": "ipython3",
|
| 212 |
+
"version": "3.10.13"
|
| 213 |
+
}
|
| 214 |
+
},
|
| 215 |
+
"nbformat": 4,
|
| 216 |
+
"nbformat_minor": 2
|
| 217 |
+
}
|
samples/fundus/original/CHASEDB1_08L.png
ADDED
|
Git LFS Details
|
samples/fundus/original/CHASEDB1_12R.png
ADDED
|
Git LFS Details
|
samples/fundus/original/DRIVE_22.png
ADDED
|
Git LFS Details
|
samples/fundus/original/DRIVE_40.png
ADDED
|
Git LFS Details
|
samples/fundus/original/HRF_04_g.jpg
ADDED
|
Git LFS Details
|
samples/fundus/original/HRF_07_dr.jpg
ADDED
|
Git LFS Details
|
setup.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import find_packages, setup
|
| 2 |
+
|
| 3 |
+
with open("README.md", "r") as fh:
|
| 4 |
+
long_description = fh.read()
|
| 5 |
+
|
| 6 |
+
setup(
|
| 7 |
+
name="vascx_models",
|
| 8 |
+
# using versioneer for versioning using git tags
|
| 9 |
+
# https://github.com/python-versioneer/python-versioneer/blob/master/INSTALL.md
|
| 10 |
+
# version=versioneer.get_version(),
|
| 11 |
+
# cmdclass=versioneer.get_cmdclass(),
|
| 12 |
+
author="Jose Vargas",
|
| 13 |
+
author_email="j.vargasquiros@erasmusmc.nl",
|
| 14 |
+
description="Retinal analysis toolbox for Python",
|
| 15 |
+
long_description=long_description,
|
| 16 |
+
long_description_content_type="text/markdown",
|
| 17 |
+
packages=find_packages(),
|
| 18 |
+
include_package_data=True,
|
| 19 |
+
zip_safe=False,
|
| 20 |
+
entry_points={
|
| 21 |
+
"console_scripts": [
|
| 22 |
+
"vascx = vascx_models.cli:cli",
|
| 23 |
+
]
|
| 24 |
+
},
|
| 25 |
+
install_requires=[
|
| 26 |
+
"numpy == 1.*",
|
| 27 |
+
"pandas == 2.*",
|
| 28 |
+
"tqdm == 4.*",
|
| 29 |
+
"Pillow == 9.*",
|
| 30 |
+
"click==8.*",
|
| 31 |
+
],
|
| 32 |
+
python_requires=">=3.10, <3.11",
|
| 33 |
+
)
|
vascx_models/cli.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import click
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from rtnls_fundusprep.cli import _run_preprocessing
|
| 8 |
+
|
| 9 |
+
from .inference import (
|
| 10 |
+
run_fovea_detection,
|
| 11 |
+
run_quality_estimation,
|
| 12 |
+
run_segmentation_disc,
|
| 13 |
+
run_segmentation_vessels_and_av,
|
| 14 |
+
)
|
| 15 |
+
from .utils import batch_create_overlays
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@click.group(name="vascx")
|
| 19 |
+
def cli():
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@cli.command()
|
| 24 |
+
@click.argument("data_path", type=click.Path(exists=True))
|
| 25 |
+
@click.argument("output_path", type=click.Path())
|
| 26 |
+
@click.option(
|
| 27 |
+
"--preprocess/--no-preprocess",
|
| 28 |
+
default=True,
|
| 29 |
+
help="Run preprocessing or use preprocessed images",
|
| 30 |
+
)
|
| 31 |
+
@click.option(
|
| 32 |
+
"--vessels/--no-vessels", default=True, help="Run vessels and AV segmentation"
|
| 33 |
+
)
|
| 34 |
+
@click.option("--disc/--no-disc", default=True, help="Run optic disc segmentation")
|
| 35 |
+
@click.option(
|
| 36 |
+
"--quality/--no-quality", default=True, help="Run image quality estimation"
|
| 37 |
+
)
|
| 38 |
+
@click.option("--fovea/--no-fovea", default=True, help="Run fovea detection")
|
| 39 |
+
@click.option(
|
| 40 |
+
"--overlay/--no-overlay", default=True, help="Create visualization overlays"
|
| 41 |
+
)
|
| 42 |
+
@click.option("--n_jobs", type=int, default=4, help="Number of preprocessing workers")
|
| 43 |
+
def run(
|
| 44 |
+
data_path, output_path, preprocess, vessels, disc, quality, fovea, overlay, n_jobs
|
| 45 |
+
):
|
| 46 |
+
"""Run the complete inference pipeline on fundus images.
|
| 47 |
+
|
| 48 |
+
DATA_PATH is either a directory containing images or a CSV file with 'path' column.
|
| 49 |
+
OUTPUT_PATH is the directory where results will be stored.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
output_path = Path(output_path)
|
| 53 |
+
output_path.mkdir(exist_ok=True, parents=True)
|
| 54 |
+
|
| 55 |
+
# Setup output directories
|
| 56 |
+
preprocess_rgb_path = output_path / "preprocessed_rgb"
|
| 57 |
+
vessels_path = output_path / "vessels"
|
| 58 |
+
av_path = output_path / "artery_vein"
|
| 59 |
+
disc_path = output_path / "disc"
|
| 60 |
+
overlay_path = output_path / "overlays"
|
| 61 |
+
|
| 62 |
+
# Create required directories
|
| 63 |
+
if preprocess:
|
| 64 |
+
preprocess_rgb_path.mkdir(exist_ok=True, parents=True)
|
| 65 |
+
if vessels:
|
| 66 |
+
av_path.mkdir(exist_ok=True, parents=True)
|
| 67 |
+
vessels_path.mkdir(exist_ok=True, parents=True)
|
| 68 |
+
if disc:
|
| 69 |
+
disc_path.mkdir(exist_ok=True, parents=True)
|
| 70 |
+
if overlay:
|
| 71 |
+
overlay_path.mkdir(exist_ok=True, parents=True)
|
| 72 |
+
|
| 73 |
+
bounds_path = output_path / "bounds.csv" if preprocess else None
|
| 74 |
+
quality_path = output_path / "quality.csv" if quality else None
|
| 75 |
+
fovea_path = output_path / "fovea.csv" if fovea else None
|
| 76 |
+
|
| 77 |
+
# Determine if input is a folder or CSV file
|
| 78 |
+
data_path = Path(data_path)
|
| 79 |
+
is_csv = data_path.suffix.lower() == ".csv"
|
| 80 |
+
|
| 81 |
+
# Get files to process
|
| 82 |
+
files = []
|
| 83 |
+
ids = None
|
| 84 |
+
|
| 85 |
+
if is_csv:
|
| 86 |
+
click.echo(f"Reading file paths from CSV: {data_path}")
|
| 87 |
+
try:
|
| 88 |
+
df = pd.read_csv(data_path)
|
| 89 |
+
if "path" not in df.columns:
|
| 90 |
+
click.echo("Error: CSV must contain a 'path' column")
|
| 91 |
+
return
|
| 92 |
+
|
| 93 |
+
# Get file paths and convert to Path objects
|
| 94 |
+
files = [Path(p) for p in df["path"]]
|
| 95 |
+
|
| 96 |
+
if "id" in df.columns:
|
| 97 |
+
ids = df["id"].tolist()
|
| 98 |
+
click.echo("Using IDs from CSV 'id' column")
|
| 99 |
+
|
| 100 |
+
except Exception as e:
|
| 101 |
+
click.echo(f"Error reading CSV file: {e}")
|
| 102 |
+
return
|
| 103 |
+
else:
|
| 104 |
+
click.echo(f"Finding files in directory: {data_path}")
|
| 105 |
+
files = list(data_path.glob("*"))
|
| 106 |
+
ids = [f.stem for f in files]
|
| 107 |
+
|
| 108 |
+
if not files:
|
| 109 |
+
click.echo("No files found to process")
|
| 110 |
+
return
|
| 111 |
+
|
| 112 |
+
click.echo(f"Found {len(files)} files to process")
|
| 113 |
+
|
| 114 |
+
# Step 1: Preprocess images if requested
|
| 115 |
+
if preprocess:
|
| 116 |
+
click.echo("Running preprocessing...")
|
| 117 |
+
_run_preprocessing(
|
| 118 |
+
files=files,
|
| 119 |
+
ids=ids,
|
| 120 |
+
rgb_path=preprocess_rgb_path,
|
| 121 |
+
bounds_path=bounds_path,
|
| 122 |
+
n_jobs=n_jobs,
|
| 123 |
+
)
|
| 124 |
+
# Use the preprocessed images for subsequent steps
|
| 125 |
+
preprocessed_files = list(preprocess_rgb_path.glob("*.png"))
|
| 126 |
+
else:
|
| 127 |
+
# Use the input files directly
|
| 128 |
+
preprocessed_files = files
|
| 129 |
+
ids = [f.stem for f in preprocessed_files]
|
| 130 |
+
|
| 131 |
+
# Set up GPU device
|
| 132 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 133 |
+
click.echo(f"Using device: {device}")
|
| 134 |
+
|
| 135 |
+
# Step 2: Run quality estimation if requested
|
| 136 |
+
if quality:
|
| 137 |
+
click.echo("Running quality estimation...")
|
| 138 |
+
df_quality = run_quality_estimation(
|
| 139 |
+
fpaths=preprocessed_files, ids=ids, device=device
|
| 140 |
+
)
|
| 141 |
+
df_quality.to_csv(quality_path)
|
| 142 |
+
click.echo(f"Quality results saved to {quality_path}")
|
| 143 |
+
|
| 144 |
+
# Step 3: Run vessels and AV segmentation if requested
|
| 145 |
+
if vessels:
|
| 146 |
+
click.echo("Running vessels and AV segmentation...")
|
| 147 |
+
run_segmentation_vessels_and_av(
|
| 148 |
+
rgb_paths=preprocessed_files,
|
| 149 |
+
ids=ids,
|
| 150 |
+
av_path=av_path,
|
| 151 |
+
vessels_path=vessels_path,
|
| 152 |
+
device=device,
|
| 153 |
+
)
|
| 154 |
+
click.echo(f"Vessel segmentation saved to {vessels_path}")
|
| 155 |
+
click.echo(f"AV segmentation saved to {av_path}")
|
| 156 |
+
|
| 157 |
+
# Step 4: Run optic disc segmentation if requested
|
| 158 |
+
if disc:
|
| 159 |
+
click.echo("Running optic disc segmentation...")
|
| 160 |
+
run_segmentation_disc(
|
| 161 |
+
rgb_paths=preprocessed_files, ids=ids, output_path=disc_path, device=device
|
| 162 |
+
)
|
| 163 |
+
click.echo(f"Disc segmentation saved to {disc_path}")
|
| 164 |
+
|
| 165 |
+
# Step 5: Run fovea detection if requested
|
| 166 |
+
df_fovea = None
|
| 167 |
+
if fovea:
|
| 168 |
+
click.echo("Running fovea detection...")
|
| 169 |
+
df_fovea = run_fovea_detection(
|
| 170 |
+
rgb_paths=preprocessed_files, ids=ids, device=device
|
| 171 |
+
)
|
| 172 |
+
df_fovea.to_csv(fovea_path)
|
| 173 |
+
click.echo(f"Fovea detection results saved to {fovea_path}")
|
| 174 |
+
|
| 175 |
+
# Step 6: Create overlays if requested
|
| 176 |
+
if overlay:
|
| 177 |
+
click.echo("Creating visualization overlays...")
|
| 178 |
+
|
| 179 |
+
# Prepare fovea data if available
|
| 180 |
+
fovea_data = None
|
| 181 |
+
if df_fovea is not None:
|
| 182 |
+
fovea_data = {
|
| 183 |
+
idx: (row["x_fovea"], row["y_fovea"])
|
| 184 |
+
for idx, row in df_fovea.iterrows()
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
# Create visualization overlays
|
| 188 |
+
batch_create_overlays(
|
| 189 |
+
rgb_dir=preprocess_rgb_path if preprocess else data_path,
|
| 190 |
+
output_dir=overlay_path,
|
| 191 |
+
av_dir=av_path,
|
| 192 |
+
disc_dir=disc_path,
|
| 193 |
+
fovea_data=fovea_data,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
click.echo(f"Visualization overlays saved to {overlay_path}")
|
| 197 |
+
|
| 198 |
+
click.echo(f"All requested processing complete. Results saved to {output_path}")
|
vascx_models/inference.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
from rtnls_inference.ensembles.ensemble_classification import ClassificationEnsemble
|
| 12 |
+
from rtnls_inference.ensembles.ensemble_heatmap_regression import (
|
| 13 |
+
HeatmapRegressionEnsemble,
|
| 14 |
+
)
|
| 15 |
+
from rtnls_inference.ensembles.ensemble_segmentation import SegmentationEnsemble
|
| 16 |
+
from rtnls_inference.utils import decollate_batch, extract_keypoints_from_heatmaps
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def run_quality_estimation(fpaths, ids, device: torch.device):
|
| 20 |
+
ensemble_quality = ClassificationEnsemble.from_release("quality.pt").to(device)
|
| 21 |
+
dataloader = ensemble_quality._make_inference_dataloader(
|
| 22 |
+
fpaths,
|
| 23 |
+
ids=ids,
|
| 24 |
+
num_workers=8,
|
| 25 |
+
preprocess=False,
|
| 26 |
+
batch_size=16,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
output_ids, outputs = [], []
|
| 30 |
+
with torch.no_grad():
|
| 31 |
+
for batch in tqdm(dataloader):
|
| 32 |
+
if len(batch) == 0:
|
| 33 |
+
continue
|
| 34 |
+
|
| 35 |
+
im = batch["image"].to(device)
|
| 36 |
+
|
| 37 |
+
# QUALITY
|
| 38 |
+
quality = ensemble_quality.predict_step(im)
|
| 39 |
+
quality = torch.mean(quality, dim=0)
|
| 40 |
+
|
| 41 |
+
items = {"id": batch["id"], "quality": quality}
|
| 42 |
+
items = decollate_batch(items)
|
| 43 |
+
|
| 44 |
+
for item in items:
|
| 45 |
+
output_ids.append(item["id"])
|
| 46 |
+
outputs.append(item["quality"].tolist())
|
| 47 |
+
|
| 48 |
+
return pd.DataFrame(
|
| 49 |
+
outputs,
|
| 50 |
+
index=output_ids,
|
| 51 |
+
columns=["q1", "q2", "q3"],
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def run_segmentation_vessels_and_av(
|
| 56 |
+
rgb_paths: List[Path],
|
| 57 |
+
ce_paths: Optional[List[Path]] = None,
|
| 58 |
+
ids: Optional[List[str]] = None,
|
| 59 |
+
av_path: Optional[Path] = None,
|
| 60 |
+
vessels_path: Optional[Path] = None,
|
| 61 |
+
device: torch.device = torch.device(
|
| 62 |
+
"cuda:0" if torch.cuda.is_available() else "cpu"
|
| 63 |
+
),
|
| 64 |
+
) -> None:
|
| 65 |
+
"""
|
| 66 |
+
Run AV and vessel segmentation on the provided images.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
rgb_paths: List of paths to RGB fundus images
|
| 70 |
+
ce_paths: Optional list of paths to contrast enhanced images
|
| 71 |
+
ids: Optional list of ids to pass to _make_inference_dataloader
|
| 72 |
+
av_path: Folder where to store output AV segmentations
|
| 73 |
+
vessels_path: Folder where to store output vessel segmentations
|
| 74 |
+
device: Device to run inference on
|
| 75 |
+
"""
|
| 76 |
+
# Create output directories if they don't exist
|
| 77 |
+
if av_path is not None:
|
| 78 |
+
av_path.mkdir(exist_ok=True, parents=True)
|
| 79 |
+
if vessels_path is not None:
|
| 80 |
+
vessels_path.mkdir(exist_ok=True, parents=True)
|
| 81 |
+
|
| 82 |
+
# Load models
|
| 83 |
+
ensemble_av = SegmentationEnsemble.from_release("av_july24.pt").to(device).eval()
|
| 84 |
+
ensemble_vessels = (
|
| 85 |
+
SegmentationEnsemble.from_release("vessels_july24.pt").to(device).eval()
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Prepare input paths
|
| 89 |
+
if ce_paths is None:
|
| 90 |
+
# If CE paths are not provided, use RGB paths for both inputs
|
| 91 |
+
fpaths = rgb_paths
|
| 92 |
+
else:
|
| 93 |
+
# If CE paths are provided, pair them with RGB paths
|
| 94 |
+
if len(rgb_paths) != len(ce_paths):
|
| 95 |
+
raise ValueError("rgb_paths and ce_paths must have the same length")
|
| 96 |
+
fpaths = list(zip(rgb_paths, ce_paths))
|
| 97 |
+
|
| 98 |
+
# Create dataloader
|
| 99 |
+
dataloader = ensemble_av._make_inference_dataloader(
|
| 100 |
+
fpaths,
|
| 101 |
+
ids=ids,
|
| 102 |
+
num_workers=8,
|
| 103 |
+
preprocess=False,
|
| 104 |
+
batch_size=8,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Run inference
|
| 108 |
+
with torch.no_grad():
|
| 109 |
+
for batch in tqdm(dataloader):
|
| 110 |
+
# AV segmentation
|
| 111 |
+
if av_path is not None:
|
| 112 |
+
with torch.autocast(device_type=device.type):
|
| 113 |
+
proba = ensemble_av.forward(batch["image"].to(device))
|
| 114 |
+
proba = torch.mean(proba, dim=0) # average over models
|
| 115 |
+
proba = torch.permute(proba, (0, 2, 3, 1)) # NCHW -> NHWC
|
| 116 |
+
proba = torch.nn.functional.softmax(proba, dim=-1)
|
| 117 |
+
|
| 118 |
+
items = {
|
| 119 |
+
"id": batch["id"],
|
| 120 |
+
"image": proba,
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
items = decollate_batch(items)
|
| 124 |
+
for i, item in enumerate(items):
|
| 125 |
+
fpath = os.path.join(av_path, f"{item['id']}.png")
|
| 126 |
+
mask = np.argmax(item["image"], -1)
|
| 127 |
+
Image.fromarray(mask.squeeze().astype(np.uint8)).save(fpath)
|
| 128 |
+
|
| 129 |
+
# Vessel segmentation
|
| 130 |
+
if vessels_path is not None:
|
| 131 |
+
with torch.autocast(device_type=device.type):
|
| 132 |
+
proba = ensemble_vessels.forward(batch["image"].to(device))
|
| 133 |
+
proba = torch.mean(proba, dim=0) # average over models
|
| 134 |
+
proba = torch.permute(proba, (0, 2, 3, 1)) # NCHW -> NHWC
|
| 135 |
+
proba = torch.nn.functional.softmax(proba, dim=-1)
|
| 136 |
+
|
| 137 |
+
items = {
|
| 138 |
+
"id": batch["id"],
|
| 139 |
+
"image": proba,
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
items = decollate_batch(items)
|
| 143 |
+
for i, item in enumerate(items):
|
| 144 |
+
fpath = os.path.join(vessels_path, f"{item['id']}.png")
|
| 145 |
+
mask = np.argmax(item["image"], -1)
|
| 146 |
+
Image.fromarray(mask.squeeze().astype(np.uint8)).save(fpath)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def run_segmentation_disc(
|
| 150 |
+
rgb_paths: List[Path],
|
| 151 |
+
ce_paths: Optional[List[Path]] = None,
|
| 152 |
+
ids: Optional[List[str]] = None,
|
| 153 |
+
output_path: Optional[Path] = None,
|
| 154 |
+
device: torch.device = torch.device(
|
| 155 |
+
"cuda:0" if torch.cuda.is_available() else "cpu"
|
| 156 |
+
),
|
| 157 |
+
) -> None:
|
| 158 |
+
ensemble_disc = (
|
| 159 |
+
SegmentationEnsemble.from_release("disc_july24.pt").to(device).eval()
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Prepare input paths
|
| 163 |
+
if ce_paths is None:
|
| 164 |
+
# If CE paths are not provided, use RGB paths for both inputs
|
| 165 |
+
fpaths = rgb_paths
|
| 166 |
+
else:
|
| 167 |
+
# If CE paths are provided, pair them with RGB paths
|
| 168 |
+
if len(rgb_paths) != len(ce_paths):
|
| 169 |
+
raise ValueError("rgb_paths and ce_paths must have the same length")
|
| 170 |
+
fpaths = list(zip(rgb_paths, ce_paths))
|
| 171 |
+
|
| 172 |
+
dataloader = ensemble_disc._make_inference_dataloader(
|
| 173 |
+
fpaths,
|
| 174 |
+
ids=ids,
|
| 175 |
+
num_workers=8,
|
| 176 |
+
preprocess=False,
|
| 177 |
+
batch_size=8,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
with torch.no_grad():
|
| 181 |
+
for batch in tqdm(dataloader):
|
| 182 |
+
# AV
|
| 183 |
+
with torch.autocast(device_type=device.type):
|
| 184 |
+
proba = ensemble_disc.forward(batch["image"].to(device))
|
| 185 |
+
proba = torch.mean(proba, dim=0) # average over models
|
| 186 |
+
proba = torch.permute(proba, (0, 2, 3, 1)) # NCHW -> NHWC
|
| 187 |
+
proba = torch.nn.functional.softmax(proba, dim=-1)
|
| 188 |
+
|
| 189 |
+
items = {
|
| 190 |
+
"id": batch["id"],
|
| 191 |
+
"image": proba,
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
items = decollate_batch(items)
|
| 195 |
+
items = [dataloader.dataset.transform.undo_item(item) for item in items]
|
| 196 |
+
for i, item in enumerate(items):
|
| 197 |
+
fpath = os.path.join(output_path, f"{item['id']}.png")
|
| 198 |
+
|
| 199 |
+
mask = np.argmax(item["image"], -1)
|
| 200 |
+
Image.fromarray(mask.squeeze().astype(np.uint8)).save(fpath)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def run_fovea_detection(
|
| 204 |
+
rgb_paths: List[Path],
|
| 205 |
+
ce_paths: Optional[List[Path]] = None,
|
| 206 |
+
ids: Optional[List[str]] = None,
|
| 207 |
+
device: torch.device = torch.device(
|
| 208 |
+
"cuda:0" if torch.cuda.is_available() else "cpu"
|
| 209 |
+
),
|
| 210 |
+
) -> None:
|
| 211 |
+
# def run_fovea_detection(fpaths, ids, device: torch.device):
|
| 212 |
+
ensemble_fovea = HeatmapRegressionEnsemble.from_release("fovea_july24.pt").to(
|
| 213 |
+
device
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# Prepare input paths
|
| 217 |
+
if ce_paths is None:
|
| 218 |
+
# If CE paths are not provided, use RGB paths for both inputs
|
| 219 |
+
fpaths = rgb_paths
|
| 220 |
+
else:
|
| 221 |
+
# If CE paths are provided, pair them with RGB paths
|
| 222 |
+
if len(rgb_paths) != len(ce_paths):
|
| 223 |
+
raise ValueError("rgb_paths and ce_paths must have the same length")
|
| 224 |
+
fpaths = list(zip(rgb_paths, ce_paths))
|
| 225 |
+
|
| 226 |
+
dataloader = ensemble_fovea._make_inference_dataloader(
|
| 227 |
+
fpaths,
|
| 228 |
+
ids=ids,
|
| 229 |
+
num_workers=8,
|
| 230 |
+
preprocess=False,
|
| 231 |
+
batch_size=8,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
output_ids, outputs = [], []
|
| 235 |
+
with torch.no_grad():
|
| 236 |
+
for batch in tqdm(dataloader):
|
| 237 |
+
if len(batch) == 0:
|
| 238 |
+
continue
|
| 239 |
+
|
| 240 |
+
im = batch["image"].to(device)
|
| 241 |
+
|
| 242 |
+
# FOVEA DETECTION
|
| 243 |
+
with torch.autocast(device_type=device.type):
|
| 244 |
+
heatmap = ensemble_fovea.forward(im)
|
| 245 |
+
keypoints = extract_keypoints_from_heatmaps(heatmap)
|
| 246 |
+
|
| 247 |
+
kp_fovea = torch.mean(keypoints, dim=0) # average over models
|
| 248 |
+
|
| 249 |
+
items = {
|
| 250 |
+
"id": batch["id"],
|
| 251 |
+
"keypoints": kp_fovea,
|
| 252 |
+
"metadata": batch["metadata"],
|
| 253 |
+
}
|
| 254 |
+
items = decollate_batch(items)
|
| 255 |
+
|
| 256 |
+
items = [dataloader.dataset.transform.undo_item(item) for item in items]
|
| 257 |
+
|
| 258 |
+
for item in items:
|
| 259 |
+
output_ids.append(item["id"])
|
| 260 |
+
outputs.append(
|
| 261 |
+
[
|
| 262 |
+
*item["keypoints"][0].tolist(),
|
| 263 |
+
]
|
| 264 |
+
)
|
| 265 |
+
return pd.DataFrame(
|
| 266 |
+
outputs,
|
| 267 |
+
index=output_ids,
|
| 268 |
+
columns=["x_fovea", "y_fovea"],
|
| 269 |
+
)
|
vascx_models/utils.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Dict, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image, ImageDraw
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def create_fundus_overlay(
|
| 9 |
+
rgb_path: str,
|
| 10 |
+
av_path: Optional[str] = None,
|
| 11 |
+
disc_path: Optional[str] = None,
|
| 12 |
+
fovea_location: Optional[Tuple[int, int]] = None,
|
| 13 |
+
output_path: Optional[str] = None,
|
| 14 |
+
) -> np.ndarray:
|
| 15 |
+
"""
|
| 16 |
+
Create a visualization of a fundus image with overlaid segmentations and markers.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
rgb_path: Path to the RGB fundus image
|
| 20 |
+
av_path: Optional path to artery-vein segmentation (1=artery, 2=vein, 3=intersection)
|
| 21 |
+
disc_path: Optional path to binary disc segmentation
|
| 22 |
+
fovea_location: Optional (x,y) tuple indicating the location of the fovea
|
| 23 |
+
output_path: Optional path to save the visualization image
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
Numpy array containing the visualization image
|
| 27 |
+
"""
|
| 28 |
+
print(rgb_path, av_path, disc_path, fovea_location, output_path)
|
| 29 |
+
# Load RGB image
|
| 30 |
+
rgb_img = np.array(Image.open(rgb_path))
|
| 31 |
+
|
| 32 |
+
# Create output image starting with the RGB image
|
| 33 |
+
output_img = rgb_img.copy()
|
| 34 |
+
|
| 35 |
+
# Load and overlay AV segmentation if provided
|
| 36 |
+
if av_path:
|
| 37 |
+
av_mask = np.array(Image.open(av_path))
|
| 38 |
+
|
| 39 |
+
# Create masks for arteries (1), veins (2) and intersections (3)
|
| 40 |
+
artery_mask = av_mask == 1
|
| 41 |
+
vein_mask = av_mask == 2
|
| 42 |
+
intersection_mask = av_mask == 3
|
| 43 |
+
|
| 44 |
+
# Combine artery and intersection for visualization
|
| 45 |
+
artery_combined = np.logical_or(artery_mask, intersection_mask)
|
| 46 |
+
vein_combined = np.logical_or(vein_mask, intersection_mask)
|
| 47 |
+
|
| 48 |
+
# Apply colors: red for arteries, blue for veins
|
| 49 |
+
# Red channel - increase for arteries
|
| 50 |
+
output_img[artery_combined, 0] = 255
|
| 51 |
+
output_img[artery_combined, 1] = 0
|
| 52 |
+
output_img[artery_combined, 2] = 0
|
| 53 |
+
|
| 54 |
+
# Blue channel - increase for veins
|
| 55 |
+
output_img[vein_combined, 0] = 0
|
| 56 |
+
output_img[vein_combined, 1] = 0
|
| 57 |
+
output_img[vein_combined, 2] = 255
|
| 58 |
+
|
| 59 |
+
# Load and overlay optic disc segmentation if provided
|
| 60 |
+
if disc_path:
|
| 61 |
+
disc_mask = np.array(Image.open(disc_path)) > 0
|
| 62 |
+
|
| 63 |
+
# Apply white color for disc
|
| 64 |
+
output_img[disc_mask, :] = [255, 255, 255] # White
|
| 65 |
+
|
| 66 |
+
# Convert to PIL image for drawing the fovea marker
|
| 67 |
+
pil_img = Image.fromarray(output_img)
|
| 68 |
+
|
| 69 |
+
# Add fovea marker if provided
|
| 70 |
+
if fovea_location:
|
| 71 |
+
draw = ImageDraw.Draw(pil_img)
|
| 72 |
+
x, y = fovea_location
|
| 73 |
+
marker_size = (
|
| 74 |
+
min(pil_img.width, pil_img.height) // 50
|
| 75 |
+
) # Scale marker with image
|
| 76 |
+
|
| 77 |
+
# Draw yellow X at fovea location
|
| 78 |
+
draw.line(
|
| 79 |
+
[(x - marker_size, y - marker_size), (x + marker_size, y + marker_size)],
|
| 80 |
+
fill=(255, 255, 0),
|
| 81 |
+
width=2,
|
| 82 |
+
)
|
| 83 |
+
draw.line(
|
| 84 |
+
[(x - marker_size, y + marker_size), (x + marker_size, y - marker_size)],
|
| 85 |
+
fill=(255, 255, 0),
|
| 86 |
+
width=2,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# Convert back to numpy array
|
| 90 |
+
output_img = np.array(pil_img)
|
| 91 |
+
|
| 92 |
+
# Save output if path provided
|
| 93 |
+
if output_path:
|
| 94 |
+
Image.fromarray(output_img).save(output_path)
|
| 95 |
+
|
| 96 |
+
return output_img
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def batch_create_overlays(
|
| 100 |
+
rgb_dir: Path,
|
| 101 |
+
output_dir: Path,
|
| 102 |
+
av_dir: Optional[Path] = None,
|
| 103 |
+
disc_dir: Optional[Path] = None,
|
| 104 |
+
fovea_data: Optional[Dict[str, Tuple[int, int]]] = None,
|
| 105 |
+
) -> None:
|
| 106 |
+
"""
|
| 107 |
+
Create visualization overlays for a batch of images.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
rgb_dir: Directory containing RGB fundus images
|
| 111 |
+
output_dir: Directory to save visualization images
|
| 112 |
+
av_dir: Optional directory containing AV segmentations
|
| 113 |
+
disc_dir: Optional directory containing disc segmentations
|
| 114 |
+
fovea_data: Optional dictionary mapping image IDs to fovea coordinates
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
List of paths to created visualization images
|
| 118 |
+
"""
|
| 119 |
+
# Create output directory if it doesn't exist
|
| 120 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
| 121 |
+
|
| 122 |
+
# Get all RGB images
|
| 123 |
+
rgb_files = list(rgb_dir.glob("*.png"))
|
| 124 |
+
if not rgb_files:
|
| 125 |
+
return []
|
| 126 |
+
|
| 127 |
+
# Process each image
|
| 128 |
+
for rgb_file in rgb_files:
|
| 129 |
+
image_id = rgb_file.stem
|
| 130 |
+
|
| 131 |
+
# Check for corresponding AV segmentation
|
| 132 |
+
av_file = None
|
| 133 |
+
if av_dir:
|
| 134 |
+
av_file_path = av_dir / f"{image_id}.png"
|
| 135 |
+
if av_file_path.exists():
|
| 136 |
+
av_file = str(av_file_path)
|
| 137 |
+
|
| 138 |
+
# Check for corresponding disc segmentation
|
| 139 |
+
disc_file = None
|
| 140 |
+
if disc_dir:
|
| 141 |
+
disc_file_path = disc_dir / f"{image_id}.png"
|
| 142 |
+
if disc_file_path.exists():
|
| 143 |
+
disc_file = str(disc_file_path)
|
| 144 |
+
|
| 145 |
+
# Get fovea location if available
|
| 146 |
+
fovea_location = None
|
| 147 |
+
if fovea_data and image_id in fovea_data:
|
| 148 |
+
fovea_location = fovea_data[image_id]
|
| 149 |
+
|
| 150 |
+
# Create output path
|
| 151 |
+
output_file = output_dir / f"{image_id}.png"
|
| 152 |
+
|
| 153 |
+
# Create and save overlay
|
| 154 |
+
create_fundus_overlay(
|
| 155 |
+
rgb_path=str(rgb_file),
|
| 156 |
+
av_path=av_file,
|
| 157 |
+
disc_path=disc_file,
|
| 158 |
+
fovea_location=fovea_location,
|
| 159 |
+
output_path=str(output_file),
|
| 160 |
+
)
|