Spaces:
Sleeping
Sleeping
Anirudh Balaraman commited on
Commit ·
caf6ee7
1
Parent(s): c67c387
add ci
Browse files- .devcontainer/devcontainer.json +0 -33
- .github/workflows/ci.yaml +25 -0
- .gitignore +3 -1
- Makefile +30 -0
- preprocess_main.py +10 -9
- pyproject.toml +44 -2
- run_cspca.py +25 -23
- run_inference.py +19 -17
- run_pirads.py +22 -14
- src/data/custom_transforms.py +8 -7
- src/data/data_loader.py +21 -16
- src/model/{csPCa_model.py → cspca_model.py} +5 -4
- src/model/{MIL.py → mil.py} +14 -10
- src/preprocessing/center_crop.py +13 -8
- src/preprocessing/generate_heatmap.py +5 -3
- src/preprocessing/histogram_match.py +5 -3
- src/preprocessing/prostate_mask.py +12 -14
- src/preprocessing/register_and_crop.py +10 -7
- src/train/train_cspca.py +15 -16
- src/train/train_pirads.py +21 -24
- src/utils.py +37 -20
- temp.ipynb +22 -217
- tests/test_run.py +41 -8
.devcontainer/devcontainer.json
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"name": "Python 3",
|
| 3 |
-
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
|
| 4 |
-
"image": "mcr.microsoft.com/devcontainers/python:1-3.11-bookworm",
|
| 5 |
-
"customizations": {
|
| 6 |
-
"codespaces": {
|
| 7 |
-
"openFiles": [
|
| 8 |
-
"README.md",
|
| 9 |
-
"app.py"
|
| 10 |
-
]
|
| 11 |
-
},
|
| 12 |
-
"vscode": {
|
| 13 |
-
"settings": {},
|
| 14 |
-
"extensions": [
|
| 15 |
-
"ms-python.python",
|
| 16 |
-
"ms-python.vscode-pylance"
|
| 17 |
-
]
|
| 18 |
-
}
|
| 19 |
-
},
|
| 20 |
-
"updateContentCommand": "[ -f packages.txt ] && sudo apt update && sudo apt upgrade -y && sudo xargs apt install -y <packages.txt; [ -f requirements.txt ] && pip3 install --user -r requirements.txt; pip3 install --user streamlit; echo '✅ Packages installed and Requirements met'",
|
| 21 |
-
"postAttachCommand": {
|
| 22 |
-
"server": "streamlit run app.py --server.enableCORS false --server.enableXsrfProtection false"
|
| 23 |
-
},
|
| 24 |
-
"portsAttributes": {
|
| 25 |
-
"8501": {
|
| 26 |
-
"label": "Application",
|
| 27 |
-
"onAutoForward": "openPreview"
|
| 28 |
-
}
|
| 29 |
-
},
|
| 30 |
-
"forwardPorts": [
|
| 31 |
-
8501
|
| 32 |
-
]
|
| 33 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.github/workflows/ci.yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: CI
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push
|
| 5 |
+
|
| 6 |
+
jobs:
|
| 7 |
+
quality-assurance:
|
| 8 |
+
runs-on: ubuntu-latest
|
| 9 |
+
steps:
|
| 10 |
+
- name: Checkout code
|
| 11 |
+
uses: actions/checkout@v4
|
| 12 |
+
|
| 13 |
+
- name: Set up Python
|
| 14 |
+
uses: actions/setup-python@v5
|
| 15 |
+
with:
|
| 16 |
+
python-version: '3.9'
|
| 17 |
+
cache: 'pip' # Speeds up subsequent runs
|
| 18 |
+
|
| 19 |
+
- name: Install dependencies
|
| 20 |
+
run: |
|
| 21 |
+
python -m pip install --upgrade pip
|
| 22 |
+
pip install -r requirements.txt
|
| 23 |
+
|
| 24 |
+
- name: Run CI Suite
|
| 25 |
+
run: make check
|
.gitignore
CHANGED
|
@@ -6,4 +6,6 @@ temp.ipynb
|
|
| 6 |
__pycache__/
|
| 7 |
**/__pycache__/
|
| 8 |
*.pyc
|
| 9 |
-
.ruff_cache
|
|
|
|
|
|
|
|
|
| 6 |
__pycache__/
|
| 7 |
**/__pycache__/
|
| 8 |
*.pyc
|
| 9 |
+
.ruff_cache/
|
| 10 |
+
.mypy_cache/
|
| 11 |
+
.pytest_cache/
|
Makefile
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: format lint typecheck test check
|
| 2 |
+
|
| 3 |
+
format:
|
| 4 |
+
ruff format .
|
| 5 |
+
|
| 6 |
+
lint:
|
| 7 |
+
ruff check . --fix
|
| 8 |
+
|
| 9 |
+
typecheck:
|
| 10 |
+
mypy .
|
| 11 |
+
|
| 12 |
+
test:
|
| 13 |
+
pytest
|
| 14 |
+
|
| 15 |
+
clean:
|
| 16 |
+
@echo "Cleaning project..."
|
| 17 |
+
# Delete compiled bytecode
|
| 18 |
+
@python3 -Bc "import pathlib; [p.unlink() for p in pathlib.Path('.').rglob('*.py[co]')]"
|
| 19 |
+
# Delete directory-based caches
|
| 20 |
+
@python3 -Bc "import shutil, pathlib; \
|
| 21 |
+
[shutil.rmtree(p) for p in pathlib.Path('.').rglob('__pycache__')]; \
|
| 22 |
+
[shutil.rmtree(p) for p in pathlib.Path('.').rglob('.ipynb_checkpoints')]; \
|
| 23 |
+
[shutil.rmtree(p) for p in pathlib.Path('.').rglob('.monai-cache')]; \
|
| 24 |
+
[shutil.rmtree(p) for p in pathlib.Path('.').rglob('.mypy_cache')]; \
|
| 25 |
+
[shutil.rmtree(p) for p in pathlib.Path('.').rglob('.ruff_cache')]; \
|
| 26 |
+
[shutil.rmtree(p) for p in pathlib.Path('.').rglob('.pytest_cache')]"
|
| 27 |
+
|
| 28 |
+
# Updated 'check' to clean before running (optional)
|
| 29 |
+
# This ensures you are testing from a "blank slate"
|
| 30 |
+
check: format lint typecheck test clean
|
preprocess_main.py
CHANGED
|
@@ -1,14 +1,15 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from src.preprocessing.register_and_crop import register_files
|
| 3 |
-
from src.preprocessing.prostate_mask import get_segmask
|
| 4 |
-
from src.preprocessing.histogram_match import histmatch
|
| 5 |
-
from src.preprocessing.generate_heatmap import get_heatmap
|
| 6 |
-
import logging
|
| 7 |
-
from src.utils import setup_logging
|
| 8 |
-
from src.utils import validate_steps
|
| 9 |
import argparse
|
|
|
|
|
|
|
|
|
|
| 10 |
import yaml
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
def parse_args():
|
| 14 |
parser = argparse.ArgumentParser(description="File preprocessing")
|
|
@@ -37,7 +38,7 @@ def parse_args():
|
|
| 37 |
|
| 38 |
args = parser.parse_args()
|
| 39 |
if args.config:
|
| 40 |
-
with open(args.config
|
| 41 |
config = yaml.safe_load(config_file)
|
| 42 |
args.__dict__.update(config)
|
| 43 |
return args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import argparse
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
import yaml
|
| 6 |
|
| 7 |
+
from src.preprocessing.generate_heatmap import get_heatmap
|
| 8 |
+
from src.preprocessing.histogram_match import histmatch
|
| 9 |
+
from src.preprocessing.prostate_mask import get_segmask
|
| 10 |
+
from src.preprocessing.register_and_crop import register_files
|
| 11 |
+
from src.utils import setup_logging, validate_steps
|
| 12 |
+
|
| 13 |
|
| 14 |
def parse_args():
|
| 15 |
parser = argparse.ArgumentParser(description="File preprocessing")
|
|
|
|
| 38 |
|
| 39 |
args = parser.parse_args()
|
| 40 |
if args.config:
|
| 41 |
+
with open(args.config) as config_file:
|
| 42 |
config = yaml.safe_load(config_file)
|
| 43 |
args.__dict__.update(config)
|
| 44 |
return args
|
pyproject.toml
CHANGED
|
@@ -1,9 +1,51 @@
|
|
| 1 |
[tool.ruff]
|
| 2 |
line-length = 100
|
|
|
|
| 3 |
|
| 4 |
[tool.ruff.lint]
|
| 5 |
-
select = ["E", "W"]
|
| 6 |
ignore = ["E501"]
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
[tool.ruff.format]
|
| 9 |
-
quote-style = "double"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
[tool.ruff]
|
| 2 |
line-length = 100
|
| 3 |
+
target-version = "py39" # Ensures ruff uses 3.9 compatible syntax
|
| 4 |
|
| 5 |
[tool.ruff.lint]
|
| 6 |
+
select = ["E", "W", "F", "I", "B", "N", "UP"]
|
| 7 |
ignore = ["E501"]
|
| 8 |
|
| 9 |
+
[tool.ruff.lint.isort]
|
| 10 |
+
# This makes your imports look like a professional library
|
| 11 |
+
combine-as-imports = true
|
| 12 |
+
force-wrap-aliases = true
|
| 13 |
+
known-first-party = ["src"] # Treats your 'src' folder as a local package
|
| 14 |
+
|
| 15 |
[tool.ruff.format]
|
| 16 |
+
quote-style = "double"
|
| 17 |
+
indent-style = "space"
|
| 18 |
+
skip-magic-trailing-comma = false
|
| 19 |
+
|
| 20 |
+
[tool.ruff.lint.pep8-naming]
|
| 21 |
+
# Add your custom class name here
|
| 22 |
+
ignore-names = ["sitk", "NormalizeIntensity_custom", "NormalizeIntensity_customd"]
|
| 23 |
+
|
| 24 |
+
[tool.mypy]
|
| 25 |
+
ignore_missing_imports = true
|
| 26 |
+
disable_error_code = ["override", "import-untyped"]
|
| 27 |
+
mypy_path = "."
|
| 28 |
+
pretty = true
|
| 29 |
+
show_error_codes = true
|
| 30 |
+
|
| 31 |
+
[[tool.mypy.overrides]]
|
| 32 |
+
# These settings apply specifically to these external libraries
|
| 33 |
+
module = [
|
| 34 |
+
"yaml.*",
|
| 35 |
+
"nrrd.*",
|
| 36 |
+
"nibabel.*",
|
| 37 |
+
"scipy.*",
|
| 38 |
+
"sklearn.*"
|
| 39 |
+
]
|
| 40 |
+
ignore_missing_imports = true
|
| 41 |
+
|
| 42 |
+
[tool.pytest.ini_options]
|
| 43 |
+
# Automatically adds these flags every time you run 'pytest'
|
| 44 |
+
addopts = "-v --showlocals --durations=5"
|
| 45 |
+
|
| 46 |
+
# Where pytest should look for tests
|
| 47 |
+
testpaths = ["tests"]
|
| 48 |
+
|
| 49 |
+
# Patterns to identify test files
|
| 50 |
+
python_files = "test_*.py"
|
| 51 |
+
python_functions = "test_*"
|
run_cspca.py
CHANGED
|
@@ -1,22 +1,23 @@
|
|
| 1 |
import argparse
|
|
|
|
| 2 |
import os
|
| 3 |
import shutil
|
| 4 |
-
import yaml
|
| 5 |
import sys
|
| 6 |
-
import torch
|
| 7 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
| 8 |
from monai.utils import set_determinism
|
| 9 |
-
|
| 10 |
-
from src.model.MIL import MILModel_3D
|
| 11 |
-
from src.model.csPCa_model import csPCa_Model
|
| 12 |
from src.data.data_loader import get_dataloader
|
| 13 |
-
from src.
|
|
|
|
| 14 |
from src.train.train_cspca import train_epoch, val_epoch
|
| 15 |
-
import
|
| 16 |
|
| 17 |
|
| 18 |
def main_worker(args):
|
| 19 |
-
mil_model =
|
| 20 |
cache_dir_path = Path(os.path.join(args.logdir, "cache"))
|
| 21 |
|
| 22 |
if args.mode == "train":
|
|
@@ -31,7 +32,7 @@ def main_worker(args):
|
|
| 31 |
|
| 32 |
train_loader = get_dataloader(args, split="train")
|
| 33 |
valid_loader = get_dataloader(args, split="test")
|
| 34 |
-
cspca_model =
|
| 35 |
for submodule in [
|
| 36 |
cspca_model.backbone.net,
|
| 37 |
cspca_model.backbone.myfc,
|
|
@@ -45,23 +46,17 @@ def main_worker(args):
|
|
| 45 |
)
|
| 46 |
|
| 47 |
old_loss = float("inf")
|
| 48 |
-
old_auc = 0.0
|
| 49 |
for epoch in range(args.epochs):
|
| 50 |
train_loss, train_auc = train_epoch(
|
| 51 |
cspca_model, train_loader, optimizer, epoch=epoch, args=args
|
| 52 |
)
|
| 53 |
-
logging.info(
|
| 54 |
-
f"EPOCH {epoch} TRAIN loss: {train_loss:.4f} AUC: {train_auc:.4f}"
|
| 55 |
-
)
|
| 56 |
val_metric = val_epoch(cspca_model, valid_loader, epoch=epoch, args=args)
|
| 57 |
logging.info(
|
| 58 |
f"EPOCH {epoch} VAL loss: {val_metric['loss']:.4f} AUC: {val_metric['auc']:.4f}"
|
| 59 |
)
|
| 60 |
if val_metric["loss"] < old_loss:
|
| 61 |
old_loss = val_metric["loss"]
|
| 62 |
-
old_auc = val_metric["auc"]
|
| 63 |
-
sensitivity = val_metric["sensitivity"]
|
| 64 |
-
specificity = val_metric["specificity"]
|
| 65 |
save_cspca_checkpoint(cspca_model, val_metric, model_dir)
|
| 66 |
|
| 67 |
args.checkpoint_cspca = os.path.join(model_dir, "cspca_model.pth")
|
|
@@ -69,7 +64,7 @@ def main_worker(args):
|
|
| 69 |
shutil.rmtree(cache_dir_path)
|
| 70 |
|
| 71 |
|
| 72 |
-
cspca_model =
|
| 73 |
checkpt = torch.load(args.checkpoint_cspca, map_location="cpu")
|
| 74 |
cspca_model.load_state_dict(checkpt["state_dict"])
|
| 75 |
cspca_model = cspca_model.to(args.device)
|
|
@@ -96,6 +91,7 @@ def main_worker(args):
|
|
| 96 |
get_metrics(metrics_dict)
|
| 97 |
|
| 98 |
|
|
|
|
| 99 |
def parse_args():
|
| 100 |
parser = argparse.ArgumentParser(
|
| 101 |
description="Multiple Instance Learning (MIL) for csPCa risk prediction."
|
|
@@ -164,7 +160,7 @@ def parse_args():
|
|
| 164 |
)
|
| 165 |
args = parser.parse_args()
|
| 166 |
if args.config:
|
| 167 |
-
with open(args.config
|
| 168 |
config = yaml.safe_load(config_file)
|
| 169 |
args.__dict__.update(config)
|
| 170 |
|
|
@@ -172,12 +168,13 @@ def parse_args():
|
|
| 172 |
|
| 173 |
|
| 174 |
if __name__ == "__main__":
|
| 175 |
-
|
| 176 |
args = parse_args()
|
| 177 |
if args.project_dir is None:
|
| 178 |
-
args.project_dir = Path(__file__).resolve().parent
|
| 179 |
|
| 180 |
-
slurm_job_name = os.getenv(
|
|
|
|
|
|
|
| 181 |
if slurm_job_name:
|
| 182 |
args.run_name = slurm_job_name
|
| 183 |
|
|
@@ -207,10 +204,15 @@ if __name__ == "__main__":
|
|
| 207 |
|
| 208 |
if args.dry_run:
|
| 209 |
logging.info("Dry run mode enabled.")
|
| 210 |
-
args.epochs =
|
| 211 |
args.batch_size = 2
|
| 212 |
args.workers = 0
|
| 213 |
-
args.num_seeds =
|
| 214 |
args.wandb = False
|
|
|
|
|
|
|
| 215 |
|
| 216 |
main_worker(args)
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import argparse
|
| 2 |
+
import logging
|
| 3 |
import os
|
| 4 |
import shutil
|
|
|
|
| 5 |
import sys
|
|
|
|
| 6 |
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import yaml
|
| 10 |
from monai.utils import set_determinism
|
| 11 |
+
|
|
|
|
|
|
|
| 12 |
from src.data.data_loader import get_dataloader
|
| 13 |
+
from src.model.cspca_model import CSPCAModel
|
| 14 |
+
from src.model.mil import MILModel3D
|
| 15 |
from src.train.train_cspca import train_epoch, val_epoch
|
| 16 |
+
from src.utils import get_metrics, save_cspca_checkpoint, setup_logging
|
| 17 |
|
| 18 |
|
| 19 |
def main_worker(args):
|
| 20 |
+
mil_model = MILModel3D(num_classes=args.num_classes, mil_mode=args.mil_mode)
|
| 21 |
cache_dir_path = Path(os.path.join(args.logdir, "cache"))
|
| 22 |
|
| 23 |
if args.mode == "train":
|
|
|
|
| 32 |
|
| 33 |
train_loader = get_dataloader(args, split="train")
|
| 34 |
valid_loader = get_dataloader(args, split="test")
|
| 35 |
+
cspca_model = CSPCAModel(backbone=mil_model).to(args.device)
|
| 36 |
for submodule in [
|
| 37 |
cspca_model.backbone.net,
|
| 38 |
cspca_model.backbone.myfc,
|
|
|
|
| 46 |
)
|
| 47 |
|
| 48 |
old_loss = float("inf")
|
|
|
|
| 49 |
for epoch in range(args.epochs):
|
| 50 |
train_loss, train_auc = train_epoch(
|
| 51 |
cspca_model, train_loader, optimizer, epoch=epoch, args=args
|
| 52 |
)
|
| 53 |
+
logging.info(f"EPOCH {epoch} TRAIN loss: {train_loss:.4f} AUC: {train_auc:.4f}")
|
|
|
|
|
|
|
| 54 |
val_metric = val_epoch(cspca_model, valid_loader, epoch=epoch, args=args)
|
| 55 |
logging.info(
|
| 56 |
f"EPOCH {epoch} VAL loss: {val_metric['loss']:.4f} AUC: {val_metric['auc']:.4f}"
|
| 57 |
)
|
| 58 |
if val_metric["loss"] < old_loss:
|
| 59 |
old_loss = val_metric["loss"]
|
|
|
|
|
|
|
|
|
|
| 60 |
save_cspca_checkpoint(cspca_model, val_metric, model_dir)
|
| 61 |
|
| 62 |
args.checkpoint_cspca = os.path.join(model_dir, "cspca_model.pth")
|
|
|
|
| 64 |
shutil.rmtree(cache_dir_path)
|
| 65 |
|
| 66 |
|
| 67 |
+
cspca_model = CSPCAModel(backbone=mil_model).to(args.device)
|
| 68 |
checkpt = torch.load(args.checkpoint_cspca, map_location="cpu")
|
| 69 |
cspca_model.load_state_dict(checkpt["state_dict"])
|
| 70 |
cspca_model = cspca_model.to(args.device)
|
|
|
|
| 91 |
get_metrics(metrics_dict)
|
| 92 |
|
| 93 |
|
| 94 |
+
|
| 95 |
def parse_args():
|
| 96 |
parser = argparse.ArgumentParser(
|
| 97 |
description="Multiple Instance Learning (MIL) for csPCa risk prediction."
|
|
|
|
| 160 |
)
|
| 161 |
args = parser.parse_args()
|
| 162 |
if args.config:
|
| 163 |
+
with open(args.config) as config_file:
|
| 164 |
config = yaml.safe_load(config_file)
|
| 165 |
args.__dict__.update(config)
|
| 166 |
|
|
|
|
| 168 |
|
| 169 |
|
| 170 |
if __name__ == "__main__":
|
|
|
|
| 171 |
args = parse_args()
|
| 172 |
if args.project_dir is None:
|
| 173 |
+
args.project_dir = Path(__file__).resolve().parent # Set project directory
|
| 174 |
|
| 175 |
+
slurm_job_name = os.getenv(
|
| 176 |
+
"SLURM_JOB_NAME"
|
| 177 |
+
) # If the script is submitted via slurm, job name is the run name
|
| 178 |
if slurm_job_name:
|
| 179 |
args.run_name = slurm_job_name
|
| 180 |
|
|
|
|
| 204 |
|
| 205 |
if args.dry_run:
|
| 206 |
logging.info("Dry run mode enabled.")
|
| 207 |
+
args.epochs = 1
|
| 208 |
args.batch_size = 2
|
| 209 |
args.workers = 0
|
| 210 |
+
args.num_seeds = 1
|
| 211 |
args.wandb = False
|
| 212 |
+
args.tile_size = 10
|
| 213 |
+
args.tile_count = 5
|
| 214 |
|
| 215 |
main_worker(args)
|
| 216 |
+
|
| 217 |
+
if args.dry_run:
|
| 218 |
+
shutil.rmtree(args.logdir)
|
run_inference.py
CHANGED
|
@@ -1,18 +1,20 @@
|
|
| 1 |
import argparse
|
|
|
|
|
|
|
| 2 |
import os
|
| 3 |
-
|
| 4 |
import torch
|
| 5 |
-
import
|
| 6 |
-
from src.model.MIL import MILModel_3D
|
| 7 |
-
from src.model.csPCa_model import csPCa_Model
|
| 8 |
-
from src.utils import setup_logging, get_parent_image, get_patch_coordinate
|
| 9 |
-
from src.preprocessing.register_and_crop import register_files
|
| 10 |
-
from src.preprocessing.prostate_mask import get_segmask
|
| 11 |
-
from src.preprocessing.histogram_match import histmatch
|
| 12 |
-
from src.preprocessing.generate_heatmap import get_heatmap
|
| 13 |
-
from src.data.data_loader import data_transform, list_data_collate
|
| 14 |
from monai.data import Dataset
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
def parse_args():
|
|
@@ -36,7 +38,7 @@ def parse_args():
|
|
| 36 |
|
| 37 |
args = parser.parse_args()
|
| 38 |
if args.config:
|
| 39 |
-
with open(args.config
|
| 40 |
config = yaml.safe_load(config_file)
|
| 41 |
args.__dict__.update(config)
|
| 42 |
return args
|
|
@@ -64,14 +66,14 @@ if __name__ == "__main__":
|
|
| 64 |
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 65 |
|
| 66 |
logging.info("Loading PIRADS model")
|
| 67 |
-
pirads_model =
|
| 68 |
pirads_checkpoint = torch.load(
|
| 69 |
os.path.join(args.project_dir, "models", "pirads.pt"), map_location="cpu"
|
| 70 |
)
|
| 71 |
pirads_model.load_state_dict(pirads_checkpoint["state_dict"])
|
| 72 |
pirads_model.to(args.device)
|
| 73 |
logging.info("Loading csPCa model")
|
| 74 |
-
cspca_model =
|
| 75 |
checkpt = torch.load(
|
| 76 |
os.path.join(args.project_dir, "models", "cspca_model.pth"), map_location="cpu"
|
| 77 |
)
|
|
@@ -109,7 +111,7 @@ if __name__ == "__main__":
|
|
| 109 |
cspca_model.eval()
|
| 110 |
patches_top_5_list = []
|
| 111 |
with torch.no_grad():
|
| 112 |
-
for
|
| 113 |
data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
|
| 114 |
logits = pirads_model(data)
|
| 115 |
pirads_score = torch.argmax(logits, dim=1)
|
|
@@ -137,10 +139,10 @@ if __name__ == "__main__":
|
|
| 137 |
patches_top_5.append(patch_temp)
|
| 138 |
patches_top_5_list.append(patches_top_5)
|
| 139 |
coords_list = []
|
| 140 |
-
for i in args.data_list:
|
| 141 |
parent_image = get_parent_image([i], args)
|
| 142 |
|
| 143 |
-
coords = get_patch_coordinate(
|
| 144 |
coords_list.append(coords)
|
| 145 |
output_dict = {}
|
| 146 |
|
|
|
|
| 1 |
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
import os
|
| 5 |
+
|
| 6 |
import torch
|
| 7 |
+
import yaml
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
from monai.data import Dataset
|
| 9 |
+
|
| 10 |
+
from src.data.data_loader import data_transform, list_data_collate
|
| 11 |
+
from src.model.cspca_model import CSPCAModel
|
| 12 |
+
from src.model.mil import MILModel3D
|
| 13 |
+
from src.preprocessing.generate_heatmap import get_heatmap
|
| 14 |
+
from src.preprocessing.histogram_match import histmatch
|
| 15 |
+
from src.preprocessing.prostate_mask import get_segmask
|
| 16 |
+
from src.preprocessing.register_and_crop import register_files
|
| 17 |
+
from src.utils import get_parent_image, get_patch_coordinate, setup_logging
|
| 18 |
|
| 19 |
|
| 20 |
def parse_args():
|
|
|
|
| 38 |
|
| 39 |
args = parser.parse_args()
|
| 40 |
if args.config:
|
| 41 |
+
with open(args.config) as config_file:
|
| 42 |
config = yaml.safe_load(config_file)
|
| 43 |
args.__dict__.update(config)
|
| 44 |
return args
|
|
|
|
| 66 |
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 67 |
|
| 68 |
logging.info("Loading PIRADS model")
|
| 69 |
+
pirads_model = MILModel3D(num_classes=args.num_classes, mil_mode=args.mil_mode)
|
| 70 |
pirads_checkpoint = torch.load(
|
| 71 |
os.path.join(args.project_dir, "models", "pirads.pt"), map_location="cpu"
|
| 72 |
)
|
| 73 |
pirads_model.load_state_dict(pirads_checkpoint["state_dict"])
|
| 74 |
pirads_model.to(args.device)
|
| 75 |
logging.info("Loading csPCa model")
|
| 76 |
+
cspca_model = CSPCAModel(backbone=pirads_model).to(args.device)
|
| 77 |
checkpt = torch.load(
|
| 78 |
os.path.join(args.project_dir, "models", "cspca_model.pth"), map_location="cpu"
|
| 79 |
)
|
|
|
|
| 111 |
cspca_model.eval()
|
| 112 |
patches_top_5_list = []
|
| 113 |
with torch.no_grad():
|
| 114 |
+
for _, batch_data in enumerate(loader):
|
| 115 |
data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
|
| 116 |
logits = pirads_model(data)
|
| 117 |
pirads_score = torch.argmax(logits, dim=1)
|
|
|
|
| 139 |
patches_top_5.append(patch_temp)
|
| 140 |
patches_top_5_list.append(patches_top_5)
|
| 141 |
coords_list = []
|
| 142 |
+
for j, i in enumerate(args.data_list):
|
| 143 |
parent_image = get_parent_image([i], args)
|
| 144 |
|
| 145 |
+
coords = get_patch_coordinate(patches_top_5_list[j], parent_image)
|
| 146 |
coords_list.append(coords)
|
| 147 |
output_dict = {}
|
| 148 |
|
run_pirads.py
CHANGED
|
@@ -1,19 +1,21 @@
|
|
| 1 |
import argparse
|
|
|
|
| 2 |
import os
|
| 3 |
import shutil
|
| 4 |
-
import time
|
| 5 |
-
import yaml
|
| 6 |
import sys
|
|
|
|
|
|
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
import torch
|
| 9 |
-
from torch.utils.tensorboard import SummaryWriter
|
| 10 |
-
from monai.utils import set_determinism
|
| 11 |
import wandb
|
| 12 |
-
import
|
| 13 |
-
from
|
|
|
|
|
|
|
| 14 |
from src.data.data_loader import get_dataloader
|
|
|
|
| 15 |
from src.train.train_pirads import train_epoch, val_epoch
|
| 16 |
-
from src.model.MIL import MILModel_3D
|
| 17 |
from src.utils import save_pirads_checkpoint, setup_logging
|
| 18 |
|
| 19 |
|
|
@@ -21,7 +23,7 @@ def main_worker(args):
|
|
| 21 |
if args.device == torch.device("cuda"):
|
| 22 |
torch.backends.cudnn.benchmark = True
|
| 23 |
|
| 24 |
-
model =
|
| 25 |
start_epoch = 0
|
| 26 |
best_acc = 0.0
|
| 27 |
if args.checkpoint is not None:
|
|
@@ -250,19 +252,20 @@ def parse_args():
|
|
| 250 |
)
|
| 251 |
args = parser.parse_args()
|
| 252 |
if args.config:
|
| 253 |
-
with open(args.config
|
| 254 |
config = yaml.safe_load(config_file)
|
| 255 |
args.__dict__.update(config)
|
| 256 |
return args
|
| 257 |
|
| 258 |
|
| 259 |
if __name__ == "__main__":
|
| 260 |
-
|
| 261 |
args = parse_args()
|
| 262 |
if args.project_dir is None:
|
| 263 |
-
args.project_dir = Path(__file__).resolve().parent
|
| 264 |
|
| 265 |
-
slurm_job_name = os.getenv(
|
|
|
|
|
|
|
| 266 |
if slurm_job_name:
|
| 267 |
args.run_name = slurm_job_name
|
| 268 |
|
|
@@ -288,11 +291,13 @@ if __name__ == "__main__":
|
|
| 288 |
|
| 289 |
if args.dry_run:
|
| 290 |
logging.info("Dry run mode enabled.")
|
| 291 |
-
args.epochs =
|
| 292 |
args.batch_size = 2
|
| 293 |
args.workers = 0
|
| 294 |
-
args.num_seeds =
|
| 295 |
args.wandb = False
|
|
|
|
|
|
|
| 296 |
|
| 297 |
mode_wandb = "online" if args.wandb and args.mode != "test" else "disabled"
|
| 298 |
|
|
@@ -314,3 +319,6 @@ if __name__ == "__main__":
|
|
| 314 |
main_worker(args)
|
| 315 |
|
| 316 |
wandb.finish()
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import argparse
|
| 2 |
+
import logging
|
| 3 |
import os
|
| 4 |
import shutil
|
|
|
|
|
|
|
| 5 |
import sys
|
| 6 |
+
import time
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
import numpy as np
|
| 10 |
import torch
|
|
|
|
|
|
|
| 11 |
import wandb
|
| 12 |
+
import yaml
|
| 13 |
+
from monai.utils import set_determinism
|
| 14 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 15 |
+
|
| 16 |
from src.data.data_loader import get_dataloader
|
| 17 |
+
from src.model.mil import MILModel3D
|
| 18 |
from src.train.train_pirads import train_epoch, val_epoch
|
|
|
|
| 19 |
from src.utils import save_pirads_checkpoint, setup_logging
|
| 20 |
|
| 21 |
|
|
|
|
| 23 |
if args.device == torch.device("cuda"):
|
| 24 |
torch.backends.cudnn.benchmark = True
|
| 25 |
|
| 26 |
+
model = MILModel3D(num_classes=args.num_classes, mil_mode=args.mil_mode)
|
| 27 |
start_epoch = 0
|
| 28 |
best_acc = 0.0
|
| 29 |
if args.checkpoint is not None:
|
|
|
|
| 252 |
)
|
| 253 |
args = parser.parse_args()
|
| 254 |
if args.config:
|
| 255 |
+
with open(args.config) as config_file:
|
| 256 |
config = yaml.safe_load(config_file)
|
| 257 |
args.__dict__.update(config)
|
| 258 |
return args
|
| 259 |
|
| 260 |
|
| 261 |
if __name__ == "__main__":
|
|
|
|
| 262 |
args = parse_args()
|
| 263 |
if args.project_dir is None:
|
| 264 |
+
args.project_dir = Path(__file__).resolve().parent # Set project directory
|
| 265 |
|
| 266 |
+
slurm_job_name = os.getenv(
|
| 267 |
+
"SLURM_JOB_NAME"
|
| 268 |
+
) # If the script is submitted via slurm, job name is the run name
|
| 269 |
if slurm_job_name:
|
| 270 |
args.run_name = slurm_job_name
|
| 271 |
|
|
|
|
| 291 |
|
| 292 |
if args.dry_run:
|
| 293 |
logging.info("Dry run mode enabled.")
|
| 294 |
+
args.epochs = 1
|
| 295 |
args.batch_size = 2
|
| 296 |
args.workers = 0
|
| 297 |
+
args.num_seeds = 1
|
| 298 |
args.wandb = False
|
| 299 |
+
args.tile_size = 10
|
| 300 |
+
args.tile_count = 5
|
| 301 |
|
| 302 |
mode_wandb = "online" if args.wandb and args.mode != "test" else "disabled"
|
| 303 |
|
|
|
|
| 319 |
main_worker(args)
|
| 320 |
|
| 321 |
wandb.finish()
|
| 322 |
+
|
| 323 |
+
if args.dry_run:
|
| 324 |
+
shutil.rmtree(args.logdir)
|
src/data/custom_transforms.py
CHANGED
|
@@ -1,18 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import torch
|
| 3 |
-
from typing import Union
|
| 4 |
-
from monai.transforms import MapTransform
|
| 5 |
from monai.config import DtypeLike, KeysCollection
|
| 6 |
from monai.config.type_definitions import NdarrayOrTensor
|
| 7 |
from monai.data.meta_obj import get_track_meta
|
|
|
|
| 8 |
from monai.transforms.transform import Transform
|
| 9 |
from monai.transforms.utils import soft_clip
|
| 10 |
from monai.transforms.utils_pytorch_numpy_unification import clip, percentile
|
| 11 |
from monai.utils.enums import TransformBackends
|
| 12 |
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor
|
| 13 |
from scipy.ndimage import binary_dilation
|
| 14 |
-
import cv2
|
| 15 |
-
from collections.abc import Hashable, Mapping, Sequence
|
| 16 |
|
| 17 |
|
| 18 |
class DilateAndSaveMaskd(MapTransform):
|
|
@@ -100,7 +101,7 @@ class ClipMaskIntensityPercentiles(Transform):
|
|
| 100 |
self.channel_wise = channel_wise
|
| 101 |
self.dtype = dtype
|
| 102 |
|
| 103 |
-
def _clip(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor) ->
|
| 104 |
masked_img = img * (mask_data > 0)
|
| 105 |
if self.sharpness_factor is not None:
|
| 106 |
lower_percentile = (
|
|
@@ -125,8 +126,8 @@ class ClipMaskIntensityPercentiles(Transform):
|
|
| 125 |
)
|
| 126 |
img = clip(img, lower_percentile, upper_percentile)
|
| 127 |
|
| 128 |
-
|
| 129 |
-
return
|
| 130 |
|
| 131 |
def __call__(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor) -> NdarrayOrTensor:
|
| 132 |
"""
|
|
|
|
| 1 |
+
from collections.abc import Hashable, Mapping, Sequence
|
| 2 |
+
from typing import Union
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
|
|
|
|
|
|
| 7 |
from monai.config import DtypeLike, KeysCollection
|
| 8 |
from monai.config.type_definitions import NdarrayOrTensor
|
| 9 |
from monai.data.meta_obj import get_track_meta
|
| 10 |
+
from monai.transforms import MapTransform
|
| 11 |
from monai.transforms.transform import Transform
|
| 12 |
from monai.transforms.utils import soft_clip
|
| 13 |
from monai.transforms.utils_pytorch_numpy_unification import clip, percentile
|
| 14 |
from monai.utils.enums import TransformBackends
|
| 15 |
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor
|
| 16 |
from scipy.ndimage import binary_dilation
|
|
|
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
class DilateAndSaveMaskd(MapTransform):
|
|
|
|
| 101 |
self.channel_wise = channel_wise
|
| 102 |
self.dtype = dtype
|
| 103 |
|
| 104 |
+
def _clip(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor) -> torch.Tensor:
|
| 105 |
masked_img = img * (mask_data > 0)
|
| 106 |
if self.sharpness_factor is not None:
|
| 107 |
lower_percentile = (
|
|
|
|
| 126 |
)
|
| 127 |
img = clip(img, lower_percentile, upper_percentile)
|
| 128 |
|
| 129 |
+
img_tensor = convert_to_tensor(img, track_meta=False)
|
| 130 |
+
return img_tensor
|
| 131 |
|
| 132 |
def __call__(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor) -> NdarrayOrTensor:
|
| 133 |
"""
|
src/data/data_loader.py
CHANGED
|
@@ -1,30 +1,33 @@
|
|
|
|
|
| 1 |
import os
|
|
|
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
-
|
|
|
|
| 4 |
from monai.transforms import (
|
| 5 |
Compose,
|
| 6 |
-
LoadImaged,
|
| 7 |
-
ToTensord,
|
| 8 |
ConcatItemsd,
|
| 9 |
DeleteItemsd,
|
| 10 |
EnsureTyped,
|
| 11 |
-
|
| 12 |
NormalizeIntensityd,
|
| 13 |
-
|
| 14 |
RandWeightedCropd,
|
|
|
|
|
|
|
|
|
|
| 15 |
)
|
|
|
|
|
|
|
| 16 |
from .custom_transforms import (
|
| 17 |
-
NormalizeIntensity_customd,
|
| 18 |
ClipMaskIntensityPercentilesd,
|
| 19 |
ElementwiseProductd,
|
|
|
|
| 20 |
)
|
| 21 |
-
import torch
|
| 22 |
-
from torch.utils.data.dataloader import default_collate
|
| 23 |
-
from typing import Literal
|
| 24 |
-
import collections.abc
|
| 25 |
|
| 26 |
|
| 27 |
-
def list_data_collate(batch:
|
| 28 |
"""
|
| 29 |
Combine instances from a list of dicts into a single dict, by stacking them along first dim
|
| 30 |
[{'image' : 3xHxW}, {'image' : 3xHxW}, {'image' : 3xHxW}...] - > {'image' : Nx3xHxW}
|
|
@@ -42,13 +45,13 @@ def list_data_collate(batch: collections.abc.Sequence):
|
|
| 42 |
return default_collate(batch)
|
| 43 |
|
| 44 |
|
| 45 |
-
def data_transform(args):
|
| 46 |
if args.use_heatmap:
|
| 47 |
transform = Compose(
|
| 48 |
[
|
| 49 |
LoadImaged(
|
| 50 |
keys=["image", "mask", "dwi", "adc", "heatmap"],
|
| 51 |
-
reader=ITKReader
|
| 52 |
ensure_channel_first=True,
|
| 53 |
dtype=np.float32,
|
| 54 |
),
|
|
@@ -75,7 +78,7 @@ def data_transform(args):
|
|
| 75 |
[
|
| 76 |
LoadImaged(
|
| 77 |
keys=["image", "mask", "dwi", "adc"],
|
| 78 |
-
reader=ITKReader
|
| 79 |
ensure_channel_first=True,
|
| 80 |
dtype=np.float32,
|
| 81 |
),
|
|
@@ -101,14 +104,16 @@ def data_transform(args):
|
|
| 101 |
return transform
|
| 102 |
|
| 103 |
|
| 104 |
-
def get_dataloader(
|
|
|
|
|
|
|
| 105 |
data_list = load_decathlon_datalist(
|
| 106 |
data_list_file_path=args.dataset_json,
|
| 107 |
data_list_key=split,
|
| 108 |
base_dir=args.data_root,
|
| 109 |
)
|
| 110 |
if args.dry_run:
|
| 111 |
-
data_list = data_list[:
|
| 112 |
cache_dir_ = os.path.join(args.logdir, "cache")
|
| 113 |
os.makedirs(os.path.join(cache_dir_, split), exist_ok=True)
|
| 114 |
transform = data_transform(args)
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
import os
|
| 3 |
+
from typing import Literal
|
| 4 |
+
|
| 5 |
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from monai.data import PersistentDataset, load_decathlon_datalist
|
| 8 |
from monai.transforms import (
|
| 9 |
Compose,
|
|
|
|
|
|
|
| 10 |
ConcatItemsd,
|
| 11 |
DeleteItemsd,
|
| 12 |
EnsureTyped,
|
| 13 |
+
LoadImaged,
|
| 14 |
NormalizeIntensityd,
|
| 15 |
+
RandCropByPosNegLabeld,
|
| 16 |
RandWeightedCropd,
|
| 17 |
+
ToTensord,
|
| 18 |
+
Transform,
|
| 19 |
+
Transposed,
|
| 20 |
)
|
| 21 |
+
from torch.utils.data.dataloader import default_collate
|
| 22 |
+
|
| 23 |
from .custom_transforms import (
|
|
|
|
| 24 |
ClipMaskIntensityPercentilesd,
|
| 25 |
ElementwiseProductd,
|
| 26 |
+
NormalizeIntensity_customd,
|
| 27 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
+
def list_data_collate(batch: list):
|
| 31 |
"""
|
| 32 |
Combine instances from a list of dicts into a single dict, by stacking them along first dim
|
| 33 |
[{'image' : 3xHxW}, {'image' : 3xHxW}, {'image' : 3xHxW}...] - > {'image' : Nx3xHxW}
|
|
|
|
| 45 |
return default_collate(batch)
|
| 46 |
|
| 47 |
|
| 48 |
+
def data_transform(args: argparse.Namespace) -> Transform:
|
| 49 |
if args.use_heatmap:
|
| 50 |
transform = Compose(
|
| 51 |
[
|
| 52 |
LoadImaged(
|
| 53 |
keys=["image", "mask", "dwi", "adc", "heatmap"],
|
| 54 |
+
reader="ITKReader",
|
| 55 |
ensure_channel_first=True,
|
| 56 |
dtype=np.float32,
|
| 57 |
),
|
|
|
|
| 78 |
[
|
| 79 |
LoadImaged(
|
| 80 |
keys=["image", "mask", "dwi", "adc"],
|
| 81 |
+
reader="ITKReader",
|
| 82 |
ensure_channel_first=True,
|
| 83 |
dtype=np.float32,
|
| 84 |
),
|
|
|
|
| 104 |
return transform
|
| 105 |
|
| 106 |
|
| 107 |
+
def get_dataloader(
|
| 108 |
+
args: argparse.Namespace, split: Literal["train", "test"]
|
| 109 |
+
) -> torch.utils.data.DataLoader:
|
| 110 |
data_list = load_decathlon_datalist(
|
| 111 |
data_list_file_path=args.dataset_json,
|
| 112 |
data_list_key=split,
|
| 113 |
base_dir=args.data_root,
|
| 114 |
)
|
| 115 |
if args.dry_run:
|
| 116 |
+
data_list = data_list[:2] # Use only 8 samples for dry run
|
| 117 |
cache_dir_ = os.path.join(args.logdir, "cache")
|
| 118 |
os.makedirs(os.path.join(cache_dir_, split), exist_ok=True)
|
| 119 |
transform = data_transform(args)
|
src/model/{csPCa_model.py → cspca_model.py}
RENAMED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
from __future__ import annotations
|
|
|
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
from monai.utils.module import optional_import
|
|
@@ -17,8 +18,8 @@ class SimpleNN(nn.Module):
|
|
| 17 |
input_dim (int): The number of input features.
|
| 18 |
"""
|
| 19 |
|
| 20 |
-
def __init__(self, input_dim):
|
| 21 |
-
super(
|
| 22 |
self.net = nn.Sequential(
|
| 23 |
nn.Linear(input_dim, 256),
|
| 24 |
nn.ReLU(),
|
|
@@ -42,7 +43,7 @@ class SimpleNN(nn.Module):
|
|
| 42 |
return self.net(x)
|
| 43 |
|
| 44 |
|
| 45 |
-
class
|
| 46 |
"""
|
| 47 |
Clinically Significant Prostate Cancer (csPCa) risk prediction model using a MIL backbone.
|
| 48 |
|
|
@@ -67,7 +68,7 @@ class csPCa_Model(nn.Module):
|
|
| 67 |
backbone: The MIL based PI-RADS classifier.
|
| 68 |
"""
|
| 69 |
|
| 70 |
-
def __init__(self, backbone):
|
| 71 |
super().__init__()
|
| 72 |
self.backbone = backbone
|
| 73 |
self.fc_dim = backbone.myfc.in_features
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
+
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
from monai.utils.module import optional_import
|
|
|
|
| 18 |
input_dim (int): The number of input features.
|
| 19 |
"""
|
| 20 |
|
| 21 |
+
def __init__(self, input_dim: int) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
self.net = nn.Sequential(
|
| 24 |
nn.Linear(input_dim, 256),
|
| 25 |
nn.ReLU(),
|
|
|
|
| 43 |
return self.net(x)
|
| 44 |
|
| 45 |
|
| 46 |
+
class CSPCAModel(nn.Module):
|
| 47 |
"""
|
| 48 |
Clinically Significant Prostate Cancer (csPCa) risk prediction model using a MIL backbone.
|
| 49 |
|
|
|
|
| 68 |
backbone: The MIL based PI-RADS classifier.
|
| 69 |
"""
|
| 70 |
|
| 71 |
+
def __init__(self, backbone: nn.Module) -> None:
|
| 72 |
super().__init__()
|
| 73 |
self.backbone = backbone
|
| 74 |
self.fc_dim = backbone.myfc.in_features
|
src/model/{MIL.py → mil.py}
RENAMED
|
@@ -1,14 +1,16 @@
|
|
| 1 |
from __future__ import annotations
|
|
|
|
| 2 |
from typing import cast
|
|
|
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
-
from monai.utils.module import optional_import
|
| 6 |
from monai.networks.nets import resnet
|
|
|
|
| 7 |
|
| 8 |
models, _ = optional_import("torchvision.models")
|
| 9 |
|
| 10 |
|
| 11 |
-
class
|
| 12 |
"""
|
| 13 |
Multiple Instance Learning (MIL) model, with a backbone classification model.
|
| 14 |
Adapted from MONAI, modified for 3D images. The expected shape of input data is `[B, N, C, D, H, W]`,
|
|
@@ -56,6 +58,7 @@ class MILModel_3D(nn.Module):
|
|
| 56 |
self.mil_mode = mil_mode.lower()
|
| 57 |
self.attention = nn.Sequential()
|
| 58 |
self.transformer: nn.Module | None = None
|
|
|
|
| 59 |
|
| 60 |
if backbone is None:
|
| 61 |
net = resnet.resnet18(
|
|
@@ -63,8 +66,9 @@ class MILModel_3D(nn.Module):
|
|
| 63 |
n_input_channels=3,
|
| 64 |
num_classes=5,
|
| 65 |
)
|
|
|
|
| 66 |
nfc = net.fc.in_features # save the number of final features
|
| 67 |
-
net.fc = torch.nn.Identity() #
|
| 68 |
|
| 69 |
self.extra_outputs: dict[str, torch.Tensor] = {}
|
| 70 |
|
|
@@ -90,7 +94,7 @@ class MILModel_3D(nn.Module):
|
|
| 90 |
|
| 91 |
if getattr(net, "fc", None) is not None:
|
| 92 |
nfc = net.fc.in_features # save the number of final features
|
| 93 |
-
net.fc = torch.nn.Identity() #
|
| 94 |
else:
|
| 95 |
raise ValueError(
|
| 96 |
"Unable to detect FC layer for the torchvision model " + str(backbone),
|
|
@@ -100,8 +104,13 @@ class MILModel_3D(nn.Module):
|
|
| 100 |
elif isinstance(backbone, nn.Module):
|
| 101 |
# use a custom backbone
|
| 102 |
net = backbone
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
nfc = backbone_num_features
|
| 104 |
-
net.fc = torch.nn.Identity() #
|
| 105 |
|
| 106 |
if mil_mode == "att_trans_pyramid":
|
| 107 |
# register hooks to capture outputs of intermediate layers
|
|
@@ -109,11 +118,6 @@ class MILModel_3D(nn.Module):
|
|
| 109 |
"Cannot use att_trans_pyramid with custom backbone. Have to use the default ResNet 18 backbone."
|
| 110 |
)
|
| 111 |
|
| 112 |
-
if backbone_num_features is None:
|
| 113 |
-
raise ValueError(
|
| 114 |
-
"Number of endencoder features must be provided for a custom backbone model"
|
| 115 |
-
)
|
| 116 |
-
|
| 117 |
else:
|
| 118 |
raise ValueError("Unsupported backbone")
|
| 119 |
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
+
|
| 3 |
from typing import cast
|
| 4 |
+
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
|
|
|
| 7 |
from monai.networks.nets import resnet
|
| 8 |
+
from monai.utils.module import optional_import
|
| 9 |
|
| 10 |
models, _ = optional_import("torchvision.models")
|
| 11 |
|
| 12 |
|
| 13 |
+
class MILModel3D(nn.Module):
|
| 14 |
"""
|
| 15 |
Multiple Instance Learning (MIL) model, with a backbone classification model.
|
| 16 |
Adapted from MONAI, modified for 3D images. The expected shape of input data is `[B, N, C, D, H, W]`,
|
|
|
|
| 58 |
self.mil_mode = mil_mode.lower()
|
| 59 |
self.attention = nn.Sequential()
|
| 60 |
self.transformer: nn.Module | None = None
|
| 61 |
+
net: nn.Module
|
| 62 |
|
| 63 |
if backbone is None:
|
| 64 |
net = resnet.resnet18(
|
|
|
|
| 66 |
n_input_channels=3,
|
| 67 |
num_classes=5,
|
| 68 |
)
|
| 69 |
+
assert net.fc is not None
|
| 70 |
nfc = net.fc.in_features # save the number of final features
|
| 71 |
+
net.fc = torch.nn.Identity() # type: ignore[assignment]
|
| 72 |
|
| 73 |
self.extra_outputs: dict[str, torch.Tensor] = {}
|
| 74 |
|
|
|
|
| 94 |
|
| 95 |
if getattr(net, "fc", None) is not None:
|
| 96 |
nfc = net.fc.in_features # save the number of final features
|
| 97 |
+
net.fc = torch.nn.Identity() # type: ignore[assignment]
|
| 98 |
else:
|
| 99 |
raise ValueError(
|
| 100 |
"Unable to detect FC layer for the torchvision model " + str(backbone),
|
|
|
|
| 104 |
elif isinstance(backbone, nn.Module):
|
| 105 |
# use a custom backbone
|
| 106 |
net = backbone
|
| 107 |
+
|
| 108 |
+
if backbone_num_features is None:
|
| 109 |
+
raise ValueError(
|
| 110 |
+
"Number of endencoder features must be provided for a custom backbone model"
|
| 111 |
+
)
|
| 112 |
nfc = backbone_num_features
|
| 113 |
+
net.fc = torch.nn.Identity() # type: ignore[assignment]
|
| 114 |
|
| 115 |
if mil_mode == "att_trans_pyramid":
|
| 116 |
# register hooks to capture outputs of intermediate layers
|
|
|
|
| 118 |
"Cannot use att_trans_pyramid with custom backbone. Have to use the default ResNet 18 backbone."
|
| 119 |
)
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
else:
|
| 122 |
raise ValueError("Unsupported backbone")
|
| 123 |
|
src/preprocessing/center_crop.py
CHANGED
|
@@ -12,7 +12,7 @@
|
|
| 12 |
|
| 13 |
|
| 14 |
# import argparse
|
| 15 |
-
from typing import Union
|
| 16 |
|
| 17 |
import SimpleITK as sitk # noqa N813
|
| 18 |
|
|
@@ -21,7 +21,11 @@ def _flatten(t):
|
|
| 21 |
return [item for sublist in t for item in sublist]
|
| 22 |
|
| 23 |
|
| 24 |
-
def crop(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
"""
|
| 26 |
Crops a sitk.Image while retaining correct spacing. Negative margins will lead to zero padding
|
| 27 |
|
|
@@ -30,13 +34,14 @@ def crop(image: sitk.Image, margin: Union[int, float], interpolator=sitk.sitkLin
|
|
| 30 |
margin: margins to crop. Single integer or float (percentage crop),
|
| 31 |
lists of int/float or nestes lists are supported.
|
| 32 |
"""
|
| 33 |
-
if isinstance(
|
| 34 |
-
|
|
|
|
| 35 |
else:
|
| 36 |
-
assert isinstance(
|
| 37 |
-
|
| 38 |
|
| 39 |
-
margin = [m if isinstance(m, (tuple, list)) else [m, m] for m in
|
| 40 |
old_size = image.GetSize()
|
| 41 |
|
| 42 |
# calculate new origin and new image size
|
|
@@ -46,7 +51,7 @@ def crop(image: sitk.Image, margin: Union[int, float], interpolator=sitk.sitkLin
|
|
| 46 |
)
|
| 47 |
to_crop = [[int(sz * _m) for _m in m] for sz, m in zip(old_size, margin)]
|
| 48 |
elif all([isinstance(m, int) for m in _flatten(margin)]):
|
| 49 |
-
to_crop = margin
|
| 50 |
else:
|
| 51 |
raise ValueError("Wrong format of margins.")
|
| 52 |
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
# import argparse
|
| 15 |
+
from typing import Union, cast
|
| 16 |
|
| 17 |
import SimpleITK as sitk # noqa N813
|
| 18 |
|
|
|
|
| 21 |
return [item for sublist in t for item in sublist]
|
| 22 |
|
| 23 |
|
| 24 |
+
def crop(
|
| 25 |
+
image: sitk.Image,
|
| 26 |
+
margin_: Union[int, float, list[Union[int, float]]],
|
| 27 |
+
interpolator=sitk.sitkLinear,
|
| 28 |
+
):
|
| 29 |
"""
|
| 30 |
Crops a sitk.Image while retaining correct spacing. Negative margins will lead to zero padding
|
| 31 |
|
|
|
|
| 34 |
margin: margins to crop. Single integer or float (percentage crop),
|
| 35 |
lists of int/float or nestes lists are supported.
|
| 36 |
"""
|
| 37 |
+
if isinstance(margin_, (list, tuple)):
|
| 38 |
+
margin_temp = margin_
|
| 39 |
+
assert len(margin_) == 3, "expected margin to be of length 3"
|
| 40 |
else:
|
| 41 |
+
assert isinstance(margin_, (int, float)), "expected margin to be a float value"
|
| 42 |
+
margin_temp = [margin_, margin_, margin_]
|
| 43 |
|
| 44 |
+
margin = [m if isinstance(m, (tuple, list)) else [m, m] for m in margin_temp]
|
| 45 |
old_size = image.GetSize()
|
| 46 |
|
| 47 |
# calculate new origin and new image size
|
|
|
|
| 51 |
)
|
| 52 |
to_crop = [[int(sz * _m) for _m in m] for sz, m in zip(old_size, margin)]
|
| 53 |
elif all([isinstance(m, int) for m in _flatten(margin)]):
|
| 54 |
+
to_crop = cast(list[list[int]], margin)
|
| 55 |
else:
|
| 56 |
raise ValueError("Wrong format of margins.")
|
| 57 |
|
src/preprocessing/generate_heatmap.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
-
|
| 3 |
import nrrd
|
|
|
|
| 4 |
from tqdm import tqdm
|
| 5 |
-
import logging
|
| 6 |
|
| 7 |
|
| 8 |
-
def get_heatmap(args):
|
| 9 |
"""
|
| 10 |
Generate heatmaps from DWI (Diffusion Weighted Imaging) and ADC (Apparent Diffusion Coefficient) medical imaging data.
|
| 11 |
This function processes medical imaging files (DWI and ADC) along with their corresponding
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
import os
|
| 4 |
+
|
| 5 |
import nrrd
|
| 6 |
+
import numpy as np
|
| 7 |
from tqdm import tqdm
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
+
def get_heatmap(args: argparse.Namespace) -> argparse.Namespace:
|
| 11 |
"""
|
| 12 |
Generate heatmaps from DWI (Diffusion Weighted Imaging) and ADC (Apparent Diffusion Coefficient) medical imaging data.
|
| 13 |
This function processes medical imaging files (DWI and ADC) along with their corresponding
|
src/preprocessing/histogram_match.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
|
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
import nrrd
|
| 3 |
-
from skimage import exposure
|
| 4 |
-
import logging
|
| 5 |
import numpy as np
|
|
|
|
| 6 |
from tqdm import tqdm
|
| 7 |
|
| 8 |
|
|
@@ -34,7 +36,7 @@ def get_histmatched(
|
|
| 34 |
return matched_img
|
| 35 |
|
| 36 |
|
| 37 |
-
def histmatch(args):
|
| 38 |
files = os.listdir(args.t2_dir)
|
| 39 |
|
| 40 |
t2_histmatched_dir = os.path.join(args.output_dir, "t2_histmatched")
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
import os
|
| 4 |
+
|
| 5 |
import nrrd
|
|
|
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
+
from skimage import exposure
|
| 8 |
from tqdm import tqdm
|
| 9 |
|
| 10 |
|
|
|
|
| 36 |
return matched_img
|
| 37 |
|
| 38 |
|
| 39 |
+
def histmatch(args: argparse.Namespace) -> argparse.Namespace:
|
| 40 |
files = os.listdir(args.t2_dir)
|
| 41 |
|
| 42 |
t2_histmatched_dir = os.path.join(args.output_dir, "t2_histmatched")
|
src/preprocessing/prostate_mask.py
CHANGED
|
@@ -1,31 +1,29 @@
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
-
|
| 3 |
import nrrd
|
| 4 |
-
|
| 5 |
-
from monai.bundle import ConfigParser
|
| 6 |
import torch
|
| 7 |
-
|
| 8 |
-
|
| 9 |
from monai.transforms import (
|
| 10 |
Compose,
|
|
|
|
|
|
|
| 11 |
LoadImaged,
|
| 12 |
-
ScaleIntensityd,
|
| 13 |
NormalizeIntensityd,
|
| 14 |
-
)
|
| 15 |
-
from monai.utils import set_determinism
|
| 16 |
-
from monai.transforms import (
|
| 17 |
-
EnsureChannelFirstd,
|
| 18 |
Orientationd,
|
|
|
|
| 19 |
Spacingd,
|
| 20 |
-
EnsureTyped,
|
| 21 |
)
|
| 22 |
-
from monai.
|
| 23 |
-
import
|
| 24 |
|
| 25 |
set_determinism(43)
|
| 26 |
|
| 27 |
|
| 28 |
-
def get_segmask(args):
|
| 29 |
"""
|
| 30 |
Generate prostate segmentation masks using a pre-trained deep learning model.
|
| 31 |
This function performs inference on T2-weighted MRI images to segment the prostate gland.
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
import os
|
| 4 |
+
|
| 5 |
import nrrd
|
| 6 |
+
import numpy as np
|
|
|
|
| 7 |
import torch
|
| 8 |
+
from monai.bundle import ConfigParser
|
| 9 |
+
from monai.data import MetaTensor
|
| 10 |
from monai.transforms import (
|
| 11 |
Compose,
|
| 12 |
+
EnsureChannelFirstd,
|
| 13 |
+
EnsureTyped,
|
| 14 |
LoadImaged,
|
|
|
|
| 15 |
NormalizeIntensityd,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
Orientationd,
|
| 17 |
+
ScaleIntensityd,
|
| 18 |
Spacingd,
|
|
|
|
| 19 |
)
|
| 20 |
+
from monai.utils import set_determinism
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
|
| 23 |
set_determinism(43)
|
| 24 |
|
| 25 |
|
| 26 |
+
def get_segmask(args: argparse.Namespace) -> argparse.Namespace:
|
| 27 |
"""
|
| 28 |
Generate prostate segmentation masks using a pre-trained deep learning model.
|
| 29 |
This function performs inference on T2-weighted MRI images to segment the prostate gland.
|
src/preprocessing/register_and_crop.py
CHANGED
|
@@ -1,12 +1,15 @@
|
|
| 1 |
-
import
|
|
|
|
| 2 |
import os
|
| 3 |
-
|
|
|
|
| 4 |
from picai_prep.preprocessing import PreprocessingSettings, Sample
|
|
|
|
|
|
|
| 5 |
from .center_crop import crop
|
| 6 |
-
import logging
|
| 7 |
|
| 8 |
|
| 9 |
-
def register_files(args):
|
| 10 |
"""
|
| 11 |
Register and crop medical images (T2, DWI, and ADC) to a standardized spacing and size.
|
| 12 |
This function reads medical images from specified directories, resamples them to a
|
|
@@ -55,9 +58,9 @@ def register_files(args):
|
|
| 55 |
|
| 56 |
pat_case = Sample(
|
| 57 |
scans=[
|
| 58 |
-
images_to_preprocess
|
| 59 |
-
images_to_preprocess
|
| 60 |
-
images_to_preprocess
|
| 61 |
],
|
| 62 |
settings=PreprocessingSettings(
|
| 63 |
spacing=[3.0, 0.4, 0.4], matrix_size=[new_size[2], new_size[1], new_size[0]]
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
import os
|
| 4 |
+
|
| 5 |
+
import SimpleITK as sitk
|
| 6 |
from picai_prep.preprocessing import PreprocessingSettings, Sample
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
from .center_crop import crop
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
+
def register_files(args: argparse.Namespace) -> argparse.Namespace:
|
| 13 |
"""
|
| 14 |
Register and crop medical images (T2, DWI, and ADC) to a standardized spacing and size.
|
| 15 |
This function reads medical images from specified directories, resamples them to a
|
|
|
|
| 58 |
|
| 59 |
pat_case = Sample(
|
| 60 |
scans=[
|
| 61 |
+
images_to_preprocess["t2"],
|
| 62 |
+
images_to_preprocess["hbv"],
|
| 63 |
+
images_to_preprocess["adc"],
|
| 64 |
],
|
| 65 |
settings=PreprocessingSettings(
|
| 66 |
spacing=[3.0, 0.4, 0.4], matrix_size=[new_size[2], new_size[1], new_size[0]]
|
src/train/train_cspca.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
from monai.metrics import Cumulative, CumulativeAverage
|
| 4 |
-
from sklearn.metrics import confusion_matrix
|
| 5 |
-
from sklearn.metrics import roc_auc_score
|
| 6 |
|
| 7 |
|
| 8 |
def train_epoch(cspca_model, loader, optimizer, epoch, args):
|
|
@@ -10,10 +9,10 @@ def train_epoch(cspca_model, loader, optimizer, epoch, args):
|
|
| 10 |
criterion = nn.BCELoss()
|
| 11 |
loss = 0.0
|
| 12 |
run_loss = CumulativeAverage()
|
| 13 |
-
|
| 14 |
-
|
| 15 |
|
| 16 |
-
for
|
| 17 |
data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
|
| 18 |
target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
|
| 19 |
|
|
@@ -24,13 +23,13 @@ def train_epoch(cspca_model, loader, optimizer, epoch, args):
|
|
| 24 |
loss.backward()
|
| 25 |
optimizer.step()
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
run_loss.append(loss.item())
|
| 30 |
|
| 31 |
loss_epoch = run_loss.aggregate()
|
| 32 |
-
target_list =
|
| 33 |
-
pred_list =
|
| 34 |
auc_epoch = roc_auc_score(target_list, pred_list)
|
| 35 |
|
| 36 |
return loss_epoch, auc_epoch
|
|
@@ -41,10 +40,10 @@ def val_epoch(cspca_model, loader, epoch, args):
|
|
| 41 |
criterion = nn.BCELoss()
|
| 42 |
loss = 0.0
|
| 43 |
run_loss = CumulativeAverage()
|
| 44 |
-
|
| 45 |
-
|
| 46 |
with torch.no_grad():
|
| 47 |
-
for
|
| 48 |
data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
|
| 49 |
target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
|
| 50 |
|
|
@@ -52,13 +51,13 @@ def val_epoch(cspca_model, loader, epoch, args):
|
|
| 52 |
output = output.squeeze(1)
|
| 53 |
loss = criterion(output, target)
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
run_loss.append(loss.item())
|
| 58 |
|
| 59 |
loss_epoch = run_loss.aggregate()
|
| 60 |
-
target_list =
|
| 61 |
-
pred_list =
|
| 62 |
auc_epoch = roc_auc_score(target_list, pred_list)
|
| 63 |
y_pred_categoric = pred_list >= 0.5
|
| 64 |
tn, fp, fn, tp = confusion_matrix(target_list, y_pred_categoric).ravel()
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
from monai.metrics import Cumulative, CumulativeAverage
|
| 4 |
+
from sklearn.metrics import confusion_matrix, roc_auc_score
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
def train_epoch(cspca_model, loader, optimizer, epoch, args):
|
|
|
|
| 9 |
criterion = nn.BCELoss()
|
| 10 |
loss = 0.0
|
| 11 |
run_loss = CumulativeAverage()
|
| 12 |
+
targets_cumulative = Cumulative()
|
| 13 |
+
preds_cumulative = Cumulative()
|
| 14 |
|
| 15 |
+
for _, batch_data in enumerate(loader):
|
| 16 |
data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
|
| 17 |
target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
|
| 18 |
|
|
|
|
| 23 |
loss.backward()
|
| 24 |
optimizer.step()
|
| 25 |
|
| 26 |
+
targets_cumulative.extend(target.detach().cpu())
|
| 27 |
+
preds_cumulative.extend(output.detach().cpu())
|
| 28 |
run_loss.append(loss.item())
|
| 29 |
|
| 30 |
loss_epoch = run_loss.aggregate()
|
| 31 |
+
target_list = targets_cumulative.get_buffer().cpu().numpy()
|
| 32 |
+
pred_list = preds_cumulative.get_buffer().cpu().numpy()
|
| 33 |
auc_epoch = roc_auc_score(target_list, pred_list)
|
| 34 |
|
| 35 |
return loss_epoch, auc_epoch
|
|
|
|
| 40 |
criterion = nn.BCELoss()
|
| 41 |
loss = 0.0
|
| 42 |
run_loss = CumulativeAverage()
|
| 43 |
+
targets_cumulative = Cumulative()
|
| 44 |
+
preds_cumulative = Cumulative()
|
| 45 |
with torch.no_grad():
|
| 46 |
+
for _, batch_data in enumerate(loader):
|
| 47 |
data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
|
| 48 |
target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
|
| 49 |
|
|
|
|
| 51 |
output = output.squeeze(1)
|
| 52 |
loss = criterion(output, target)
|
| 53 |
|
| 54 |
+
targets_cumulative.extend(target.detach().cpu())
|
| 55 |
+
preds_cumulative.extend(output.detach().cpu())
|
| 56 |
run_loss.append(loss.item())
|
| 57 |
|
| 58 |
loss_epoch = run_loss.aggregate()
|
| 59 |
+
target_list = targets_cumulative.get_buffer().cpu().numpy()
|
| 60 |
+
pred_list = preds_cumulative.get_buffer().cpu().numpy()
|
| 61 |
auc_epoch = roc_auc_score(target_list, pred_list)
|
| 62 |
y_pred_categoric = pred_list >= 0.5
|
| 63 |
tn, fp, fn, tp = confusion_matrix(target_list, y_pred_categoric).ravel()
|
src/train/train_pirads.py
CHANGED
|
@@ -1,20 +1,27 @@
|
|
|
|
|
|
|
|
| 1 |
import time
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
from monai.metrics import Cumulative, CumulativeAverage
|
| 6 |
from sklearn.metrics import cohen_kappa_score
|
| 7 |
-
import logging
|
| 8 |
|
| 9 |
|
| 10 |
-
def get_lambda_att(epoch, max_lambda=2.0, warmup_epochs=10):
|
| 11 |
if epoch < warmup_epochs:
|
| 12 |
return (epoch / warmup_epochs) * max_lambda
|
| 13 |
else:
|
| 14 |
return max_lambda
|
| 15 |
|
| 16 |
|
| 17 |
-
def get_attention_scores(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
"""
|
| 19 |
Compute attention scores from heatmaps and shuffle data accordingly.
|
| 20 |
This function generates attention scores based on spatial heatmaps, applies
|
|
@@ -136,17 +143,7 @@ def train_epoch(model, loader, optimizer, scaler, epoch, args):
|
|
| 136 |
run_loss.append(loss.detach().cpu())
|
| 137 |
run_acc.append(acc.detach().cpu())
|
| 138 |
logging.info(
|
| 139 |
-
"Epoch {}/{} {}/{} loss: {:.4f} attention loss: {:.4f} acc: {:.4f} grad norm: {:.4f} time {:.2f}s"
|
| 140 |
-
epoch,
|
| 141 |
-
args.epochs,
|
| 142 |
-
idx,
|
| 143 |
-
len(loader),
|
| 144 |
-
loss.item(),
|
| 145 |
-
attn_loss.item(),
|
| 146 |
-
acc,
|
| 147 |
-
total_norm,
|
| 148 |
-
time.time() - start_time,
|
| 149 |
-
)
|
| 150 |
)
|
| 151 |
start_time = time.time()
|
| 152 |
|
|
@@ -164,8 +161,8 @@ def val_epoch(model, loader, epoch, args):
|
|
| 164 |
|
| 165 |
run_loss = CumulativeAverage()
|
| 166 |
run_acc = CumulativeAverage()
|
| 167 |
-
|
| 168 |
-
|
| 169 |
|
| 170 |
start_time = time.time()
|
| 171 |
loss, acc = 0.0, 0.0
|
|
@@ -188,12 +185,10 @@ def val_epoch(model, loader, epoch, args):
|
|
| 188 |
|
| 189 |
run_loss.append(loss.detach().cpu())
|
| 190 |
run_acc.append(acc.detach().cpu())
|
| 191 |
-
|
| 192 |
-
|
| 193 |
logging.info(
|
| 194 |
-
"Val epoch {}/{} {}/{} loss: {:.4f} acc: {:.4f} time {:.2f}s"
|
| 195 |
-
epoch, args.epochs, idx, len(loader), loss, acc, time.time() - start_time
|
| 196 |
-
)
|
| 197 |
)
|
| 198 |
start_time = time.time()
|
| 199 |
|
|
@@ -201,9 +196,11 @@ def val_epoch(model, loader, epoch, args):
|
|
| 201 |
torch.cuda.empty_cache()
|
| 202 |
|
| 203 |
# Calculate QWK metric (Quadratic Weigted Kappa) https://en.wikipedia.org/wiki/Cohen%27s_kappa
|
| 204 |
-
|
| 205 |
-
|
| 206 |
loss_epoch = run_loss.aggregate()
|
| 207 |
acc_epoch = run_acc.aggregate()
|
| 208 |
-
qwk = cohen_kappa_score(
|
|
|
|
|
|
|
| 209 |
return loss_epoch, acc_epoch, qwk
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
import time
|
| 4 |
+
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
from monai.metrics import Cumulative, CumulativeAverage
|
| 9 |
from sklearn.metrics import cohen_kappa_score
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
+
def get_lambda_att(epoch: int, max_lambda: float = 2.0, warmup_epochs: int = 10) -> float:
|
| 13 |
if epoch < warmup_epochs:
|
| 14 |
return (epoch / warmup_epochs) * max_lambda
|
| 15 |
else:
|
| 16 |
return max_lambda
|
| 17 |
|
| 18 |
|
| 19 |
+
def get_attention_scores(
|
| 20 |
+
data: torch.Tensor,
|
| 21 |
+
target: torch.Tensor,
|
| 22 |
+
heatmap: torch.Tensor,
|
| 23 |
+
args: argparse.Namespace,
|
| 24 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 25 |
"""
|
| 26 |
Compute attention scores from heatmaps and shuffle data accordingly.
|
| 27 |
This function generates attention scores based on spatial heatmaps, applies
|
|
|
|
| 143 |
run_loss.append(loss.detach().cpu())
|
| 144 |
run_acc.append(acc.detach().cpu())
|
| 145 |
logging.info(
|
| 146 |
+
f"Epoch {epoch}/{args.epochs} {idx}/{len(loader)} loss: {loss.item():.4f} attention loss: {attn_loss.item():.4f} acc: {acc:.4f} grad norm: {total_norm:.4f} time {time.time() - start_time:.2f}s"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
)
|
| 148 |
start_time = time.time()
|
| 149 |
|
|
|
|
| 161 |
|
| 162 |
run_loss = CumulativeAverage()
|
| 163 |
run_acc = CumulativeAverage()
|
| 164 |
+
preds_cumulative = Cumulative()
|
| 165 |
+
targets_cumulative = Cumulative()
|
| 166 |
|
| 167 |
start_time = time.time()
|
| 168 |
loss, acc = 0.0, 0.0
|
|
|
|
| 185 |
|
| 186 |
run_loss.append(loss.detach().cpu())
|
| 187 |
run_acc.append(acc.detach().cpu())
|
| 188 |
+
preds_cumulative.extend(pred.detach().cpu())
|
| 189 |
+
targets_cumulative.extend(target.detach().cpu())
|
| 190 |
logging.info(
|
| 191 |
+
f"Val epoch {epoch}/{args.epochs} {idx}/{len(loader)} loss: {loss:.4f} acc: {acc:.4f} time {time.time() - start_time:.2f}s"
|
|
|
|
|
|
|
| 192 |
)
|
| 193 |
start_time = time.time()
|
| 194 |
|
|
|
|
| 196 |
torch.cuda.empty_cache()
|
| 197 |
|
| 198 |
# Calculate QWK metric (Quadratic Weigted Kappa) https://en.wikipedia.org/wiki/Cohen%27s_kappa
|
| 199 |
+
preds_cumulative = preds_cumulative.get_buffer().cpu().numpy()
|
| 200 |
+
targets_cumulative = targets_cumulative.get_buffer().cpu().numpy()
|
| 201 |
loss_epoch = run_loss.aggregate()
|
| 202 |
acc_epoch = run_acc.aggregate()
|
| 203 |
+
qwk = cohen_kappa_score(
|
| 204 |
+
targets_cumulative.astype(np.float64), preds_cumulative.astype(np.float64)
|
| 205 |
+
)
|
| 206 |
return loss_epoch, acc_epoch, qwk
|
src/utils.py
CHANGED
|
@@ -1,21 +1,31 @@
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
|
|
|
| 5 |
from monai.transforms import (
|
| 6 |
Compose,
|
|
|
|
| 7 |
LoadImaged,
|
| 8 |
ToTensord,
|
| 9 |
-
EnsureTyped,
|
| 10 |
)
|
|
|
|
| 11 |
from .data.custom_transforms import ClipMaskIntensityPercentilesd, NormalizeIntensity_customd
|
| 12 |
-
from monai.data import Dataset, ITKReader
|
| 13 |
-
import logging
|
| 14 |
-
from pathlib import Path
|
| 15 |
-
import cv2
|
| 16 |
|
| 17 |
|
| 18 |
-
def save_pirads_checkpoint(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
"""Save checkpoint for the PI-RADS model"""
|
| 20 |
|
| 21 |
state_dict = model.state_dict()
|
|
@@ -25,7 +35,11 @@ def save_pirads_checkpoint(model, epoch, args, filename="model.pth", best_acc=0)
|
|
| 25 |
logging.info(f"Saving checkpoint {filename}")
|
| 26 |
|
| 27 |
|
| 28 |
-
def save_cspca_checkpoint(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
"""Save checkpoint for the csPCa model"""
|
| 30 |
|
| 31 |
state_dict = model.state_dict()
|
|
@@ -41,7 +55,7 @@ def save_cspca_checkpoint(model, val_metric, model_dir):
|
|
| 41 |
logging.info(f"Saving model with auc: {val_metric['auc']}")
|
| 42 |
|
| 43 |
|
| 44 |
-
def get_metrics(metric_dict: dict):
|
| 45 |
for metric_name, metric_list in metric_dict.items():
|
| 46 |
metric_list = np.array(metric_list)
|
| 47 |
lower = np.percentile(metric_list, 2.5)
|
|
@@ -51,7 +65,7 @@ def get_metrics(metric_dict: dict):
|
|
| 51 |
logging.info(f"95% CI: ({lower:.3f}, {upper:.3f})")
|
| 52 |
|
| 53 |
|
| 54 |
-
def setup_logging(log_file):
|
| 55 |
log_file = Path(log_file)
|
| 56 |
log_file.parent.mkdir(parents=True, exist_ok=True)
|
| 57 |
if log_file.exists():
|
|
@@ -66,13 +80,13 @@ def setup_logging(log_file):
|
|
| 66 |
|
| 67 |
|
| 68 |
def validate_steps(steps):
|
| 69 |
-
|
| 70 |
"get_segmentation_mask": ["register_and_crop"],
|
| 71 |
"histogram_match": ["get_segmentation_mask", "register_and_crop"],
|
| 72 |
"get_heatmap": ["get_segmentation_mask", "histogram_match", "register_and_crop"],
|
| 73 |
}
|
| 74 |
for i, step in enumerate(steps):
|
| 75 |
-
required =
|
| 76 |
for req in required:
|
| 77 |
if req not in steps[:i]:
|
| 78 |
logging.error(
|
|
@@ -81,7 +95,10 @@ def validate_steps(steps):
|
|
| 81 |
sys.exit(1)
|
| 82 |
|
| 83 |
|
| 84 |
-
def get_patch_coordinate(
|
|
|
|
|
|
|
|
|
|
| 85 |
"""
|
| 86 |
Locate the coordinates of top-5 patches within a parent image.
|
| 87 |
|
|
@@ -90,7 +107,7 @@ def get_patch_coordinate(patches_top_5, parent_image):
|
|
| 90 |
coordinates (row, column) and the slice index where each patch is found.
|
| 91 |
|
| 92 |
Args:
|
| 93 |
-
patches_top_5 (list): List of top-5
|
| 94 |
where C is channels, H is height, W is width.
|
| 95 |
parent_image (np.ndarray): 3D image volume with shape (height, width, slices)
|
| 96 |
to search within.
|
|
@@ -137,12 +154,12 @@ def get_patch_coordinate(patches_top_5, parent_image):
|
|
| 137 |
return coords
|
| 138 |
|
| 139 |
|
| 140 |
-
def get_parent_image(temp_data_list, args):
|
| 141 |
transform_image = Compose(
|
| 142 |
[
|
| 143 |
LoadImaged(
|
| 144 |
keys=["image", "mask"],
|
| 145 |
-
reader=ITKReader
|
| 146 |
ensure_channel_first=True,
|
| 147 |
dtype=np.float32,
|
| 148 |
),
|
|
@@ -169,9 +186,9 @@ def visualise_patches():
|
|
| 169 |
for i in range(rows):
|
| 170 |
for j in range(slices):
|
| 171 |
ax = axes[i, j]
|
| 172 |
-
|
| 173 |
if j == 0:
|
| 174 |
-
|
| 175 |
for k in range(parent_image.shape[2]):
|
| 176 |
img_temp = parent_image[:, :, k]
|
| 177 |
H, W = img_temp.shape
|
|
@@ -187,11 +204,11 @@ def visualise_patches():
|
|
| 187 |
break
|
| 188 |
if bool1:
|
| 189 |
break
|
| 190 |
-
|
| 191 |
if bool1:
|
| 192 |
break
|
| 193 |
|
| 194 |
-
|
| 195 |
|
| 196 |
|
| 197 |
ax.imshow(parent_image[:, :, k+j], cmap='gray')
|
|
@@ -199,7 +216,7 @@ def visualise_patches():
|
|
| 199 |
linewidth=2, edgecolor='red', facecolor='none')
|
| 200 |
ax.add_patch(rect)
|
| 201 |
ax.axis('off')
|
| 202 |
-
|
| 203 |
|
| 204 |
plt.tight_layout()
|
| 205 |
plt.show()
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
import os
|
| 4 |
import sys
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Union
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
import numpy as np
|
| 10 |
import torch
|
| 11 |
+
from monai.data import Dataset
|
| 12 |
from monai.transforms import (
|
| 13 |
Compose,
|
| 14 |
+
EnsureTyped,
|
| 15 |
LoadImaged,
|
| 16 |
ToTensord,
|
|
|
|
| 17 |
)
|
| 18 |
+
|
| 19 |
from .data.custom_transforms import ClipMaskIntensityPercentilesd, NormalizeIntensity_customd
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
+
def save_pirads_checkpoint(
|
| 23 |
+
model: torch.nn.Module,
|
| 24 |
+
epoch: int,
|
| 25 |
+
args: argparse.Namespace,
|
| 26 |
+
filename: str = "model.pth",
|
| 27 |
+
best_acc: float = 0,
|
| 28 |
+
) -> None:
|
| 29 |
"""Save checkpoint for the PI-RADS model"""
|
| 30 |
|
| 31 |
state_dict = model.state_dict()
|
|
|
|
| 35 |
logging.info(f"Saving checkpoint {filename}")
|
| 36 |
|
| 37 |
|
| 38 |
+
def save_cspca_checkpoint(
|
| 39 |
+
model: torch.nn.Module,
|
| 40 |
+
val_metric: dict[str, Any],
|
| 41 |
+
model_dir: str,
|
| 42 |
+
) -> None:
|
| 43 |
"""Save checkpoint for the csPCa model"""
|
| 44 |
|
| 45 |
state_dict = model.state_dict()
|
|
|
|
| 55 |
logging.info(f"Saving model with auc: {val_metric['auc']}")
|
| 56 |
|
| 57 |
|
| 58 |
+
def get_metrics(metric_dict: dict) -> None:
|
| 59 |
for metric_name, metric_list in metric_dict.items():
|
| 60 |
metric_list = np.array(metric_list)
|
| 61 |
lower = np.percentile(metric_list, 2.5)
|
|
|
|
| 65 |
logging.info(f"95% CI: ({lower:.3f}, {upper:.3f})")
|
| 66 |
|
| 67 |
|
| 68 |
+
def setup_logging(log_file: Union[str, Path]) -> None:
|
| 69 |
log_file = Path(log_file)
|
| 70 |
log_file.parent.mkdir(parents=True, exist_ok=True)
|
| 71 |
if log_file.exists():
|
|
|
|
| 80 |
|
| 81 |
|
| 82 |
def validate_steps(steps):
|
| 83 |
+
requires = {
|
| 84 |
"get_segmentation_mask": ["register_and_crop"],
|
| 85 |
"histogram_match": ["get_segmentation_mask", "register_and_crop"],
|
| 86 |
"get_heatmap": ["get_segmentation_mask", "histogram_match", "register_and_crop"],
|
| 87 |
}
|
| 88 |
for i, step in enumerate(steps):
|
| 89 |
+
required = requires.get(step, [])
|
| 90 |
for req in required:
|
| 91 |
if req not in steps[:i]:
|
| 92 |
logging.error(
|
|
|
|
| 95 |
sys.exit(1)
|
| 96 |
|
| 97 |
|
| 98 |
+
def get_patch_coordinate(
|
| 99 |
+
patches_top_5: list[np.ndarray],
|
| 100 |
+
parent_image: np.ndarray,
|
| 101 |
+
) -> list[tuple[int, int, int]]:
|
| 102 |
"""
|
| 103 |
Locate the coordinates of top-5 patches within a parent image.
|
| 104 |
|
|
|
|
| 107 |
coordinates (row, column) and the slice index where each patch is found.
|
| 108 |
|
| 109 |
Args:
|
| 110 |
+
patches_top_5 (list): List of top-5 patches as np arrays, each with shape (C, H, W)
|
| 111 |
where C is channels, H is height, W is width.
|
| 112 |
parent_image (np.ndarray): 3D image volume with shape (height, width, slices)
|
| 113 |
to search within.
|
|
|
|
| 154 |
return coords
|
| 155 |
|
| 156 |
|
| 157 |
+
def get_parent_image(temp_data_list, args: argparse.Namespace) -> np.ndarray:
|
| 158 |
transform_image = Compose(
|
| 159 |
[
|
| 160 |
LoadImaged(
|
| 161 |
keys=["image", "mask"],
|
| 162 |
+
reader="ITKReader",
|
| 163 |
ensure_channel_first=True,
|
| 164 |
dtype=np.float32,
|
| 165 |
),
|
|
|
|
| 186 |
for i in range(rows):
|
| 187 |
for j in range(slices):
|
| 188 |
ax = axes[i, j]
|
| 189 |
+
|
| 190 |
if j == 0:
|
| 191 |
+
|
| 192 |
for k in range(parent_image.shape[2]):
|
| 193 |
img_temp = parent_image[:, :, k]
|
| 194 |
H, W = img_temp.shape
|
|
|
|
| 204 |
break
|
| 205 |
if bool1:
|
| 206 |
break
|
| 207 |
+
|
| 208 |
if bool1:
|
| 209 |
break
|
| 210 |
|
| 211 |
+
|
| 212 |
|
| 213 |
|
| 214 |
ax.imshow(parent_image[:, :, k+j], cmap='gray')
|
|
|
|
| 216 |
linewidth=2, edgecolor='red', facecolor='none')
|
| 217 |
ax.add_patch(rect)
|
| 218 |
ax.axis('off')
|
| 219 |
+
|
| 220 |
|
| 221 |
plt.tight_layout()
|
| 222 |
plt.show()
|
temp.ipynb
CHANGED
|
@@ -160,7 +160,7 @@
|
|
| 160 |
},
|
| 161 |
{
|
| 162 |
"cell_type": "code",
|
| 163 |
-
"execution_count":
|
| 164 |
"id": "c91a5802",
|
| 165 |
"metadata": {},
|
| 166 |
"outputs": [
|
|
@@ -168,9 +168,18 @@
|
|
| 168 |
"name": "stderr",
|
| 169 |
"output_type": "stream",
|
| 170 |
"text": [
|
| 171 |
-
" 0%| | 0/1 [00:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
"You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
|
| 173 |
-
"
|
|
|
|
|
|
|
| 174 |
]
|
| 175 |
}
|
| 176 |
],
|
|
@@ -188,8 +197,8 @@
|
|
| 188 |
"\n",
|
| 189 |
"args = register_files(args)\n",
|
| 190 |
"args = get_segmask(args)\n",
|
| 191 |
-
"
|
| 192 |
-
"
|
| 193 |
]
|
| 194 |
},
|
| 195 |
{
|
|
@@ -370,7 +379,7 @@
|
|
| 370 |
},
|
| 371 |
{
|
| 372 |
"cell_type": "code",
|
| 373 |
-
"execution_count":
|
| 374 |
"id": "8b5d382e",
|
| 375 |
"metadata": {},
|
| 376 |
"outputs": [],
|
|
@@ -385,7 +394,7 @@
|
|
| 385 |
},
|
| 386 |
{
|
| 387 |
"cell_type": "code",
|
| 388 |
-
"execution_count":
|
| 389 |
"id": "4cf061ec",
|
| 390 |
"metadata": {},
|
| 391 |
"outputs": [
|
|
@@ -419,7 +428,7 @@
|
|
| 419 |
},
|
| 420 |
{
|
| 421 |
"cell_type": "code",
|
| 422 |
-
"execution_count":
|
| 423 |
"id": "fac15515",
|
| 424 |
"metadata": {},
|
| 425 |
"outputs": [],
|
|
@@ -452,7 +461,7 @@
|
|
| 452 |
},
|
| 453 |
{
|
| 454 |
"cell_type": "code",
|
| 455 |
-
"execution_count":
|
| 456 |
"id": "eb80047b",
|
| 457 |
"metadata": {},
|
| 458 |
"outputs": [],
|
|
@@ -493,227 +502,23 @@
|
|
| 493 |
},
|
| 494 |
{
|
| 495 |
"cell_type": "code",
|
| 496 |
-
"execution_count":
|
| 497 |
"id": "dbcfc97f",
|
| 498 |
"metadata": {},
|
| 499 |
"outputs": [
|
| 500 |
{
|
| 501 |
"data": {
|
| 502 |
"text/plain": [
|
| 503 |
-
"
|
| 504 |
-
" -1.6027895 , -1.257521 ],\n",
|
| 505 |
-
" [-0.7786003 , -1.2797965 , -2.0817103 , ..., -2.0371597 ,\n",
|
| 506 |
-
" -2.0817103 , -2.059435 ],\n",
|
| 507 |
-
" [-1.4802749 , -1.658478 , -1.7809926 , ..., -1.4691372 ,\n",
|
| 508 |
-
" -2.1596742 , -1.9591957 ],\n",
|
| 509 |
-
" ...,\n",
|
| 510 |
-
" [ 0.8697783 , 0.34630668, 0.6358867 , ..., 0.624749 ,\n",
|
| 511 |
-
" 0.24606745, 0.03445129],\n",
|
| 512 |
-
" [ 0.44654593, 0.23492976, 0.7806767 , ..., 1.2373221 ,\n",
|
| 513 |
-
" 0.2906182 , -1.0124918 ],\n",
|
| 514 |
-
" [ 0.3017559 , -0.2551287 , 0.27948052, ..., 1.1036698 ,\n",
|
| 515 |
-
" -0.16602719, -1.2463834 ]],\n",
|
| 516 |
-
" \n",
|
| 517 |
-
" [[-0.84542644, -1.3800358 , -1.4357241 , ..., -2.5383558 ,\n",
|
| 518 |
-
" -2.0817103 , -0.99021643],\n",
|
| 519 |
-
" [-0.4444695 , -1.3466226 , -1.8255434 , ..., -1.335485 ,\n",
|
| 520 |
-
" -1.6362027 , -1.4579996 ],\n",
|
| 521 |
-
" [-1.2129703 , -1.7921304 , -2.0371597 , ..., 0.14582822,\n",
|
| 522 |
-
" -0.04351256, -0.82315105],\n",
|
| 523 |
-
" ...,\n",
|
| 524 |
-
" [-0.4556072 , 1.4378005 , 1.5603153 , ..., 0.2906182 ,\n",
|
| 525 |
-
" -0.85656416, -1.1906949 ],\n",
|
| 526 |
-
" [-1.2352457 , -0.35536796, 1.1816336 , ..., -0.39991874,\n",
|
| 527 |
-
" -1.3800358 , -1.2129703 ],\n",
|
| 528 |
-
" [-1.2909342 , -1.1127311 , 0.5467852 , ..., -0.9568034 ,\n",
|
| 529 |
-
" -1.4023111 , -0.9011149 ]],\n",
|
| 530 |
-
" \n",
|
| 531 |
-
" [[-2.1596742 , -1.914645 , -1.7809926 , ..., -1.0124918 ,\n",
|
| 532 |
-
" 0.03445129, 0.44654593],\n",
|
| 533 |
-
" [-2.0037465 , -1.8923696 , -1.8700942 , ..., -1.2129703 ,\n",
|
| 534 |
-
" -1.1572819 , -0.50015795],\n",
|
| 535 |
-
" [-1.8478189 , -1.9369203 , -1.9814711 , ..., -0.65608567,\n",
|
| 536 |
-
" -1.2352457 , -1.4134488 ],\n",
|
| 537 |
-
" ...,\n",
|
| 538 |
-
" [-0.6338103 , -1.0124918 , -0.16602719, ..., 0.41313285,\n",
|
| 539 |
-
" 0.13469052, -0.80087566],\n",
|
| 540 |
-
" [-0.7451872 , -1.3577603 , -0.2996795 , ..., 0.34630668,\n",
|
| 541 |
-
" 0.68043745, 0.45768362],\n",
|
| 542 |
-
" [ 0.0121759 , -0.7674626 , -0.33309257, ..., 0.09013975,\n",
|
| 543 |
-
" 0.46882132, 1.0034306 ]]], dtype=float32),\n",
|
| 544 |
-
" array([[[-0.64494795, -0.789738 , -0.5447087 , ..., 0.0233136 ,\n",
|
| 545 |
-
" 0.14582822, -0.35536796],\n",
|
| 546 |
-
" [-0.8120134 , -0.7340495 , -0.42219412, ..., -0.31081718,\n",
|
| 547 |
-
" 0.46882132, 0.15696591],\n",
|
| 548 |
-
" [-0.08806333, 0.03445129, 0.12355283, ..., -0.2774041 ,\n",
|
| 549 |
-
" 0.73612595, 0.7249882 ],\n",
|
| 550 |
-
" ...,\n",
|
| 551 |
-
" [ 0.09013975, 0.0233136 , -0.24399103, ..., 0.34630668,\n",
|
| 552 |
-
" 0.914329 , 0.6358867 ],\n",
|
| 553 |
-
" [ 0.2906182 , 0.335169 , 0.624749 , ..., 0.24606745,\n",
|
| 554 |
-
" 0.9254667 , 0.9922929 ],\n",
|
| 555 |
-
" [ 0.44654593, 0.70271283, 1.0479814 , ..., -0.09920102,\n",
|
| 556 |
-
" 0.37971976, 0.70271283]],\n",
|
| 557 |
-
" \n",
|
| 558 |
-
" [[-0.23285334, -0.01009948, -0.1326141 , ..., 1.1593583 ,\n",
|
| 559 |
-
" 1.5603153 , 1.5603153 ],\n",
|
| 560 |
-
" [-0.23285334, 0.13469052, 0.1792413 , ..., 0.98115516,\n",
|
| 561 |
-
" 1.5603153 , 1.5603153 ],\n",
|
| 562 |
-
" [-0.36650565, -0.01009948, 0.190379 , ..., 0.8697783 ,\n",
|
| 563 |
-
" 1.5603153 , 1.5603153 ],\n",
|
| 564 |
-
" ...,\n",
|
| 565 |
-
" [-0.16602719, -0.06578795, 0.190379 , ..., -1.4357241 ,\n",
|
| 566 |
-
" -1.368898 , -1.6027895 ],\n",
|
| 567 |
-
" [-0.2662664 , -0.35536796, 0.190379 , ..., -1.513688 ,\n",
|
| 568 |
-
" -1.3800358 , -1.6362027 ],\n",
|
| 569 |
-
" [-0.789738 , -0.7786003 , -0.17716487, ..., -1.7141665 ,\n",
|
| 570 |
-
" -1.5693765 , -1.9591957 ]],\n",
|
| 571 |
-
" \n",
|
| 572 |
-
" [[-0.12147641, -0.01009948, 0.0233136 , ..., 0.769539 ,\n",
|
| 573 |
-
" 0.8140898 , 0.2906182 ],\n",
|
| 574 |
-
" [-0.35536796, -0.2774041 , -0.16602719, ..., 0.55792284,\n",
|
| 575 |
-
" 0.68043745, 0.335169 ],\n",
|
| 576 |
-
" [-0.18830256, 0.1681036 , 0.7918144 , ..., 0.70271283,\n",
|
| 577 |
-
" 0.4910967 , 0.10127745],\n",
|
| 578 |
-
" ...,\n",
|
| 579 |
-
" [-0.97907877, -0.97907877, -0.9011149 , ..., -0.43333182,\n",
|
| 580 |
-
" -0.5447087 , -0.1437518 ],\n",
|
| 581 |
-
" [-1.1238688 , -1.1461442 , -1.0347673 , ..., -0.7117741 ,\n",
|
| 582 |
-
" -0.12147641, 0.479959 ],\n",
|
| 583 |
-
" [-0.8788395 , -0.62267256, -0.9122526 , ..., -0.92339027,\n",
|
| 584 |
-
" -0.42219412, 0.22379206]]], dtype=float32),\n",
|
| 585 |
-
" array([[[ 2.7948052e-01, 1.0368437e+00, 1.5603153e+00, ...,\n",
|
| 586 |
-
" -1.2797965e+00, -8.3428878e-01, -2.4399103e-01],\n",
|
| 587 |
-
" [ 6.7864366e-02, 3.4630668e-01, 1.1704960e+00, ...,\n",
|
| 588 |
-
" -7.7860028e-01, -8.3428878e-01, -3.1081718e-01],\n",
|
| 589 |
-
" [ 2.6834285e-01, 9.0139754e-02, 2.3492976e-01, ...,\n",
|
| 590 |
-
" -6.3381028e-01, -9.5680338e-01, -4.4446951e-01],\n",
|
| 591 |
-
" ...,\n",
|
| 592 |
-
" [-1.3688980e+00, -1.2241080e+00, -9.4566566e-01, ...,\n",
|
| 593 |
-
" -1.0570426e+00, -2.6626641e-01, 1.2707351e+00],\n",
|
| 594 |
-
" [-1.0793180e+00, -8.8997722e-01, -1.1015934e+00, ...,\n",
|
| 595 |
-
" -7.6746261e-01, -5.4650255e-02, 1.0257059e+00],\n",
|
| 596 |
-
" [-7.3404950e-01, -6.5608567e-01, -1.0124918e+00, ...,\n",
|
| 597 |
-
" -5.3357106e-01, 2.6834285e-01, 8.5864055e-01]],\n",
|
| 598 |
-
" \n",
|
| 599 |
-
" [[ 1.1593583e+00, 1.5603153e+00, 1.5603153e+00, ...,\n",
|
| 600 |
-
" -4.5560721e-01, -1.1033872e-01, 8.6977828e-01],\n",
|
| 601 |
-
" [ 8.4750289e-01, 1.5603153e+00, 1.5603153e+00, ...,\n",
|
| 602 |
-
" -1.3020718e+00, -1.0459049e+00, 6.3588673e-01],\n",
|
| 603 |
-
" [ 7.9002060e-02, 1.3469052e-01, -3.1081718e-01, ...,\n",
|
| 604 |
-
" -1.3688980e+00, -1.0570426e+00, 4.9109671e-01],\n",
|
| 605 |
-
" ...,\n",
|
| 606 |
-
" [ 1.5603153e+00, 6.4702439e-01, -4.3512560e-02, ...,\n",
|
| 607 |
-
" 5.2450979e-01, 8.9205366e-01, -2.9967949e-01],\n",
|
| 608 |
-
" [ 1.2930106e+00, 6.1361134e-01, -4.6674490e-01, ...,\n",
|
| 609 |
-
" 2.7948052e-01, 8.4750289e-01, 7.8067672e-01],\n",
|
| 610 |
-
" [ 1.0382106e-03, 2.5720516e-01, -4.1105643e-01, ...,\n",
|
| 611 |
-
" 6.2474900e-01, 1.5603153e+00, 1.5603153e+00]],\n",
|
| 612 |
-
" \n",
|
| 613 |
-
" [[-9.2339027e-01, -1.0236295e+00, -1.0347673e+00, ...,\n",
|
| 614 |
-
" 2.6834285e-01, -1.9944026e-01, -1.1033872e-01],\n",
|
| 615 |
-
" [-8.2315105e-01, -1.1127311e+00, -9.7907877e-01, ...,\n",
|
| 616 |
-
" 1.2818729e+00, 3.2403129e-01, 6.7864366e-02],\n",
|
| 617 |
-
" [-8.0087566e-01, -9.7907877e-01, -8.0087566e-01, ...,\n",
|
| 618 |
-
" 1.5046268e+00, 6.1361134e-01, -1.7716487e-01],\n",
|
| 619 |
-
" ...,\n",
|
| 620 |
-
" [ 1.3932499e+00, 6.8043745e-01, 2.3492976e-01, ...,\n",
|
| 621 |
-
" -1.8144057e+00, -1.2129703e+00, 2.3313597e-02],\n",
|
| 622 |
-
" [ 1.5603153e+00, 1.5603153e+00, 1.5603153e+00, ...,\n",
|
| 623 |
-
" -1.7587173e+00, -1.1572819e+00, 5.6906056e-01],\n",
|
| 624 |
-
" [ 1.5603153e+00, 1.5603153e+00, 1.5603153e+00, ...,\n",
|
| 625 |
-
" -1.8366811e+00, -8.7883949e-01, 1.3598367e+00]]], dtype=float32),\n",
|
| 626 |
-
" array([[[-1.4245864 , -1.3466226 , -1.079318 , ..., -0.36650565,\n",
|
| 627 |
-
" -0.66722333, -1.079318 ],\n",
|
| 628 |
-
" [-1.1350064 , -1.2352457 , -1.4357241 , ..., -0.01009948,\n",
|
| 629 |
-
" -0.41105643, -0.7563249 ],\n",
|
| 630 |
-
" [-1.0459049 , -0.84542644, -1.368898 , ..., -0.22171564,\n",
|
| 631 |
-
" -0.4778826 , -0.82315105],\n",
|
| 632 |
-
" ...,\n",
|
| 633 |
-
" [-2.2599134 , -2.1151235 , -1.6139272 , ..., 0.5245098 ,\n",
|
| 634 |
-
" 0.41313285, 0.37971976],\n",
|
| 635 |
-
" [-2.560631 , -2.204225 , -1.7921304 , ..., 0.10127745,\n",
|
| 636 |
-
" 0.3128936 , 0.37971976],\n",
|
| 637 |
-
" [-2.2933266 , -2.1596742 , -2.0705726 , ..., 0.0233136 ,\n",
|
| 638 |
-
" 0.3017559 , 0.4910967 ]],\n",
|
| 639 |
-
" \n",
|
| 640 |
-
" [[-1.0570426 , -1.3911734 , -1.658478 , ..., 0.769539 ,\n",
|
| 641 |
-
" 0.7806767 , 0.95887977],\n",
|
| 642 |
-
" [-1.6362027 , -1.4134488 , -1.5693765 , ..., 1.0145682 ,\n",
|
| 643 |
-
" 1.1816336 , 1.1370829 ],\n",
|
| 644 |
-
" [-1.914645 , -1.1127311 , -1.224108 , ..., 0.68043745,\n",
|
| 645 |
-
" 0.9922929 , 0.85864055],\n",
|
| 646 |
-
" ...,\n",
|
| 647 |
-
" [-2.1262612 , -2.026022 , -1.7921304 , ..., 0.4910967 ,\n",
|
| 648 |
-
" 0.9031913 , 1.0702567 ],\n",
|
| 649 |
-
" [-2.3712904 , -2.5494936 , -2.304464 , ..., 0.70271283,\n",
|
| 650 |
-
" 0.914329 , 1.0479814 ],\n",
|
| 651 |
-
" [-2.2710512 , -2.3267395 , -2.4715295 , ..., 1.0145682 ,\n",
|
| 652 |
-
" 0.7806767 , 0.7472636 ]],\n",
|
| 653 |
-
" \n",
|
| 654 |
-
" [[-0.94566566, -0.6003972 , -0.85656416, ..., 1.2373221 ,\n",
|
| 655 |
-
" 1.5603153 , 1.5603153 ],\n",
|
| 656 |
-
" [-1.1795572 , -0.38878104, -0.37764335, ..., 1.5603153 ,\n",
|
| 657 |
-
" 1.5603153 , 1.5603153 ],\n",
|
| 658 |
-
" [-1.335485 , -1.3577603 , -0.32195488, ..., 1.5603153 ,\n",
|
| 659 |
-
" 1.5603153 , 1.5603153 ],\n",
|
| 660 |
-
" ...,\n",
|
| 661 |
-
" [-1.224108 , -1.1795572 , -1.4914126 , ..., -0.04351256,\n",
|
| 662 |
-
" -0.5112957 , -0.64494795],\n",
|
| 663 |
-
" [-1.3132095 , -1.8589565 , -1.7587173 , ..., -0.03237487,\n",
|
| 664 |
-
" -0.42219412, -0.1437518 ],\n",
|
| 665 |
-
" [-1.6473403 , -2.4381166 , -2.1262612 , ..., 0.37971976,\n",
|
| 666 |
-
" 0.34630668, 0.46882132]]], dtype=float32),\n",
|
| 667 |
-
" array([[[-0.84542644, -0.8120134 , -0.6894987 , ..., 0.46882132,\n",
|
| 668 |
-
" -0.01009948, -0.39991874],\n",
|
| 669 |
-
" [-0.65608567, -0.8120134 , -0.6338103 , ..., 0.59133595,\n",
|
| 670 |
-
" 0.04558898, -0.5112957 ],\n",
|
| 671 |
-
" [-0.96794105, -0.8342888 , -0.35536796, ..., 0.5801982 ,\n",
|
| 672 |
-
" 0.13469052, -0.37764335],\n",
|
| 673 |
-
" ...,\n",
|
| 674 |
-
" [-0.09920102, -0.11033872, -1.224108 , ..., -0.92339027,\n",
|
| 675 |
-
" -1.2686588 , -0.85656416],\n",
|
| 676 |
-
" [ 0.0121759 , -0.01009948, -1.0236295 , ..., -1.2129703 ,\n",
|
| 677 |
-
" -1.4023111 , -0.7340495 ],\n",
|
| 678 |
-
" [ 0.23492976, -0.07692564, -0.70063645, ..., -0.70063645,\n",
|
| 679 |
-
" -0.934528 , -1.0124918 ]],\n",
|
| 680 |
-
" \n",
|
| 681 |
-
" [[-1.6139272 , -0.97907877, -0.7340495 , ..., -0.8342888 ,\n",
|
| 682 |
-
" -0.6894987 , -1.0347673 ],\n",
|
| 683 |
-
" [-1.0459049 , -0.5892595 , -0.789738 , ..., -1.257521 ,\n",
|
| 684 |
-
" -1.0124918 , -1.4357241 ],\n",
|
| 685 |
-
" [-0.42219412, -0.5892595 , -1.4134488 , ..., -1.5025504 ,\n",
|
| 686 |
-
" -1.2463834 , -1.4802749 ],\n",
|
| 687 |
-
" ...,\n",
|
| 688 |
-
" [ 0.10127745, 0.40199515, 0.13469052, ..., 0.25720516,\n",
|
| 689 |
-
" 0.55792284, 0.12355283],\n",
|
| 690 |
-
" [-0.2885418 , -0.2551287 , 0.41313285, ..., 0.46882132,\n",
|
| 691 |
-
" -0.33309257, -1.2797965 ],\n",
|
| 692 |
-
" [-0.2996795 , -0.64494795, 0.9366044 , ..., 0.56906056,\n",
|
| 693 |
-
" -0.41105643, -1.3466226 ]],\n",
|
| 694 |
-
" \n",
|
| 695 |
-
" [[-1.1238688 , -1.4134488 , -1.7364419 , ..., -1.1350064 ,\n",
|
| 696 |
-
" -1.6918911 , -2.3156018 ],\n",
|
| 697 |
-
" [-1.4914126 , -1.3020718 , -0.99021643, ..., -1.658478 ,\n",
|
| 698 |
-
" -1.7698549 , -1.8812319 ],\n",
|
| 699 |
-
" [-1.5582387 , -0.85656416, -0.43333182, ..., -1.6027895 ,\n",
|
| 700 |
-
" -1.914645 , -1.7698549 ],\n",
|
| 701 |
-
" ...,\n",
|
| 702 |
-
" [ 0.6358867 , 0.9031913 , 1.0702567 , ..., 0.11241514,\n",
|
| 703 |
-
" 0.07900206, 0.34630668],\n",
|
| 704 |
-
" [ 0.70271283, 1.1593583 , 0.9254667 , ..., 0.335169 ,\n",
|
| 705 |
-
" 0.41313285, 0.23492976],\n",
|
| 706 |
-
" [ 0.35744438, 1.1482205 , 0.8697783 , ..., -0.2662664 ,\n",
|
| 707 |
-
" 0.56906056, 0.624749 ]]], dtype=float32)]"
|
| 708 |
]
|
| 709 |
},
|
| 710 |
-
"execution_count":
|
| 711 |
"metadata": {},
|
| 712 |
"output_type": "execute_result"
|
| 713 |
}
|
| 714 |
],
|
| 715 |
"source": [
|
| 716 |
-
"patches_top_5"
|
| 717 |
]
|
| 718 |
},
|
| 719 |
{
|
|
|
|
| 160 |
},
|
| 161 |
{
|
| 162 |
"cell_type": "code",
|
| 163 |
+
"execution_count": 3,
|
| 164 |
"id": "c91a5802",
|
| 165 |
"metadata": {},
|
| 166 |
"outputs": [
|
|
|
|
| 168 |
"name": "stderr",
|
| 169 |
"output_type": "stream",
|
| 170 |
"text": [
|
| 171 |
+
" 0%| | 0/1 [00:00<?, ?it/s]"
|
| 172 |
+
]
|
| 173 |
+
},
|
| 174 |
+
{
|
| 175 |
+
"name": "stderr",
|
| 176 |
+
"output_type": "stream",
|
| 177 |
+
"text": [
|
| 178 |
+
"100%|██████████| 1/1 [00:02<00:00, 2.45s/it]\n",
|
| 179 |
"You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
|
| 180 |
+
"100%|██████████| 1/1 [00:38<00:00, 38.82s/it]\n",
|
| 181 |
+
"100%|██████████| 1/1 [00:00<00:00, 4.11it/s]\n",
|
| 182 |
+
"100%|██████████| 1/1 [00:00<00:00, 9.82it/s]\n"
|
| 183 |
]
|
| 184 |
}
|
| 185 |
],
|
|
|
|
| 197 |
"\n",
|
| 198 |
"args = register_files(args)\n",
|
| 199 |
"args = get_segmask(args)\n",
|
| 200 |
+
"args = histmatch(args)\n",
|
| 201 |
+
"args = get_heatmap(args)\n"
|
| 202 |
]
|
| 203 |
},
|
| 204 |
{
|
|
|
|
| 379 |
},
|
| 380 |
{
|
| 381 |
"cell_type": "code",
|
| 382 |
+
"execution_count": 5,
|
| 383 |
"id": "8b5d382e",
|
| 384 |
"metadata": {},
|
| 385 |
"outputs": [],
|
|
|
|
| 394 |
},
|
| 395 |
{
|
| 396 |
"cell_type": "code",
|
| 397 |
+
"execution_count": 6,
|
| 398 |
"id": "4cf061ec",
|
| 399 |
"metadata": {},
|
| 400 |
"outputs": [
|
|
|
|
| 428 |
},
|
| 429 |
{
|
| 430 |
"cell_type": "code",
|
| 431 |
+
"execution_count": 7,
|
| 432 |
"id": "fac15515",
|
| 433 |
"metadata": {},
|
| 434 |
"outputs": [],
|
|
|
|
| 461 |
},
|
| 462 |
{
|
| 463 |
"cell_type": "code",
|
| 464 |
+
"execution_count": 8,
|
| 465 |
"id": "eb80047b",
|
| 466 |
"metadata": {},
|
| 467 |
"outputs": [],
|
|
|
|
| 502 |
},
|
| 503 |
{
|
| 504 |
"cell_type": "code",
|
| 505 |
+
"execution_count": 9,
|
| 506 |
"id": "dbcfc97f",
|
| 507 |
"metadata": {},
|
| 508 |
"outputs": [
|
| 509 |
{
|
| 510 |
"data": {
|
| 511 |
"text/plain": [
|
| 512 |
+
"list"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
]
|
| 514 |
},
|
| 515 |
+
"execution_count": 9,
|
| 516 |
"metadata": {},
|
| 517 |
"output_type": "execute_result"
|
| 518 |
}
|
| 519 |
],
|
| 520 |
"source": [
|
| 521 |
+
"type(patches_top_5)"
|
| 522 |
]
|
| 523 |
},
|
| 524 |
{
|
tests/test_run.py
CHANGED
|
@@ -20,7 +20,15 @@ def test_run_pirads_training():
|
|
| 20 |
|
| 21 |
# Run the script with the config
|
| 22 |
result = subprocess.run(
|
| 23 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
capture_output=True,
|
| 25 |
text=True,
|
| 26 |
)
|
|
@@ -28,6 +36,7 @@ def test_run_pirads_training():
|
|
| 28 |
# Check that it ran without errors
|
| 29 |
assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
|
| 30 |
|
|
|
|
| 31 |
def test_run_pirads_inference():
|
| 32 |
"""
|
| 33 |
Test that run_cspca.py runs without crashing using an existing YAML config.
|
|
@@ -45,14 +54,23 @@ def test_run_pirads_inference():
|
|
| 45 |
|
| 46 |
# Run the script with the config
|
| 47 |
result = subprocess.run(
|
| 48 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
capture_output=True,
|
| 50 |
text=True,
|
| 51 |
)
|
| 52 |
|
| 53 |
# Check that it ran without errors
|
| 54 |
assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
|
| 55 |
-
|
|
|
|
| 56 |
def test_run_cspca_training():
|
| 57 |
"""
|
| 58 |
Test that run_cspca.py runs without crashing using an existing YAML config.
|
|
@@ -70,14 +88,23 @@ def test_run_cspca_training():
|
|
| 70 |
|
| 71 |
# Run the script with the config
|
| 72 |
result = subprocess.run(
|
| 73 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
capture_output=True,
|
| 75 |
text=True,
|
| 76 |
)
|
| 77 |
|
| 78 |
# Check that it ran without errors
|
| 79 |
assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
|
| 80 |
-
|
|
|
|
| 81 |
def test_run_cspca_inference():
|
| 82 |
"""
|
| 83 |
Test that run_cspca.py runs without crashing using an existing YAML config.
|
|
@@ -95,12 +122,18 @@ def test_run_cspca_inference():
|
|
| 95 |
|
| 96 |
# Run the script with the config
|
| 97 |
result = subprocess.run(
|
| 98 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
capture_output=True,
|
| 100 |
text=True,
|
| 101 |
)
|
| 102 |
|
| 103 |
# Check that it ran without errors
|
| 104 |
assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
|
| 105 |
-
|
| 106 |
-
|
|
|
|
| 20 |
|
| 21 |
# Run the script with the config
|
| 22 |
result = subprocess.run(
|
| 23 |
+
[
|
| 24 |
+
sys.executable,
|
| 25 |
+
str(script_path),
|
| 26 |
+
"--mode",
|
| 27 |
+
"train",
|
| 28 |
+
"--config",
|
| 29 |
+
str(config_path),
|
| 30 |
+
"--dry_run",
|
| 31 |
+
],
|
| 32 |
capture_output=True,
|
| 33 |
text=True,
|
| 34 |
)
|
|
|
|
| 36 |
# Check that it ran without errors
|
| 37 |
assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
|
| 38 |
|
| 39 |
+
|
| 40 |
def test_run_pirads_inference():
|
| 41 |
"""
|
| 42 |
Test that run_cspca.py runs without crashing using an existing YAML config.
|
|
|
|
| 54 |
|
| 55 |
# Run the script with the config
|
| 56 |
result = subprocess.run(
|
| 57 |
+
[
|
| 58 |
+
sys.executable,
|
| 59 |
+
str(script_path),
|
| 60 |
+
"--mode",
|
| 61 |
+
"test",
|
| 62 |
+
"--config",
|
| 63 |
+
str(config_path),
|
| 64 |
+
"--dry_run",
|
| 65 |
+
],
|
| 66 |
capture_output=True,
|
| 67 |
text=True,
|
| 68 |
)
|
| 69 |
|
| 70 |
# Check that it ran without errors
|
| 71 |
assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
|
| 72 |
+
|
| 73 |
+
|
| 74 |
def test_run_cspca_training():
|
| 75 |
"""
|
| 76 |
Test that run_cspca.py runs without crashing using an existing YAML config.
|
|
|
|
| 88 |
|
| 89 |
# Run the script with the config
|
| 90 |
result = subprocess.run(
|
| 91 |
+
[
|
| 92 |
+
sys.executable,
|
| 93 |
+
str(script_path),
|
| 94 |
+
"--mode",
|
| 95 |
+
"train",
|
| 96 |
+
"--config",
|
| 97 |
+
str(config_path),
|
| 98 |
+
"--dry_run",
|
| 99 |
+
],
|
| 100 |
capture_output=True,
|
| 101 |
text=True,
|
| 102 |
)
|
| 103 |
|
| 104 |
# Check that it ran without errors
|
| 105 |
assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
|
| 106 |
+
|
| 107 |
+
|
| 108 |
def test_run_cspca_inference():
|
| 109 |
"""
|
| 110 |
Test that run_cspca.py runs without crashing using an existing YAML config.
|
|
|
|
| 122 |
|
| 123 |
# Run the script with the config
|
| 124 |
result = subprocess.run(
|
| 125 |
+
[
|
| 126 |
+
sys.executable,
|
| 127 |
+
str(script_path),
|
| 128 |
+
"--mode",
|
| 129 |
+
"test",
|
| 130 |
+
"--config",
|
| 131 |
+
str(config_path),
|
| 132 |
+
"--dry_run",
|
| 133 |
+
],
|
| 134 |
capture_output=True,
|
| 135 |
text=True,
|
| 136 |
)
|
| 137 |
|
| 138 |
# Check that it ran without errors
|
| 139 |
assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
|
|
|
|
|
|