Anirudh Balaraman commited on
Commit
caf6ee7
·
1 Parent(s): c67c387
.devcontainer/devcontainer.json DELETED
@@ -1,33 +0,0 @@
1
- {
2
- "name": "Python 3",
3
- // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
4
- "image": "mcr.microsoft.com/devcontainers/python:1-3.11-bookworm",
5
- "customizations": {
6
- "codespaces": {
7
- "openFiles": [
8
- "README.md",
9
- "app.py"
10
- ]
11
- },
12
- "vscode": {
13
- "settings": {},
14
- "extensions": [
15
- "ms-python.python",
16
- "ms-python.vscode-pylance"
17
- ]
18
- }
19
- },
20
- "updateContentCommand": "[ -f packages.txt ] && sudo apt update && sudo apt upgrade -y && sudo xargs apt install -y <packages.txt; [ -f requirements.txt ] && pip3 install --user -r requirements.txt; pip3 install --user streamlit; echo '✅ Packages installed and Requirements met'",
21
- "postAttachCommand": {
22
- "server": "streamlit run app.py --server.enableCORS false --server.enableXsrfProtection false"
23
- },
24
- "portsAttributes": {
25
- "8501": {
26
- "label": "Application",
27
- "onAutoForward": "openPreview"
28
- }
29
- },
30
- "forwardPorts": [
31
- 8501
32
- ]
33
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/workflows/ci.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: CI
2
+
3
+ on:
4
+ push
5
+
6
+ jobs:
7
+ quality-assurance:
8
+ runs-on: ubuntu-latest
9
+ steps:
10
+ - name: Checkout code
11
+ uses: actions/checkout@v4
12
+
13
+ - name: Set up Python
14
+ uses: actions/setup-python@v5
15
+ with:
16
+ python-version: '3.9'
17
+ cache: 'pip' # Speeds up subsequent runs
18
+
19
+ - name: Install dependencies
20
+ run: |
21
+ python -m pip install --upgrade pip
22
+ pip install -r requirements.txt
23
+
24
+ - name: Run CI Suite
25
+ run: make check
.gitignore CHANGED
@@ -6,4 +6,6 @@ temp.ipynb
6
  __pycache__/
7
  **/__pycache__/
8
  *.pyc
9
- .ruff_cache
 
 
 
6
  __pycache__/
7
  **/__pycache__/
8
  *.pyc
9
+ .ruff_cache/
10
+ .mypy_cache/
11
+ .pytest_cache/
Makefile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: format lint typecheck test check
2
+
3
+ format:
4
+ ruff format .
5
+
6
+ lint:
7
+ ruff check . --fix
8
+
9
+ typecheck:
10
+ mypy .
11
+
12
+ test:
13
+ pytest
14
+
15
+ clean:
16
+ @echo "Cleaning project..."
17
+ # Delete compiled bytecode
18
+ @python3 -Bc "import pathlib; [p.unlink() for p in pathlib.Path('.').rglob('*.py[co]')]"
19
+ # Delete directory-based caches
20
+ @python3 -Bc "import shutil, pathlib; \
21
+ [shutil.rmtree(p) for p in pathlib.Path('.').rglob('__pycache__')]; \
22
+ [shutil.rmtree(p) for p in pathlib.Path('.').rglob('.ipynb_checkpoints')]; \
23
+ [shutil.rmtree(p) for p in pathlib.Path('.').rglob('.monai-cache')]; \
24
+ [shutil.rmtree(p) for p in pathlib.Path('.').rglob('.mypy_cache')]; \
25
+ [shutil.rmtree(p) for p in pathlib.Path('.').rglob('.ruff_cache')]; \
26
+ [shutil.rmtree(p) for p in pathlib.Path('.').rglob('.pytest_cache')]"
27
+
28
+ # Updated 'check' to clean before running (optional)
29
+ # This ensures you are testing from a "blank slate"
30
+ check: format lint typecheck test clean
preprocess_main.py CHANGED
@@ -1,14 +1,15 @@
1
- import os
2
- from src.preprocessing.register_and_crop import register_files
3
- from src.preprocessing.prostate_mask import get_segmask
4
- from src.preprocessing.histogram_match import histmatch
5
- from src.preprocessing.generate_heatmap import get_heatmap
6
- import logging
7
- from src.utils import setup_logging
8
- from src.utils import validate_steps
9
  import argparse
 
 
 
10
  import yaml
11
 
 
 
 
 
 
 
12
 
13
  def parse_args():
14
  parser = argparse.ArgumentParser(description="File preprocessing")
@@ -37,7 +38,7 @@ def parse_args():
37
 
38
  args = parser.parse_args()
39
  if args.config:
40
- with open(args.config, "r") as config_file:
41
  config = yaml.safe_load(config_file)
42
  args.__dict__.update(config)
43
  return args
 
 
 
 
 
 
 
 
 
1
  import argparse
2
+ import logging
3
+ import os
4
+
5
  import yaml
6
 
7
+ from src.preprocessing.generate_heatmap import get_heatmap
8
+ from src.preprocessing.histogram_match import histmatch
9
+ from src.preprocessing.prostate_mask import get_segmask
10
+ from src.preprocessing.register_and_crop import register_files
11
+ from src.utils import setup_logging, validate_steps
12
+
13
 
14
  def parse_args():
15
  parser = argparse.ArgumentParser(description="File preprocessing")
 
38
 
39
  args = parser.parse_args()
40
  if args.config:
41
+ with open(args.config) as config_file:
42
  config = yaml.safe_load(config_file)
43
  args.__dict__.update(config)
44
  return args
pyproject.toml CHANGED
@@ -1,9 +1,51 @@
1
  [tool.ruff]
2
  line-length = 100
 
3
 
4
  [tool.ruff.lint]
5
- select = ["E", "W"]
6
  ignore = ["E501"]
7
 
 
 
 
 
 
 
8
  [tool.ruff.format]
9
- quote-style = "double"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  [tool.ruff]
2
  line-length = 100
3
+ target-version = "py39" # Ensures ruff uses 3.9 compatible syntax
4
 
5
  [tool.ruff.lint]
6
+ select = ["E", "W", "F", "I", "B", "N", "UP"]
7
  ignore = ["E501"]
8
 
9
+ [tool.ruff.lint.isort]
10
+ # This makes your imports look like a professional library
11
+ combine-as-imports = true
12
+ force-wrap-aliases = true
13
+ known-first-party = ["src"] # Treats your 'src' folder as a local package
14
+
15
  [tool.ruff.format]
16
+ quote-style = "double"
17
+ indent-style = "space"
18
+ skip-magic-trailing-comma = false
19
+
20
+ [tool.ruff.lint.pep8-naming]
21
+ # Add your custom class name here
22
+ ignore-names = ["sitk", "NormalizeIntensity_custom", "NormalizeIntensity_customd"]
23
+
24
+ [tool.mypy]
25
+ ignore_missing_imports = true
26
+ disable_error_code = ["override", "import-untyped"]
27
+ mypy_path = "."
28
+ pretty = true
29
+ show_error_codes = true
30
+
31
+ [[tool.mypy.overrides]]
32
+ # These settings apply specifically to these external libraries
33
+ module = [
34
+ "yaml.*",
35
+ "nrrd.*",
36
+ "nibabel.*",
37
+ "scipy.*",
38
+ "sklearn.*"
39
+ ]
40
+ ignore_missing_imports = true
41
+
42
+ [tool.pytest.ini_options]
43
+ # Automatically adds these flags every time you run 'pytest'
44
+ addopts = "-v --showlocals --durations=5"
45
+
46
+ # Where pytest should look for tests
47
+ testpaths = ["tests"]
48
+
49
+ # Patterns to identify test files
50
+ python_files = "test_*.py"
51
+ python_functions = "test_*"
run_cspca.py CHANGED
@@ -1,22 +1,23 @@
1
  import argparse
 
2
  import os
3
  import shutil
4
- import yaml
5
  import sys
6
- import torch
7
  from pathlib import Path
 
 
 
8
  from monai.utils import set_determinism
9
- import logging
10
- from src.model.MIL import MILModel_3D
11
- from src.model.csPCa_model import csPCa_Model
12
  from src.data.data_loader import get_dataloader
13
- from src.utils import save_cspca_checkpoint, get_metrics, setup_logging
 
14
  from src.train.train_cspca import train_epoch, val_epoch
15
- import random
16
 
17
 
18
  def main_worker(args):
19
- mil_model = MILModel_3D(num_classes=args.num_classes, mil_mode=args.mil_mode)
20
  cache_dir_path = Path(os.path.join(args.logdir, "cache"))
21
 
22
  if args.mode == "train":
@@ -31,7 +32,7 @@ def main_worker(args):
31
 
32
  train_loader = get_dataloader(args, split="train")
33
  valid_loader = get_dataloader(args, split="test")
34
- cspca_model = csPCa_Model(backbone=mil_model).to(args.device)
35
  for submodule in [
36
  cspca_model.backbone.net,
37
  cspca_model.backbone.myfc,
@@ -45,23 +46,17 @@ def main_worker(args):
45
  )
46
 
47
  old_loss = float("inf")
48
- old_auc = 0.0
49
  for epoch in range(args.epochs):
50
  train_loss, train_auc = train_epoch(
51
  cspca_model, train_loader, optimizer, epoch=epoch, args=args
52
  )
53
- logging.info(
54
- f"EPOCH {epoch} TRAIN loss: {train_loss:.4f} AUC: {train_auc:.4f}"
55
- )
56
  val_metric = val_epoch(cspca_model, valid_loader, epoch=epoch, args=args)
57
  logging.info(
58
  f"EPOCH {epoch} VAL loss: {val_metric['loss']:.4f} AUC: {val_metric['auc']:.4f}"
59
  )
60
  if val_metric["loss"] < old_loss:
61
  old_loss = val_metric["loss"]
62
- old_auc = val_metric["auc"]
63
- sensitivity = val_metric["sensitivity"]
64
- specificity = val_metric["specificity"]
65
  save_cspca_checkpoint(cspca_model, val_metric, model_dir)
66
 
67
  args.checkpoint_cspca = os.path.join(model_dir, "cspca_model.pth")
@@ -69,7 +64,7 @@ def main_worker(args):
69
  shutil.rmtree(cache_dir_path)
70
 
71
 
72
- cspca_model = csPCa_Model(backbone=mil_model).to(args.device)
73
  checkpt = torch.load(args.checkpoint_cspca, map_location="cpu")
74
  cspca_model.load_state_dict(checkpt["state_dict"])
75
  cspca_model = cspca_model.to(args.device)
@@ -96,6 +91,7 @@ def main_worker(args):
96
  get_metrics(metrics_dict)
97
 
98
 
 
99
  def parse_args():
100
  parser = argparse.ArgumentParser(
101
  description="Multiple Instance Learning (MIL) for csPCa risk prediction."
@@ -164,7 +160,7 @@ def parse_args():
164
  )
165
  args = parser.parse_args()
166
  if args.config:
167
- with open(args.config, "r") as config_file:
168
  config = yaml.safe_load(config_file)
169
  args.__dict__.update(config)
170
 
@@ -172,12 +168,13 @@ def parse_args():
172
 
173
 
174
  if __name__ == "__main__":
175
-
176
  args = parse_args()
177
  if args.project_dir is None:
178
- args.project_dir = Path(__file__).resolve().parent # Set project directory
179
 
180
- slurm_job_name = os.getenv('SLURM_JOB_NAME') # If the script is submitted via slurm, job name is the run name
 
 
181
  if slurm_job_name:
182
  args.run_name = slurm_job_name
183
 
@@ -207,10 +204,15 @@ if __name__ == "__main__":
207
 
208
  if args.dry_run:
209
  logging.info("Dry run mode enabled.")
210
- args.epochs = 2
211
  args.batch_size = 2
212
  args.workers = 0
213
- args.num_seeds = 2
214
  args.wandb = False
 
 
215
 
216
  main_worker(args)
 
 
 
 
1
  import argparse
2
+ import logging
3
  import os
4
  import shutil
 
5
  import sys
 
6
  from pathlib import Path
7
+
8
+ import torch
9
+ import yaml
10
  from monai.utils import set_determinism
11
+
 
 
12
  from src.data.data_loader import get_dataloader
13
+ from src.model.cspca_model import CSPCAModel
14
+ from src.model.mil import MILModel3D
15
  from src.train.train_cspca import train_epoch, val_epoch
16
+ from src.utils import get_metrics, save_cspca_checkpoint, setup_logging
17
 
18
 
19
  def main_worker(args):
20
+ mil_model = MILModel3D(num_classes=args.num_classes, mil_mode=args.mil_mode)
21
  cache_dir_path = Path(os.path.join(args.logdir, "cache"))
22
 
23
  if args.mode == "train":
 
32
 
33
  train_loader = get_dataloader(args, split="train")
34
  valid_loader = get_dataloader(args, split="test")
35
+ cspca_model = CSPCAModel(backbone=mil_model).to(args.device)
36
  for submodule in [
37
  cspca_model.backbone.net,
38
  cspca_model.backbone.myfc,
 
46
  )
47
 
48
  old_loss = float("inf")
 
49
  for epoch in range(args.epochs):
50
  train_loss, train_auc = train_epoch(
51
  cspca_model, train_loader, optimizer, epoch=epoch, args=args
52
  )
53
+ logging.info(f"EPOCH {epoch} TRAIN loss: {train_loss:.4f} AUC: {train_auc:.4f}")
 
 
54
  val_metric = val_epoch(cspca_model, valid_loader, epoch=epoch, args=args)
55
  logging.info(
56
  f"EPOCH {epoch} VAL loss: {val_metric['loss']:.4f} AUC: {val_metric['auc']:.4f}"
57
  )
58
  if val_metric["loss"] < old_loss:
59
  old_loss = val_metric["loss"]
 
 
 
60
  save_cspca_checkpoint(cspca_model, val_metric, model_dir)
61
 
62
  args.checkpoint_cspca = os.path.join(model_dir, "cspca_model.pth")
 
64
  shutil.rmtree(cache_dir_path)
65
 
66
 
67
+ cspca_model = CSPCAModel(backbone=mil_model).to(args.device)
68
  checkpt = torch.load(args.checkpoint_cspca, map_location="cpu")
69
  cspca_model.load_state_dict(checkpt["state_dict"])
70
  cspca_model = cspca_model.to(args.device)
 
91
  get_metrics(metrics_dict)
92
 
93
 
94
+
95
  def parse_args():
96
  parser = argparse.ArgumentParser(
97
  description="Multiple Instance Learning (MIL) for csPCa risk prediction."
 
160
  )
161
  args = parser.parse_args()
162
  if args.config:
163
+ with open(args.config) as config_file:
164
  config = yaml.safe_load(config_file)
165
  args.__dict__.update(config)
166
 
 
168
 
169
 
170
  if __name__ == "__main__":
 
171
  args = parse_args()
172
  if args.project_dir is None:
173
+ args.project_dir = Path(__file__).resolve().parent # Set project directory
174
 
175
+ slurm_job_name = os.getenv(
176
+ "SLURM_JOB_NAME"
177
+ ) # If the script is submitted via slurm, job name is the run name
178
  if slurm_job_name:
179
  args.run_name = slurm_job_name
180
 
 
204
 
205
  if args.dry_run:
206
  logging.info("Dry run mode enabled.")
207
+ args.epochs = 1
208
  args.batch_size = 2
209
  args.workers = 0
210
+ args.num_seeds = 1
211
  args.wandb = False
212
+ args.tile_size = 10
213
+ args.tile_count = 5
214
 
215
  main_worker(args)
216
+
217
+ if args.dry_run:
218
+ shutil.rmtree(args.logdir)
run_inference.py CHANGED
@@ -1,18 +1,20 @@
1
  import argparse
 
 
2
  import os
3
- import yaml
4
  import torch
5
- import logging
6
- from src.model.MIL import MILModel_3D
7
- from src.model.csPCa_model import csPCa_Model
8
- from src.utils import setup_logging, get_parent_image, get_patch_coordinate
9
- from src.preprocessing.register_and_crop import register_files
10
- from src.preprocessing.prostate_mask import get_segmask
11
- from src.preprocessing.histogram_match import histmatch
12
- from src.preprocessing.generate_heatmap import get_heatmap
13
- from src.data.data_loader import data_transform, list_data_collate
14
  from monai.data import Dataset
15
- import json
 
 
 
 
 
 
 
 
16
 
17
 
18
  def parse_args():
@@ -36,7 +38,7 @@ def parse_args():
36
 
37
  args = parser.parse_args()
38
  if args.config:
39
- with open(args.config, "r") as config_file:
40
  config = yaml.safe_load(config_file)
41
  args.__dict__.update(config)
42
  return args
@@ -64,14 +66,14 @@ if __name__ == "__main__":
64
  args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
 
66
  logging.info("Loading PIRADS model")
67
- pirads_model = MILModel_3D(num_classes=args.num_classes, mil_mode=args.mil_mode)
68
  pirads_checkpoint = torch.load(
69
  os.path.join(args.project_dir, "models", "pirads.pt"), map_location="cpu"
70
  )
71
  pirads_model.load_state_dict(pirads_checkpoint["state_dict"])
72
  pirads_model.to(args.device)
73
  logging.info("Loading csPCa model")
74
- cspca_model = csPCa_Model(backbone=pirads_model).to(args.device)
75
  checkpt = torch.load(
76
  os.path.join(args.project_dir, "models", "cspca_model.pth"), map_location="cpu"
77
  )
@@ -109,7 +111,7 @@ if __name__ == "__main__":
109
  cspca_model.eval()
110
  patches_top_5_list = []
111
  with torch.no_grad():
112
- for idx, batch_data in enumerate(loader):
113
  data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
114
  logits = pirads_model(data)
115
  pirads_score = torch.argmax(logits, dim=1)
@@ -137,10 +139,10 @@ if __name__ == "__main__":
137
  patches_top_5.append(patch_temp)
138
  patches_top_5_list.append(patches_top_5)
139
  coords_list = []
140
- for i in args.data_list:
141
  parent_image = get_parent_image([i], args)
142
 
143
- coords = get_patch_coordinate(patches_top_5, parent_image)
144
  coords_list.append(coords)
145
  output_dict = {}
146
 
 
1
  import argparse
2
+ import json
3
+ import logging
4
  import os
5
+
6
  import torch
7
+ import yaml
 
 
 
 
 
 
 
 
8
  from monai.data import Dataset
9
+
10
+ from src.data.data_loader import data_transform, list_data_collate
11
+ from src.model.cspca_model import CSPCAModel
12
+ from src.model.mil import MILModel3D
13
+ from src.preprocessing.generate_heatmap import get_heatmap
14
+ from src.preprocessing.histogram_match import histmatch
15
+ from src.preprocessing.prostate_mask import get_segmask
16
+ from src.preprocessing.register_and_crop import register_files
17
+ from src.utils import get_parent_image, get_patch_coordinate, setup_logging
18
 
19
 
20
  def parse_args():
 
38
 
39
  args = parser.parse_args()
40
  if args.config:
41
+ with open(args.config) as config_file:
42
  config = yaml.safe_load(config_file)
43
  args.__dict__.update(config)
44
  return args
 
66
  args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67
 
68
  logging.info("Loading PIRADS model")
69
+ pirads_model = MILModel3D(num_classes=args.num_classes, mil_mode=args.mil_mode)
70
  pirads_checkpoint = torch.load(
71
  os.path.join(args.project_dir, "models", "pirads.pt"), map_location="cpu"
72
  )
73
  pirads_model.load_state_dict(pirads_checkpoint["state_dict"])
74
  pirads_model.to(args.device)
75
  logging.info("Loading csPCa model")
76
+ cspca_model = CSPCAModel(backbone=pirads_model).to(args.device)
77
  checkpt = torch.load(
78
  os.path.join(args.project_dir, "models", "cspca_model.pth"), map_location="cpu"
79
  )
 
111
  cspca_model.eval()
112
  patches_top_5_list = []
113
  with torch.no_grad():
114
+ for _, batch_data in enumerate(loader):
115
  data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
116
  logits = pirads_model(data)
117
  pirads_score = torch.argmax(logits, dim=1)
 
139
  patches_top_5.append(patch_temp)
140
  patches_top_5_list.append(patches_top_5)
141
  coords_list = []
142
+ for j, i in enumerate(args.data_list):
143
  parent_image = get_parent_image([i], args)
144
 
145
+ coords = get_patch_coordinate(patches_top_5_list[j], parent_image)
146
  coords_list.append(coords)
147
  output_dict = {}
148
 
run_pirads.py CHANGED
@@ -1,19 +1,21 @@
1
  import argparse
 
2
  import os
3
  import shutil
4
- import time
5
- import yaml
6
  import sys
 
 
 
7
  import numpy as np
8
  import torch
9
- from torch.utils.tensorboard import SummaryWriter
10
- from monai.utils import set_determinism
11
  import wandb
12
- import logging
13
- from pathlib import Path
 
 
14
  from src.data.data_loader import get_dataloader
 
15
  from src.train.train_pirads import train_epoch, val_epoch
16
- from src.model.MIL import MILModel_3D
17
  from src.utils import save_pirads_checkpoint, setup_logging
18
 
19
 
@@ -21,7 +23,7 @@ def main_worker(args):
21
  if args.device == torch.device("cuda"):
22
  torch.backends.cudnn.benchmark = True
23
 
24
- model = MILModel_3D(num_classes=args.num_classes, mil_mode=args.mil_mode)
25
  start_epoch = 0
26
  best_acc = 0.0
27
  if args.checkpoint is not None:
@@ -250,19 +252,20 @@ def parse_args():
250
  )
251
  args = parser.parse_args()
252
  if args.config:
253
- with open(args.config, "r") as config_file:
254
  config = yaml.safe_load(config_file)
255
  args.__dict__.update(config)
256
  return args
257
 
258
 
259
  if __name__ == "__main__":
260
-
261
  args = parse_args()
262
  if args.project_dir is None:
263
- args.project_dir = Path(__file__).resolve().parent # Set project directory
264
 
265
- slurm_job_name = os.getenv('SLURM_JOB_NAME') # If the script is submitted via slurm, job name is the run name
 
 
266
  if slurm_job_name:
267
  args.run_name = slurm_job_name
268
 
@@ -288,11 +291,13 @@ if __name__ == "__main__":
288
 
289
  if args.dry_run:
290
  logging.info("Dry run mode enabled.")
291
- args.epochs = 2
292
  args.batch_size = 2
293
  args.workers = 0
294
- args.num_seeds = 2
295
  args.wandb = False
 
 
296
 
297
  mode_wandb = "online" if args.wandb and args.mode != "test" else "disabled"
298
 
@@ -314,3 +319,6 @@ if __name__ == "__main__":
314
  main_worker(args)
315
 
316
  wandb.finish()
 
 
 
 
1
  import argparse
2
+ import logging
3
  import os
4
  import shutil
 
 
5
  import sys
6
+ import time
7
+ from pathlib import Path
8
+
9
  import numpy as np
10
  import torch
 
 
11
  import wandb
12
+ import yaml
13
+ from monai.utils import set_determinism
14
+ from torch.utils.tensorboard import SummaryWriter
15
+
16
  from src.data.data_loader import get_dataloader
17
+ from src.model.mil import MILModel3D
18
  from src.train.train_pirads import train_epoch, val_epoch
 
19
  from src.utils import save_pirads_checkpoint, setup_logging
20
 
21
 
 
23
  if args.device == torch.device("cuda"):
24
  torch.backends.cudnn.benchmark = True
25
 
26
+ model = MILModel3D(num_classes=args.num_classes, mil_mode=args.mil_mode)
27
  start_epoch = 0
28
  best_acc = 0.0
29
  if args.checkpoint is not None:
 
252
  )
253
  args = parser.parse_args()
254
  if args.config:
255
+ with open(args.config) as config_file:
256
  config = yaml.safe_load(config_file)
257
  args.__dict__.update(config)
258
  return args
259
 
260
 
261
  if __name__ == "__main__":
 
262
  args = parse_args()
263
  if args.project_dir is None:
264
+ args.project_dir = Path(__file__).resolve().parent # Set project directory
265
 
266
+ slurm_job_name = os.getenv(
267
+ "SLURM_JOB_NAME"
268
+ ) # If the script is submitted via slurm, job name is the run name
269
  if slurm_job_name:
270
  args.run_name = slurm_job_name
271
 
 
291
 
292
  if args.dry_run:
293
  logging.info("Dry run mode enabled.")
294
+ args.epochs = 1
295
  args.batch_size = 2
296
  args.workers = 0
297
+ args.num_seeds = 1
298
  args.wandb = False
299
+ args.tile_size = 10
300
+ args.tile_count = 5
301
 
302
  mode_wandb = "online" if args.wandb and args.mode != "test" else "disabled"
303
 
 
319
  main_worker(args)
320
 
321
  wandb.finish()
322
+
323
+ if args.dry_run:
324
+ shutil.rmtree(args.logdir)
src/data/custom_transforms.py CHANGED
@@ -1,18 +1,19 @@
 
 
 
 
1
  import numpy as np
2
  import torch
3
- from typing import Union
4
- from monai.transforms import MapTransform
5
  from monai.config import DtypeLike, KeysCollection
6
  from monai.config.type_definitions import NdarrayOrTensor
7
  from monai.data.meta_obj import get_track_meta
 
8
  from monai.transforms.transform import Transform
9
  from monai.transforms.utils import soft_clip
10
  from monai.transforms.utils_pytorch_numpy_unification import clip, percentile
11
  from monai.utils.enums import TransformBackends
12
  from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor
13
  from scipy.ndimage import binary_dilation
14
- import cv2
15
- from collections.abc import Hashable, Mapping, Sequence
16
 
17
 
18
  class DilateAndSaveMaskd(MapTransform):
@@ -100,7 +101,7 @@ class ClipMaskIntensityPercentiles(Transform):
100
  self.channel_wise = channel_wise
101
  self.dtype = dtype
102
 
103
- def _clip(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor) -> NdarrayOrTensor:
104
  masked_img = img * (mask_data > 0)
105
  if self.sharpness_factor is not None:
106
  lower_percentile = (
@@ -125,8 +126,8 @@ class ClipMaskIntensityPercentiles(Transform):
125
  )
126
  img = clip(img, lower_percentile, upper_percentile)
127
 
128
- img = convert_to_tensor(img, track_meta=False)
129
- return img
130
 
131
  def __call__(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor) -> NdarrayOrTensor:
132
  """
 
1
+ from collections.abc import Hashable, Mapping, Sequence
2
+ from typing import Union
3
+
4
+ import cv2
5
  import numpy as np
6
  import torch
 
 
7
  from monai.config import DtypeLike, KeysCollection
8
  from monai.config.type_definitions import NdarrayOrTensor
9
  from monai.data.meta_obj import get_track_meta
10
+ from monai.transforms import MapTransform
11
  from monai.transforms.transform import Transform
12
  from monai.transforms.utils import soft_clip
13
  from monai.transforms.utils_pytorch_numpy_unification import clip, percentile
14
  from monai.utils.enums import TransformBackends
15
  from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor
16
  from scipy.ndimage import binary_dilation
 
 
17
 
18
 
19
  class DilateAndSaveMaskd(MapTransform):
 
101
  self.channel_wise = channel_wise
102
  self.dtype = dtype
103
 
104
+ def _clip(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor) -> torch.Tensor:
105
  masked_img = img * (mask_data > 0)
106
  if self.sharpness_factor is not None:
107
  lower_percentile = (
 
126
  )
127
  img = clip(img, lower_percentile, upper_percentile)
128
 
129
+ img_tensor = convert_to_tensor(img, track_meta=False)
130
+ return img_tensor
131
 
132
  def __call__(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor) -> NdarrayOrTensor:
133
  """
src/data/data_loader.py CHANGED
@@ -1,30 +1,33 @@
 
1
  import os
 
 
2
  import numpy as np
3
- from monai.data import load_decathlon_datalist, ITKReader, PersistentDataset
 
4
  from monai.transforms import (
5
  Compose,
6
- LoadImaged,
7
- ToTensord,
8
  ConcatItemsd,
9
  DeleteItemsd,
10
  EnsureTyped,
11
- RandCropByPosNegLabeld,
12
  NormalizeIntensityd,
13
- Transposed,
14
  RandWeightedCropd,
 
 
 
15
  )
 
 
16
  from .custom_transforms import (
17
- NormalizeIntensity_customd,
18
  ClipMaskIntensityPercentilesd,
19
  ElementwiseProductd,
 
20
  )
21
- import torch
22
- from torch.utils.data.dataloader import default_collate
23
- from typing import Literal
24
- import collections.abc
25
 
26
 
27
- def list_data_collate(batch: collections.abc.Sequence):
28
  """
29
  Combine instances from a list of dicts into a single dict, by stacking them along first dim
30
  [{'image' : 3xHxW}, {'image' : 3xHxW}, {'image' : 3xHxW}...] - > {'image' : Nx3xHxW}
@@ -42,13 +45,13 @@ def list_data_collate(batch: collections.abc.Sequence):
42
  return default_collate(batch)
43
 
44
 
45
- def data_transform(args):
46
  if args.use_heatmap:
47
  transform = Compose(
48
  [
49
  LoadImaged(
50
  keys=["image", "mask", "dwi", "adc", "heatmap"],
51
- reader=ITKReader(),
52
  ensure_channel_first=True,
53
  dtype=np.float32,
54
  ),
@@ -75,7 +78,7 @@ def data_transform(args):
75
  [
76
  LoadImaged(
77
  keys=["image", "mask", "dwi", "adc"],
78
- reader=ITKReader(),
79
  ensure_channel_first=True,
80
  dtype=np.float32,
81
  ),
@@ -101,14 +104,16 @@ def data_transform(args):
101
  return transform
102
 
103
 
104
- def get_dataloader(args, split: Literal["train", "test"]):
 
 
105
  data_list = load_decathlon_datalist(
106
  data_list_file_path=args.dataset_json,
107
  data_list_key=split,
108
  base_dir=args.data_root,
109
  )
110
  if args.dry_run:
111
- data_list = data_list[:8] # Use only 8 samples for dry run
112
  cache_dir_ = os.path.join(args.logdir, "cache")
113
  os.makedirs(os.path.join(cache_dir_, split), exist_ok=True)
114
  transform = data_transform(args)
 
1
+ import argparse
2
  import os
3
+ from typing import Literal
4
+
5
  import numpy as np
6
+ import torch
7
+ from monai.data import PersistentDataset, load_decathlon_datalist
8
  from monai.transforms import (
9
  Compose,
 
 
10
  ConcatItemsd,
11
  DeleteItemsd,
12
  EnsureTyped,
13
+ LoadImaged,
14
  NormalizeIntensityd,
15
+ RandCropByPosNegLabeld,
16
  RandWeightedCropd,
17
+ ToTensord,
18
+ Transform,
19
+ Transposed,
20
  )
21
+ from torch.utils.data.dataloader import default_collate
22
+
23
  from .custom_transforms import (
 
24
  ClipMaskIntensityPercentilesd,
25
  ElementwiseProductd,
26
+ NormalizeIntensity_customd,
27
  )
 
 
 
 
28
 
29
 
30
+ def list_data_collate(batch: list):
31
  """
32
  Combine instances from a list of dicts into a single dict, by stacking them along first dim
33
  [{'image' : 3xHxW}, {'image' : 3xHxW}, {'image' : 3xHxW}...] - > {'image' : Nx3xHxW}
 
45
  return default_collate(batch)
46
 
47
 
48
+ def data_transform(args: argparse.Namespace) -> Transform:
49
  if args.use_heatmap:
50
  transform = Compose(
51
  [
52
  LoadImaged(
53
  keys=["image", "mask", "dwi", "adc", "heatmap"],
54
+ reader="ITKReader",
55
  ensure_channel_first=True,
56
  dtype=np.float32,
57
  ),
 
78
  [
79
  LoadImaged(
80
  keys=["image", "mask", "dwi", "adc"],
81
+ reader="ITKReader",
82
  ensure_channel_first=True,
83
  dtype=np.float32,
84
  ),
 
104
  return transform
105
 
106
 
107
+ def get_dataloader(
108
+ args: argparse.Namespace, split: Literal["train", "test"]
109
+ ) -> torch.utils.data.DataLoader:
110
  data_list = load_decathlon_datalist(
111
  data_list_file_path=args.dataset_json,
112
  data_list_key=split,
113
  base_dir=args.data_root,
114
  )
115
  if args.dry_run:
116
+ data_list = data_list[:2] # Use only 8 samples for dry run
117
  cache_dir_ = os.path.join(args.logdir, "cache")
118
  os.makedirs(os.path.join(cache_dir_, split), exist_ok=True)
119
  transform = data_transform(args)
src/model/{csPCa_model.py → cspca_model.py} RENAMED
@@ -1,4 +1,5 @@
1
  from __future__ import annotations
 
2
  import torch
3
  import torch.nn as nn
4
  from monai.utils.module import optional_import
@@ -17,8 +18,8 @@ class SimpleNN(nn.Module):
17
  input_dim (int): The number of input features.
18
  """
19
 
20
- def __init__(self, input_dim):
21
- super(SimpleNN, self).__init__()
22
  self.net = nn.Sequential(
23
  nn.Linear(input_dim, 256),
24
  nn.ReLU(),
@@ -42,7 +43,7 @@ class SimpleNN(nn.Module):
42
  return self.net(x)
43
 
44
 
45
- class csPCa_Model(nn.Module):
46
  """
47
  Clinically Significant Prostate Cancer (csPCa) risk prediction model using a MIL backbone.
48
 
@@ -67,7 +68,7 @@ class csPCa_Model(nn.Module):
67
  backbone: The MIL based PI-RADS classifier.
68
  """
69
 
70
- def __init__(self, backbone):
71
  super().__init__()
72
  self.backbone = backbone
73
  self.fc_dim = backbone.myfc.in_features
 
1
  from __future__ import annotations
2
+
3
  import torch
4
  import torch.nn as nn
5
  from monai.utils.module import optional_import
 
18
  input_dim (int): The number of input features.
19
  """
20
 
21
+ def __init__(self, input_dim: int) -> None:
22
+ super().__init__()
23
  self.net = nn.Sequential(
24
  nn.Linear(input_dim, 256),
25
  nn.ReLU(),
 
43
  return self.net(x)
44
 
45
 
46
+ class CSPCAModel(nn.Module):
47
  """
48
  Clinically Significant Prostate Cancer (csPCa) risk prediction model using a MIL backbone.
49
 
 
68
  backbone: The MIL based PI-RADS classifier.
69
  """
70
 
71
+ def __init__(self, backbone: nn.Module) -> None:
72
  super().__init__()
73
  self.backbone = backbone
74
  self.fc_dim = backbone.myfc.in_features
src/model/{MIL.py → mil.py} RENAMED
@@ -1,14 +1,16 @@
1
  from __future__ import annotations
 
2
  from typing import cast
 
3
  import torch
4
  import torch.nn as nn
5
- from monai.utils.module import optional_import
6
  from monai.networks.nets import resnet
 
7
 
8
  models, _ = optional_import("torchvision.models")
9
 
10
 
11
- class MILModel_3D(nn.Module):
12
  """
13
  Multiple Instance Learning (MIL) model, with a backbone classification model.
14
  Adapted from MONAI, modified for 3D images. The expected shape of input data is `[B, N, C, D, H, W]`,
@@ -56,6 +58,7 @@ class MILModel_3D(nn.Module):
56
  self.mil_mode = mil_mode.lower()
57
  self.attention = nn.Sequential()
58
  self.transformer: nn.Module | None = None
 
59
 
60
  if backbone is None:
61
  net = resnet.resnet18(
@@ -63,8 +66,9 @@ class MILModel_3D(nn.Module):
63
  n_input_channels=3,
64
  num_classes=5,
65
  )
 
66
  nfc = net.fc.in_features # save the number of final features
67
- net.fc = torch.nn.Identity() # remove final linear layer
68
 
69
  self.extra_outputs: dict[str, torch.Tensor] = {}
70
 
@@ -90,7 +94,7 @@ class MILModel_3D(nn.Module):
90
 
91
  if getattr(net, "fc", None) is not None:
92
  nfc = net.fc.in_features # save the number of final features
93
- net.fc = torch.nn.Identity() # remove final linear layer
94
  else:
95
  raise ValueError(
96
  "Unable to detect FC layer for the torchvision model " + str(backbone),
@@ -100,8 +104,13 @@ class MILModel_3D(nn.Module):
100
  elif isinstance(backbone, nn.Module):
101
  # use a custom backbone
102
  net = backbone
 
 
 
 
 
103
  nfc = backbone_num_features
104
- net.fc = torch.nn.Identity() # remove final linear layer
105
 
106
  if mil_mode == "att_trans_pyramid":
107
  # register hooks to capture outputs of intermediate layers
@@ -109,11 +118,6 @@ class MILModel_3D(nn.Module):
109
  "Cannot use att_trans_pyramid with custom backbone. Have to use the default ResNet 18 backbone."
110
  )
111
 
112
- if backbone_num_features is None:
113
- raise ValueError(
114
- "Number of endencoder features must be provided for a custom backbone model"
115
- )
116
-
117
  else:
118
  raise ValueError("Unsupported backbone")
119
 
 
1
  from __future__ import annotations
2
+
3
  from typing import cast
4
+
5
  import torch
6
  import torch.nn as nn
 
7
  from monai.networks.nets import resnet
8
+ from monai.utils.module import optional_import
9
 
10
  models, _ = optional_import("torchvision.models")
11
 
12
 
13
+ class MILModel3D(nn.Module):
14
  """
15
  Multiple Instance Learning (MIL) model, with a backbone classification model.
16
  Adapted from MONAI, modified for 3D images. The expected shape of input data is `[B, N, C, D, H, W]`,
 
58
  self.mil_mode = mil_mode.lower()
59
  self.attention = nn.Sequential()
60
  self.transformer: nn.Module | None = None
61
+ net: nn.Module
62
 
63
  if backbone is None:
64
  net = resnet.resnet18(
 
66
  n_input_channels=3,
67
  num_classes=5,
68
  )
69
+ assert net.fc is not None
70
  nfc = net.fc.in_features # save the number of final features
71
+ net.fc = torch.nn.Identity() # type: ignore[assignment]
72
 
73
  self.extra_outputs: dict[str, torch.Tensor] = {}
74
 
 
94
 
95
  if getattr(net, "fc", None) is not None:
96
  nfc = net.fc.in_features # save the number of final features
97
+ net.fc = torch.nn.Identity() # type: ignore[assignment]
98
  else:
99
  raise ValueError(
100
  "Unable to detect FC layer for the torchvision model " + str(backbone),
 
104
  elif isinstance(backbone, nn.Module):
105
  # use a custom backbone
106
  net = backbone
107
+
108
+ if backbone_num_features is None:
109
+ raise ValueError(
110
+ "Number of endencoder features must be provided for a custom backbone model"
111
+ )
112
  nfc = backbone_num_features
113
+ net.fc = torch.nn.Identity() # type: ignore[assignment]
114
 
115
  if mil_mode == "att_trans_pyramid":
116
  # register hooks to capture outputs of intermediate layers
 
118
  "Cannot use att_trans_pyramid with custom backbone. Have to use the default ResNet 18 backbone."
119
  )
120
 
 
 
 
 
 
121
  else:
122
  raise ValueError("Unsupported backbone")
123
 
src/preprocessing/center_crop.py CHANGED
@@ -12,7 +12,7 @@
12
 
13
 
14
  # import argparse
15
- from typing import Union
16
 
17
  import SimpleITK as sitk # noqa N813
18
 
@@ -21,7 +21,11 @@ def _flatten(t):
21
  return [item for sublist in t for item in sublist]
22
 
23
 
24
- def crop(image: sitk.Image, margin: Union[int, float], interpolator=sitk.sitkLinear):
 
 
 
 
25
  """
26
  Crops a sitk.Image while retaining correct spacing. Negative margins will lead to zero padding
27
 
@@ -30,13 +34,14 @@ def crop(image: sitk.Image, margin: Union[int, float], interpolator=sitk.sitkLin
30
  margin: margins to crop. Single integer or float (percentage crop),
31
  lists of int/float or nestes lists are supported.
32
  """
33
- if isinstance(margin, (list, tuple)):
34
- assert len(margin) == 3, "expected margin to be of length 3"
 
35
  else:
36
- assert isinstance(margin, (int, float)), "expected margin to be a float value"
37
- margin = [margin, margin, margin]
38
 
39
- margin = [m if isinstance(m, (tuple, list)) else [m, m] for m in margin]
40
  old_size = image.GetSize()
41
 
42
  # calculate new origin and new image size
@@ -46,7 +51,7 @@ def crop(image: sitk.Image, margin: Union[int, float], interpolator=sitk.sitkLin
46
  )
47
  to_crop = [[int(sz * _m) for _m in m] for sz, m in zip(old_size, margin)]
48
  elif all([isinstance(m, int) for m in _flatten(margin)]):
49
- to_crop = margin
50
  else:
51
  raise ValueError("Wrong format of margins.")
52
 
 
12
 
13
 
14
  # import argparse
15
+ from typing import Union, cast
16
 
17
  import SimpleITK as sitk # noqa N813
18
 
 
21
  return [item for sublist in t for item in sublist]
22
 
23
 
24
+ def crop(
25
+ image: sitk.Image,
26
+ margin_: Union[int, float, list[Union[int, float]]],
27
+ interpolator=sitk.sitkLinear,
28
+ ):
29
  """
30
  Crops a sitk.Image while retaining correct spacing. Negative margins will lead to zero padding
31
 
 
34
  margin: margins to crop. Single integer or float (percentage crop),
35
  lists of int/float or nestes lists are supported.
36
  """
37
+ if isinstance(margin_, (list, tuple)):
38
+ margin_temp = margin_
39
+ assert len(margin_) == 3, "expected margin to be of length 3"
40
  else:
41
+ assert isinstance(margin_, (int, float)), "expected margin to be a float value"
42
+ margin_temp = [margin_, margin_, margin_]
43
 
44
+ margin = [m if isinstance(m, (tuple, list)) else [m, m] for m in margin_temp]
45
  old_size = image.GetSize()
46
 
47
  # calculate new origin and new image size
 
51
  )
52
  to_crop = [[int(sz * _m) for _m in m] for sz, m in zip(old_size, margin)]
53
  elif all([isinstance(m, int) for m in _flatten(margin)]):
54
+ to_crop = cast(list[list[int]], margin)
55
  else:
56
  raise ValueError("Wrong format of margins.")
57
 
src/preprocessing/generate_heatmap.py CHANGED
@@ -1,11 +1,13 @@
 
 
1
  import os
2
- import numpy as np
3
  import nrrd
 
4
  from tqdm import tqdm
5
- import logging
6
 
7
 
8
- def get_heatmap(args):
9
  """
10
  Generate heatmaps from DWI (Diffusion Weighted Imaging) and ADC (Apparent Diffusion Coefficient) medical imaging data.
11
  This function processes medical imaging files (DWI and ADC) along with their corresponding
 
1
+ import argparse
2
+ import logging
3
  import os
4
+
5
  import nrrd
6
+ import numpy as np
7
  from tqdm import tqdm
 
8
 
9
 
10
+ def get_heatmap(args: argparse.Namespace) -> argparse.Namespace:
11
  """
12
  Generate heatmaps from DWI (Diffusion Weighted Imaging) and ADC (Apparent Diffusion Coefficient) medical imaging data.
13
  This function processes medical imaging files (DWI and ADC) along with their corresponding
src/preprocessing/histogram_match.py CHANGED
@@ -1,8 +1,10 @@
 
 
1
  import os
 
2
  import nrrd
3
- from skimage import exposure
4
- import logging
5
  import numpy as np
 
6
  from tqdm import tqdm
7
 
8
 
@@ -34,7 +36,7 @@ def get_histmatched(
34
  return matched_img
35
 
36
 
37
- def histmatch(args):
38
  files = os.listdir(args.t2_dir)
39
 
40
  t2_histmatched_dir = os.path.join(args.output_dir, "t2_histmatched")
 
1
+ import argparse
2
+ import logging
3
  import os
4
+
5
  import nrrd
 
 
6
  import numpy as np
7
+ from skimage import exposure
8
  from tqdm import tqdm
9
 
10
 
 
36
  return matched_img
37
 
38
 
39
+ def histmatch(args: argparse.Namespace) -> argparse.Namespace:
40
  files = os.listdir(args.t2_dir)
41
 
42
  t2_histmatched_dir = os.path.join(args.output_dir, "t2_histmatched")
src/preprocessing/prostate_mask.py CHANGED
@@ -1,31 +1,29 @@
 
 
1
  import os
2
- import numpy as np
3
  import nrrd
4
- from tqdm import tqdm
5
- from monai.bundle import ConfigParser
6
  import torch
7
-
8
-
9
  from monai.transforms import (
10
  Compose,
 
 
11
  LoadImaged,
12
- ScaleIntensityd,
13
  NormalizeIntensityd,
14
- )
15
- from monai.utils import set_determinism
16
- from monai.transforms import (
17
- EnsureChannelFirstd,
18
  Orientationd,
 
19
  Spacingd,
20
- EnsureTyped,
21
  )
22
- from monai.data import MetaTensor
23
- import logging
24
 
25
  set_determinism(43)
26
 
27
 
28
- def get_segmask(args):
29
  """
30
  Generate prostate segmentation masks using a pre-trained deep learning model.
31
  This function performs inference on T2-weighted MRI images to segment the prostate gland.
 
1
+ import argparse
2
+ import logging
3
  import os
4
+
5
  import nrrd
6
+ import numpy as np
 
7
  import torch
8
+ from monai.bundle import ConfigParser
9
+ from monai.data import MetaTensor
10
  from monai.transforms import (
11
  Compose,
12
+ EnsureChannelFirstd,
13
+ EnsureTyped,
14
  LoadImaged,
 
15
  NormalizeIntensityd,
 
 
 
 
16
  Orientationd,
17
+ ScaleIntensityd,
18
  Spacingd,
 
19
  )
20
+ from monai.utils import set_determinism
21
+ from tqdm import tqdm
22
 
23
  set_determinism(43)
24
 
25
 
26
+ def get_segmask(args: argparse.Namespace) -> argparse.Namespace:
27
  """
28
  Generate prostate segmentation masks using a pre-trained deep learning model.
29
  This function performs inference on T2-weighted MRI images to segment the prostate gland.
src/preprocessing/register_and_crop.py CHANGED
@@ -1,12 +1,15 @@
1
- import SimpleITK as sitk
 
2
  import os
3
- from tqdm import tqdm
 
4
  from picai_prep.preprocessing import PreprocessingSettings, Sample
 
 
5
  from .center_crop import crop
6
- import logging
7
 
8
 
9
- def register_files(args):
10
  """
11
  Register and crop medical images (T2, DWI, and ADC) to a standardized spacing and size.
12
  This function reads medical images from specified directories, resamples them to a
@@ -55,9 +58,9 @@ def register_files(args):
55
 
56
  pat_case = Sample(
57
  scans=[
58
- images_to_preprocess.get("t2"),
59
- images_to_preprocess.get("hbv"),
60
- images_to_preprocess.get("adc"),
61
  ],
62
  settings=PreprocessingSettings(
63
  spacing=[3.0, 0.4, 0.4], matrix_size=[new_size[2], new_size[1], new_size[0]]
 
1
+ import argparse
2
+ import logging
3
  import os
4
+
5
+ import SimpleITK as sitk
6
  from picai_prep.preprocessing import PreprocessingSettings, Sample
7
+ from tqdm import tqdm
8
+
9
  from .center_crop import crop
 
10
 
11
 
12
+ def register_files(args: argparse.Namespace) -> argparse.Namespace:
13
  """
14
  Register and crop medical images (T2, DWI, and ADC) to a standardized spacing and size.
15
  This function reads medical images from specified directories, resamples them to a
 
58
 
59
  pat_case = Sample(
60
  scans=[
61
+ images_to_preprocess["t2"],
62
+ images_to_preprocess["hbv"],
63
+ images_to_preprocess["adc"],
64
  ],
65
  settings=PreprocessingSettings(
66
  spacing=[3.0, 0.4, 0.4], matrix_size=[new_size[2], new_size[1], new_size[0]]
src/train/train_cspca.py CHANGED
@@ -1,8 +1,7 @@
1
  import torch
2
  import torch.nn as nn
3
  from monai.metrics import Cumulative, CumulativeAverage
4
- from sklearn.metrics import confusion_matrix
5
- from sklearn.metrics import roc_auc_score
6
 
7
 
8
  def train_epoch(cspca_model, loader, optimizer, epoch, args):
@@ -10,10 +9,10 @@ def train_epoch(cspca_model, loader, optimizer, epoch, args):
10
  criterion = nn.BCELoss()
11
  loss = 0.0
12
  run_loss = CumulativeAverage()
13
- TARGETS = Cumulative()
14
- PREDS = Cumulative()
15
 
16
- for idx, batch_data in enumerate(loader):
17
  data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
18
  target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
19
 
@@ -24,13 +23,13 @@ def train_epoch(cspca_model, loader, optimizer, epoch, args):
24
  loss.backward()
25
  optimizer.step()
26
 
27
- TARGETS.extend(target.detach().cpu())
28
- PREDS.extend(output.detach().cpu())
29
  run_loss.append(loss.item())
30
 
31
  loss_epoch = run_loss.aggregate()
32
- target_list = TARGETS.get_buffer().cpu().numpy()
33
- pred_list = PREDS.get_buffer().cpu().numpy()
34
  auc_epoch = roc_auc_score(target_list, pred_list)
35
 
36
  return loss_epoch, auc_epoch
@@ -41,10 +40,10 @@ def val_epoch(cspca_model, loader, epoch, args):
41
  criterion = nn.BCELoss()
42
  loss = 0.0
43
  run_loss = CumulativeAverage()
44
- TARGETS = Cumulative()
45
- PREDS = Cumulative()
46
  with torch.no_grad():
47
- for idx, batch_data in enumerate(loader):
48
  data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
49
  target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
50
 
@@ -52,13 +51,13 @@ def val_epoch(cspca_model, loader, epoch, args):
52
  output = output.squeeze(1)
53
  loss = criterion(output, target)
54
 
55
- TARGETS.extend(target.detach().cpu())
56
- PREDS.extend(output.detach().cpu())
57
  run_loss.append(loss.item())
58
 
59
  loss_epoch = run_loss.aggregate()
60
- target_list = TARGETS.get_buffer().cpu().numpy()
61
- pred_list = PREDS.get_buffer().cpu().numpy()
62
  auc_epoch = roc_auc_score(target_list, pred_list)
63
  y_pred_categoric = pred_list >= 0.5
64
  tn, fp, fn, tp = confusion_matrix(target_list, y_pred_categoric).ravel()
 
1
  import torch
2
  import torch.nn as nn
3
  from monai.metrics import Cumulative, CumulativeAverage
4
+ from sklearn.metrics import confusion_matrix, roc_auc_score
 
5
 
6
 
7
  def train_epoch(cspca_model, loader, optimizer, epoch, args):
 
9
  criterion = nn.BCELoss()
10
  loss = 0.0
11
  run_loss = CumulativeAverage()
12
+ targets_cumulative = Cumulative()
13
+ preds_cumulative = Cumulative()
14
 
15
+ for _, batch_data in enumerate(loader):
16
  data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
17
  target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
18
 
 
23
  loss.backward()
24
  optimizer.step()
25
 
26
+ targets_cumulative.extend(target.detach().cpu())
27
+ preds_cumulative.extend(output.detach().cpu())
28
  run_loss.append(loss.item())
29
 
30
  loss_epoch = run_loss.aggregate()
31
+ target_list = targets_cumulative.get_buffer().cpu().numpy()
32
+ pred_list = preds_cumulative.get_buffer().cpu().numpy()
33
  auc_epoch = roc_auc_score(target_list, pred_list)
34
 
35
  return loss_epoch, auc_epoch
 
40
  criterion = nn.BCELoss()
41
  loss = 0.0
42
  run_loss = CumulativeAverage()
43
+ targets_cumulative = Cumulative()
44
+ preds_cumulative = Cumulative()
45
  with torch.no_grad():
46
+ for _, batch_data in enumerate(loader):
47
  data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
48
  target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
49
 
 
51
  output = output.squeeze(1)
52
  loss = criterion(output, target)
53
 
54
+ targets_cumulative.extend(target.detach().cpu())
55
+ preds_cumulative.extend(output.detach().cpu())
56
  run_loss.append(loss.item())
57
 
58
  loss_epoch = run_loss.aggregate()
59
+ target_list = targets_cumulative.get_buffer().cpu().numpy()
60
+ pred_list = preds_cumulative.get_buffer().cpu().numpy()
61
  auc_epoch = roc_auc_score(target_list, pred_list)
62
  y_pred_categoric = pred_list >= 0.5
63
  tn, fp, fn, tp = confusion_matrix(target_list, y_pred_categoric).ravel()
src/train/train_pirads.py CHANGED
@@ -1,20 +1,27 @@
 
 
1
  import time
 
2
  import numpy as np
3
  import torch
4
  import torch.nn as nn
5
  from monai.metrics import Cumulative, CumulativeAverage
6
  from sklearn.metrics import cohen_kappa_score
7
- import logging
8
 
9
 
10
- def get_lambda_att(epoch, max_lambda=2.0, warmup_epochs=10):
11
  if epoch < warmup_epochs:
12
  return (epoch / warmup_epochs) * max_lambda
13
  else:
14
  return max_lambda
15
 
16
 
17
- def get_attention_scores(data, target, heatmap, args):
 
 
 
 
 
18
  """
19
  Compute attention scores from heatmaps and shuffle data accordingly.
20
  This function generates attention scores based on spatial heatmaps, applies
@@ -136,17 +143,7 @@ def train_epoch(model, loader, optimizer, scaler, epoch, args):
136
  run_loss.append(loss.detach().cpu())
137
  run_acc.append(acc.detach().cpu())
138
  logging.info(
139
- "Epoch {}/{} {}/{} loss: {:.4f} attention loss: {:.4f} acc: {:.4f} grad norm: {:.4f} time {:.2f}s".format(
140
- epoch,
141
- args.epochs,
142
- idx,
143
- len(loader),
144
- loss.item(),
145
- attn_loss.item(),
146
- acc,
147
- total_norm,
148
- time.time() - start_time,
149
- )
150
  )
151
  start_time = time.time()
152
 
@@ -164,8 +161,8 @@ def val_epoch(model, loader, epoch, args):
164
 
165
  run_loss = CumulativeAverage()
166
  run_acc = CumulativeAverage()
167
- PREDS = Cumulative()
168
- TARGETS = Cumulative()
169
 
170
  start_time = time.time()
171
  loss, acc = 0.0, 0.0
@@ -188,12 +185,10 @@ def val_epoch(model, loader, epoch, args):
188
 
189
  run_loss.append(loss.detach().cpu())
190
  run_acc.append(acc.detach().cpu())
191
- PREDS.extend(pred.detach().cpu())
192
- TARGETS.extend(target.detach().cpu())
193
  logging.info(
194
- "Val epoch {}/{} {}/{} loss: {:.4f} acc: {:.4f} time {:.2f}s".format(
195
- epoch, args.epochs, idx, len(loader), loss, acc, time.time() - start_time
196
- )
197
  )
198
  start_time = time.time()
199
 
@@ -201,9 +196,11 @@ def val_epoch(model, loader, epoch, args):
201
  torch.cuda.empty_cache()
202
 
203
  # Calculate QWK metric (Quadratic Weigted Kappa) https://en.wikipedia.org/wiki/Cohen%27s_kappa
204
- PREDS = PREDS.get_buffer().cpu().numpy()
205
- TARGETS = TARGETS.get_buffer().cpu().numpy()
206
  loss_epoch = run_loss.aggregate()
207
  acc_epoch = run_acc.aggregate()
208
- qwk = cohen_kappa_score(TARGETS.astype(np.float64), PREDS.astype(np.float64))
 
 
209
  return loss_epoch, acc_epoch, qwk
 
1
+ import argparse
2
+ import logging
3
  import time
4
+
5
  import numpy as np
6
  import torch
7
  import torch.nn as nn
8
  from monai.metrics import Cumulative, CumulativeAverage
9
  from sklearn.metrics import cohen_kappa_score
 
10
 
11
 
12
+ def get_lambda_att(epoch: int, max_lambda: float = 2.0, warmup_epochs: int = 10) -> float:
13
  if epoch < warmup_epochs:
14
  return (epoch / warmup_epochs) * max_lambda
15
  else:
16
  return max_lambda
17
 
18
 
19
+ def get_attention_scores(
20
+ data: torch.Tensor,
21
+ target: torch.Tensor,
22
+ heatmap: torch.Tensor,
23
+ args: argparse.Namespace,
24
+ ) -> tuple[torch.Tensor, torch.Tensor]:
25
  """
26
  Compute attention scores from heatmaps and shuffle data accordingly.
27
  This function generates attention scores based on spatial heatmaps, applies
 
143
  run_loss.append(loss.detach().cpu())
144
  run_acc.append(acc.detach().cpu())
145
  logging.info(
146
+ f"Epoch {epoch}/{args.epochs} {idx}/{len(loader)} loss: {loss.item():.4f} attention loss: {attn_loss.item():.4f} acc: {acc:.4f} grad norm: {total_norm:.4f} time {time.time() - start_time:.2f}s"
 
 
 
 
 
 
 
 
 
 
147
  )
148
  start_time = time.time()
149
 
 
161
 
162
  run_loss = CumulativeAverage()
163
  run_acc = CumulativeAverage()
164
+ preds_cumulative = Cumulative()
165
+ targets_cumulative = Cumulative()
166
 
167
  start_time = time.time()
168
  loss, acc = 0.0, 0.0
 
185
 
186
  run_loss.append(loss.detach().cpu())
187
  run_acc.append(acc.detach().cpu())
188
+ preds_cumulative.extend(pred.detach().cpu())
189
+ targets_cumulative.extend(target.detach().cpu())
190
  logging.info(
191
+ f"Val epoch {epoch}/{args.epochs} {idx}/{len(loader)} loss: {loss:.4f} acc: {acc:.4f} time {time.time() - start_time:.2f}s"
 
 
192
  )
193
  start_time = time.time()
194
 
 
196
  torch.cuda.empty_cache()
197
 
198
  # Calculate QWK metric (Quadratic Weigted Kappa) https://en.wikipedia.org/wiki/Cohen%27s_kappa
199
+ preds_cumulative = preds_cumulative.get_buffer().cpu().numpy()
200
+ targets_cumulative = targets_cumulative.get_buffer().cpu().numpy()
201
  loss_epoch = run_loss.aggregate()
202
  acc_epoch = run_acc.aggregate()
203
+ qwk = cohen_kappa_score(
204
+ targets_cumulative.astype(np.float64), preds_cumulative.astype(np.float64)
205
+ )
206
  return loss_epoch, acc_epoch, qwk
src/utils.py CHANGED
@@ -1,21 +1,31 @@
 
 
1
  import os
2
  import sys
 
 
 
 
3
  import numpy as np
4
  import torch
 
5
  from monai.transforms import (
6
  Compose,
 
7
  LoadImaged,
8
  ToTensord,
9
- EnsureTyped,
10
  )
 
11
  from .data.custom_transforms import ClipMaskIntensityPercentilesd, NormalizeIntensity_customd
12
- from monai.data import Dataset, ITKReader
13
- import logging
14
- from pathlib import Path
15
- import cv2
16
 
17
 
18
- def save_pirads_checkpoint(model, epoch, args, filename="model.pth", best_acc=0):
 
 
 
 
 
 
19
  """Save checkpoint for the PI-RADS model"""
20
 
21
  state_dict = model.state_dict()
@@ -25,7 +35,11 @@ def save_pirads_checkpoint(model, epoch, args, filename="model.pth", best_acc=0)
25
  logging.info(f"Saving checkpoint {filename}")
26
 
27
 
28
- def save_cspca_checkpoint(model, val_metric, model_dir):
 
 
 
 
29
  """Save checkpoint for the csPCa model"""
30
 
31
  state_dict = model.state_dict()
@@ -41,7 +55,7 @@ def save_cspca_checkpoint(model, val_metric, model_dir):
41
  logging.info(f"Saving model with auc: {val_metric['auc']}")
42
 
43
 
44
- def get_metrics(metric_dict: dict):
45
  for metric_name, metric_list in metric_dict.items():
46
  metric_list = np.array(metric_list)
47
  lower = np.percentile(metric_list, 2.5)
@@ -51,7 +65,7 @@ def get_metrics(metric_dict: dict):
51
  logging.info(f"95% CI: ({lower:.3f}, {upper:.3f})")
52
 
53
 
54
- def setup_logging(log_file):
55
  log_file = Path(log_file)
56
  log_file.parent.mkdir(parents=True, exist_ok=True)
57
  if log_file.exists():
@@ -66,13 +80,13 @@ def setup_logging(log_file):
66
 
67
 
68
  def validate_steps(steps):
69
- REQUIRES = {
70
  "get_segmentation_mask": ["register_and_crop"],
71
  "histogram_match": ["get_segmentation_mask", "register_and_crop"],
72
  "get_heatmap": ["get_segmentation_mask", "histogram_match", "register_and_crop"],
73
  }
74
  for i, step in enumerate(steps):
75
- required = REQUIRES.get(step, [])
76
  for req in required:
77
  if req not in steps[:i]:
78
  logging.error(
@@ -81,7 +95,10 @@ def validate_steps(steps):
81
  sys.exit(1)
82
 
83
 
84
- def get_patch_coordinate(patches_top_5, parent_image):
 
 
 
85
  """
86
  Locate the coordinates of top-5 patches within a parent image.
87
 
@@ -90,7 +107,7 @@ def get_patch_coordinate(patches_top_5, parent_image):
90
  coordinates (row, column) and the slice index where each patch is found.
91
 
92
  Args:
93
- patches_top_5 (list): List of top-5 patch tensors, each with shape (C, H, W)
94
  where C is channels, H is height, W is width.
95
  parent_image (np.ndarray): 3D image volume with shape (height, width, slices)
96
  to search within.
@@ -137,12 +154,12 @@ def get_patch_coordinate(patches_top_5, parent_image):
137
  return coords
138
 
139
 
140
- def get_parent_image(temp_data_list, args):
141
  transform_image = Compose(
142
  [
143
  LoadImaged(
144
  keys=["image", "mask"],
145
- reader=ITKReader(),
146
  ensure_channel_first=True,
147
  dtype=np.float32,
148
  ),
@@ -169,9 +186,9 @@ def visualise_patches():
169
  for i in range(rows):
170
  for j in range(slices):
171
  ax = axes[i, j]
172
-
173
  if j == 0:
174
-
175
  for k in range(parent_image.shape[2]):
176
  img_temp = parent_image[:, :, k]
177
  H, W = img_temp.shape
@@ -187,11 +204,11 @@ def visualise_patches():
187
  break
188
  if bool1:
189
  break
190
-
191
  if bool1:
192
  break
193
 
194
-
195
 
196
 
197
  ax.imshow(parent_image[:, :, k+j], cmap='gray')
@@ -199,7 +216,7 @@ def visualise_patches():
199
  linewidth=2, edgecolor='red', facecolor='none')
200
  ax.add_patch(rect)
201
  ax.axis('off')
202
-
203
 
204
  plt.tight_layout()
205
  plt.show()
 
1
+ import argparse
2
+ import logging
3
  import os
4
  import sys
5
+ from pathlib import Path
6
+ from typing import Any, Union
7
+
8
+ import cv2
9
  import numpy as np
10
  import torch
11
+ from monai.data import Dataset
12
  from monai.transforms import (
13
  Compose,
14
+ EnsureTyped,
15
  LoadImaged,
16
  ToTensord,
 
17
  )
18
+
19
  from .data.custom_transforms import ClipMaskIntensityPercentilesd, NormalizeIntensity_customd
 
 
 
 
20
 
21
 
22
+ def save_pirads_checkpoint(
23
+ model: torch.nn.Module,
24
+ epoch: int,
25
+ args: argparse.Namespace,
26
+ filename: str = "model.pth",
27
+ best_acc: float = 0,
28
+ ) -> None:
29
  """Save checkpoint for the PI-RADS model"""
30
 
31
  state_dict = model.state_dict()
 
35
  logging.info(f"Saving checkpoint {filename}")
36
 
37
 
38
+ def save_cspca_checkpoint(
39
+ model: torch.nn.Module,
40
+ val_metric: dict[str, Any],
41
+ model_dir: str,
42
+ ) -> None:
43
  """Save checkpoint for the csPCa model"""
44
 
45
  state_dict = model.state_dict()
 
55
  logging.info(f"Saving model with auc: {val_metric['auc']}")
56
 
57
 
58
+ def get_metrics(metric_dict: dict) -> None:
59
  for metric_name, metric_list in metric_dict.items():
60
  metric_list = np.array(metric_list)
61
  lower = np.percentile(metric_list, 2.5)
 
65
  logging.info(f"95% CI: ({lower:.3f}, {upper:.3f})")
66
 
67
 
68
+ def setup_logging(log_file: Union[str, Path]) -> None:
69
  log_file = Path(log_file)
70
  log_file.parent.mkdir(parents=True, exist_ok=True)
71
  if log_file.exists():
 
80
 
81
 
82
  def validate_steps(steps):
83
+ requires = {
84
  "get_segmentation_mask": ["register_and_crop"],
85
  "histogram_match": ["get_segmentation_mask", "register_and_crop"],
86
  "get_heatmap": ["get_segmentation_mask", "histogram_match", "register_and_crop"],
87
  }
88
  for i, step in enumerate(steps):
89
+ required = requires.get(step, [])
90
  for req in required:
91
  if req not in steps[:i]:
92
  logging.error(
 
95
  sys.exit(1)
96
 
97
 
98
+ def get_patch_coordinate(
99
+ patches_top_5: list[np.ndarray],
100
+ parent_image: np.ndarray,
101
+ ) -> list[tuple[int, int, int]]:
102
  """
103
  Locate the coordinates of top-5 patches within a parent image.
104
 
 
107
  coordinates (row, column) and the slice index where each patch is found.
108
 
109
  Args:
110
+ patches_top_5 (list): List of top-5 patches as np arrays, each with shape (C, H, W)
111
  where C is channels, H is height, W is width.
112
  parent_image (np.ndarray): 3D image volume with shape (height, width, slices)
113
  to search within.
 
154
  return coords
155
 
156
 
157
+ def get_parent_image(temp_data_list, args: argparse.Namespace) -> np.ndarray:
158
  transform_image = Compose(
159
  [
160
  LoadImaged(
161
  keys=["image", "mask"],
162
+ reader="ITKReader",
163
  ensure_channel_first=True,
164
  dtype=np.float32,
165
  ),
 
186
  for i in range(rows):
187
  for j in range(slices):
188
  ax = axes[i, j]
189
+
190
  if j == 0:
191
+
192
  for k in range(parent_image.shape[2]):
193
  img_temp = parent_image[:, :, k]
194
  H, W = img_temp.shape
 
204
  break
205
  if bool1:
206
  break
207
+
208
  if bool1:
209
  break
210
 
211
+
212
 
213
 
214
  ax.imshow(parent_image[:, :, k+j], cmap='gray')
 
216
  linewidth=2, edgecolor='red', facecolor='none')
217
  ax.add_patch(rect)
218
  ax.axis('off')
219
+
220
 
221
  plt.tight_layout()
222
  plt.show()
temp.ipynb CHANGED
@@ -160,7 +160,7 @@
160
  },
161
  {
162
  "cell_type": "code",
163
- "execution_count": 4,
164
  "id": "c91a5802",
165
  "metadata": {},
166
  "outputs": [
@@ -168,9 +168,18 @@
168
  "name": "stderr",
169
  "output_type": "stream",
170
  "text": [
171
- " 0%| | 0/1 [00:16<?, ?it/s]\n",
 
 
 
 
 
 
 
172
  "You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
173
- " 0%| | 0/1 [00:18<?, ?it/s]\n"
 
 
174
  ]
175
  }
176
  ],
@@ -188,8 +197,8 @@
188
  "\n",
189
  "args = register_files(args)\n",
190
  "args = get_segmask(args)\n",
191
- "#args = histmatch(args)\n",
192
- "#args = get_heatmap(args)\n"
193
  ]
194
  },
195
  {
@@ -370,7 +379,7 @@
370
  },
371
  {
372
  "cell_type": "code",
373
- "execution_count": 4,
374
  "id": "8b5d382e",
375
  "metadata": {},
376
  "outputs": [],
@@ -385,7 +394,7 @@
385
  },
386
  {
387
  "cell_type": "code",
388
- "execution_count": 5,
389
  "id": "4cf061ec",
390
  "metadata": {},
391
  "outputs": [
@@ -419,7 +428,7 @@
419
  },
420
  {
421
  "cell_type": "code",
422
- "execution_count": 6,
423
  "id": "fac15515",
424
  "metadata": {},
425
  "outputs": [],
@@ -452,7 +461,7 @@
452
  },
453
  {
454
  "cell_type": "code",
455
- "execution_count": 7,
456
  "id": "eb80047b",
457
  "metadata": {},
458
  "outputs": [],
@@ -493,227 +502,23 @@
493
  },
494
  {
495
  "cell_type": "code",
496
- "execution_count": 8,
497
  "id": "dbcfc97f",
498
  "metadata": {},
499
  "outputs": [
500
  {
501
  "data": {
502
  "text/plain": [
503
- "[array([[[-0.43333182, -0.52243334, -1.4245864 , ..., -1.7698549 ,\n",
504
- " -1.6027895 , -1.257521 ],\n",
505
- " [-0.7786003 , -1.2797965 , -2.0817103 , ..., -2.0371597 ,\n",
506
- " -2.0817103 , -2.059435 ],\n",
507
- " [-1.4802749 , -1.658478 , -1.7809926 , ..., -1.4691372 ,\n",
508
- " -2.1596742 , -1.9591957 ],\n",
509
- " ...,\n",
510
- " [ 0.8697783 , 0.34630668, 0.6358867 , ..., 0.624749 ,\n",
511
- " 0.24606745, 0.03445129],\n",
512
- " [ 0.44654593, 0.23492976, 0.7806767 , ..., 1.2373221 ,\n",
513
- " 0.2906182 , -1.0124918 ],\n",
514
- " [ 0.3017559 , -0.2551287 , 0.27948052, ..., 1.1036698 ,\n",
515
- " -0.16602719, -1.2463834 ]],\n",
516
- " \n",
517
- " [[-0.84542644, -1.3800358 , -1.4357241 , ..., -2.5383558 ,\n",
518
- " -2.0817103 , -0.99021643],\n",
519
- " [-0.4444695 , -1.3466226 , -1.8255434 , ..., -1.335485 ,\n",
520
- " -1.6362027 , -1.4579996 ],\n",
521
- " [-1.2129703 , -1.7921304 , -2.0371597 , ..., 0.14582822,\n",
522
- " -0.04351256, -0.82315105],\n",
523
- " ...,\n",
524
- " [-0.4556072 , 1.4378005 , 1.5603153 , ..., 0.2906182 ,\n",
525
- " -0.85656416, -1.1906949 ],\n",
526
- " [-1.2352457 , -0.35536796, 1.1816336 , ..., -0.39991874,\n",
527
- " -1.3800358 , -1.2129703 ],\n",
528
- " [-1.2909342 , -1.1127311 , 0.5467852 , ..., -0.9568034 ,\n",
529
- " -1.4023111 , -0.9011149 ]],\n",
530
- " \n",
531
- " [[-2.1596742 , -1.914645 , -1.7809926 , ..., -1.0124918 ,\n",
532
- " 0.03445129, 0.44654593],\n",
533
- " [-2.0037465 , -1.8923696 , -1.8700942 , ..., -1.2129703 ,\n",
534
- " -1.1572819 , -0.50015795],\n",
535
- " [-1.8478189 , -1.9369203 , -1.9814711 , ..., -0.65608567,\n",
536
- " -1.2352457 , -1.4134488 ],\n",
537
- " ...,\n",
538
- " [-0.6338103 , -1.0124918 , -0.16602719, ..., 0.41313285,\n",
539
- " 0.13469052, -0.80087566],\n",
540
- " [-0.7451872 , -1.3577603 , -0.2996795 , ..., 0.34630668,\n",
541
- " 0.68043745, 0.45768362],\n",
542
- " [ 0.0121759 , -0.7674626 , -0.33309257, ..., 0.09013975,\n",
543
- " 0.46882132, 1.0034306 ]]], dtype=float32),\n",
544
- " array([[[-0.64494795, -0.789738 , -0.5447087 , ..., 0.0233136 ,\n",
545
- " 0.14582822, -0.35536796],\n",
546
- " [-0.8120134 , -0.7340495 , -0.42219412, ..., -0.31081718,\n",
547
- " 0.46882132, 0.15696591],\n",
548
- " [-0.08806333, 0.03445129, 0.12355283, ..., -0.2774041 ,\n",
549
- " 0.73612595, 0.7249882 ],\n",
550
- " ...,\n",
551
- " [ 0.09013975, 0.0233136 , -0.24399103, ..., 0.34630668,\n",
552
- " 0.914329 , 0.6358867 ],\n",
553
- " [ 0.2906182 , 0.335169 , 0.624749 , ..., 0.24606745,\n",
554
- " 0.9254667 , 0.9922929 ],\n",
555
- " [ 0.44654593, 0.70271283, 1.0479814 , ..., -0.09920102,\n",
556
- " 0.37971976, 0.70271283]],\n",
557
- " \n",
558
- " [[-0.23285334, -0.01009948, -0.1326141 , ..., 1.1593583 ,\n",
559
- " 1.5603153 , 1.5603153 ],\n",
560
- " [-0.23285334, 0.13469052, 0.1792413 , ..., 0.98115516,\n",
561
- " 1.5603153 , 1.5603153 ],\n",
562
- " [-0.36650565, -0.01009948, 0.190379 , ..., 0.8697783 ,\n",
563
- " 1.5603153 , 1.5603153 ],\n",
564
- " ...,\n",
565
- " [-0.16602719, -0.06578795, 0.190379 , ..., -1.4357241 ,\n",
566
- " -1.368898 , -1.6027895 ],\n",
567
- " [-0.2662664 , -0.35536796, 0.190379 , ..., -1.513688 ,\n",
568
- " -1.3800358 , -1.6362027 ],\n",
569
- " [-0.789738 , -0.7786003 , -0.17716487, ..., -1.7141665 ,\n",
570
- " -1.5693765 , -1.9591957 ]],\n",
571
- " \n",
572
- " [[-0.12147641, -0.01009948, 0.0233136 , ..., 0.769539 ,\n",
573
- " 0.8140898 , 0.2906182 ],\n",
574
- " [-0.35536796, -0.2774041 , -0.16602719, ..., 0.55792284,\n",
575
- " 0.68043745, 0.335169 ],\n",
576
- " [-0.18830256, 0.1681036 , 0.7918144 , ..., 0.70271283,\n",
577
- " 0.4910967 , 0.10127745],\n",
578
- " ...,\n",
579
- " [-0.97907877, -0.97907877, -0.9011149 , ..., -0.43333182,\n",
580
- " -0.5447087 , -0.1437518 ],\n",
581
- " [-1.1238688 , -1.1461442 , -1.0347673 , ..., -0.7117741 ,\n",
582
- " -0.12147641, 0.479959 ],\n",
583
- " [-0.8788395 , -0.62267256, -0.9122526 , ..., -0.92339027,\n",
584
- " -0.42219412, 0.22379206]]], dtype=float32),\n",
585
- " array([[[ 2.7948052e-01, 1.0368437e+00, 1.5603153e+00, ...,\n",
586
- " -1.2797965e+00, -8.3428878e-01, -2.4399103e-01],\n",
587
- " [ 6.7864366e-02, 3.4630668e-01, 1.1704960e+00, ...,\n",
588
- " -7.7860028e-01, -8.3428878e-01, -3.1081718e-01],\n",
589
- " [ 2.6834285e-01, 9.0139754e-02, 2.3492976e-01, ...,\n",
590
- " -6.3381028e-01, -9.5680338e-01, -4.4446951e-01],\n",
591
- " ...,\n",
592
- " [-1.3688980e+00, -1.2241080e+00, -9.4566566e-01, ...,\n",
593
- " -1.0570426e+00, -2.6626641e-01, 1.2707351e+00],\n",
594
- " [-1.0793180e+00, -8.8997722e-01, -1.1015934e+00, ...,\n",
595
- " -7.6746261e-01, -5.4650255e-02, 1.0257059e+00],\n",
596
- " [-7.3404950e-01, -6.5608567e-01, -1.0124918e+00, ...,\n",
597
- " -5.3357106e-01, 2.6834285e-01, 8.5864055e-01]],\n",
598
- " \n",
599
- " [[ 1.1593583e+00, 1.5603153e+00, 1.5603153e+00, ...,\n",
600
- " -4.5560721e-01, -1.1033872e-01, 8.6977828e-01],\n",
601
- " [ 8.4750289e-01, 1.5603153e+00, 1.5603153e+00, ...,\n",
602
- " -1.3020718e+00, -1.0459049e+00, 6.3588673e-01],\n",
603
- " [ 7.9002060e-02, 1.3469052e-01, -3.1081718e-01, ...,\n",
604
- " -1.3688980e+00, -1.0570426e+00, 4.9109671e-01],\n",
605
- " ...,\n",
606
- " [ 1.5603153e+00, 6.4702439e-01, -4.3512560e-02, ...,\n",
607
- " 5.2450979e-01, 8.9205366e-01, -2.9967949e-01],\n",
608
- " [ 1.2930106e+00, 6.1361134e-01, -4.6674490e-01, ...,\n",
609
- " 2.7948052e-01, 8.4750289e-01, 7.8067672e-01],\n",
610
- " [ 1.0382106e-03, 2.5720516e-01, -4.1105643e-01, ...,\n",
611
- " 6.2474900e-01, 1.5603153e+00, 1.5603153e+00]],\n",
612
- " \n",
613
- " [[-9.2339027e-01, -1.0236295e+00, -1.0347673e+00, ...,\n",
614
- " 2.6834285e-01, -1.9944026e-01, -1.1033872e-01],\n",
615
- " [-8.2315105e-01, -1.1127311e+00, -9.7907877e-01, ...,\n",
616
- " 1.2818729e+00, 3.2403129e-01, 6.7864366e-02],\n",
617
- " [-8.0087566e-01, -9.7907877e-01, -8.0087566e-01, ...,\n",
618
- " 1.5046268e+00, 6.1361134e-01, -1.7716487e-01],\n",
619
- " ...,\n",
620
- " [ 1.3932499e+00, 6.8043745e-01, 2.3492976e-01, ...,\n",
621
- " -1.8144057e+00, -1.2129703e+00, 2.3313597e-02],\n",
622
- " [ 1.5603153e+00, 1.5603153e+00, 1.5603153e+00, ...,\n",
623
- " -1.7587173e+00, -1.1572819e+00, 5.6906056e-01],\n",
624
- " [ 1.5603153e+00, 1.5603153e+00, 1.5603153e+00, ...,\n",
625
- " -1.8366811e+00, -8.7883949e-01, 1.3598367e+00]]], dtype=float32),\n",
626
- " array([[[-1.4245864 , -1.3466226 , -1.079318 , ..., -0.36650565,\n",
627
- " -0.66722333, -1.079318 ],\n",
628
- " [-1.1350064 , -1.2352457 , -1.4357241 , ..., -0.01009948,\n",
629
- " -0.41105643, -0.7563249 ],\n",
630
- " [-1.0459049 , -0.84542644, -1.368898 , ..., -0.22171564,\n",
631
- " -0.4778826 , -0.82315105],\n",
632
- " ...,\n",
633
- " [-2.2599134 , -2.1151235 , -1.6139272 , ..., 0.5245098 ,\n",
634
- " 0.41313285, 0.37971976],\n",
635
- " [-2.560631 , -2.204225 , -1.7921304 , ..., 0.10127745,\n",
636
- " 0.3128936 , 0.37971976],\n",
637
- " [-2.2933266 , -2.1596742 , -2.0705726 , ..., 0.0233136 ,\n",
638
- " 0.3017559 , 0.4910967 ]],\n",
639
- " \n",
640
- " [[-1.0570426 , -1.3911734 , -1.658478 , ..., 0.769539 ,\n",
641
- " 0.7806767 , 0.95887977],\n",
642
- " [-1.6362027 , -1.4134488 , -1.5693765 , ..., 1.0145682 ,\n",
643
- " 1.1816336 , 1.1370829 ],\n",
644
- " [-1.914645 , -1.1127311 , -1.224108 , ..., 0.68043745,\n",
645
- " 0.9922929 , 0.85864055],\n",
646
- " ...,\n",
647
- " [-2.1262612 , -2.026022 , -1.7921304 , ..., 0.4910967 ,\n",
648
- " 0.9031913 , 1.0702567 ],\n",
649
- " [-2.3712904 , -2.5494936 , -2.304464 , ..., 0.70271283,\n",
650
- " 0.914329 , 1.0479814 ],\n",
651
- " [-2.2710512 , -2.3267395 , -2.4715295 , ..., 1.0145682 ,\n",
652
- " 0.7806767 , 0.7472636 ]],\n",
653
- " \n",
654
- " [[-0.94566566, -0.6003972 , -0.85656416, ..., 1.2373221 ,\n",
655
- " 1.5603153 , 1.5603153 ],\n",
656
- " [-1.1795572 , -0.38878104, -0.37764335, ..., 1.5603153 ,\n",
657
- " 1.5603153 , 1.5603153 ],\n",
658
- " [-1.335485 , -1.3577603 , -0.32195488, ..., 1.5603153 ,\n",
659
- " 1.5603153 , 1.5603153 ],\n",
660
- " ...,\n",
661
- " [-1.224108 , -1.1795572 , -1.4914126 , ..., -0.04351256,\n",
662
- " -0.5112957 , -0.64494795],\n",
663
- " [-1.3132095 , -1.8589565 , -1.7587173 , ..., -0.03237487,\n",
664
- " -0.42219412, -0.1437518 ],\n",
665
- " [-1.6473403 , -2.4381166 , -2.1262612 , ..., 0.37971976,\n",
666
- " 0.34630668, 0.46882132]]], dtype=float32),\n",
667
- " array([[[-0.84542644, -0.8120134 , -0.6894987 , ..., 0.46882132,\n",
668
- " -0.01009948, -0.39991874],\n",
669
- " [-0.65608567, -0.8120134 , -0.6338103 , ..., 0.59133595,\n",
670
- " 0.04558898, -0.5112957 ],\n",
671
- " [-0.96794105, -0.8342888 , -0.35536796, ..., 0.5801982 ,\n",
672
- " 0.13469052, -0.37764335],\n",
673
- " ...,\n",
674
- " [-0.09920102, -0.11033872, -1.224108 , ..., -0.92339027,\n",
675
- " -1.2686588 , -0.85656416],\n",
676
- " [ 0.0121759 , -0.01009948, -1.0236295 , ..., -1.2129703 ,\n",
677
- " -1.4023111 , -0.7340495 ],\n",
678
- " [ 0.23492976, -0.07692564, -0.70063645, ..., -0.70063645,\n",
679
- " -0.934528 , -1.0124918 ]],\n",
680
- " \n",
681
- " [[-1.6139272 , -0.97907877, -0.7340495 , ..., -0.8342888 ,\n",
682
- " -0.6894987 , -1.0347673 ],\n",
683
- " [-1.0459049 , -0.5892595 , -0.789738 , ..., -1.257521 ,\n",
684
- " -1.0124918 , -1.4357241 ],\n",
685
- " [-0.42219412, -0.5892595 , -1.4134488 , ..., -1.5025504 ,\n",
686
- " -1.2463834 , -1.4802749 ],\n",
687
- " ...,\n",
688
- " [ 0.10127745, 0.40199515, 0.13469052, ..., 0.25720516,\n",
689
- " 0.55792284, 0.12355283],\n",
690
- " [-0.2885418 , -0.2551287 , 0.41313285, ..., 0.46882132,\n",
691
- " -0.33309257, -1.2797965 ],\n",
692
- " [-0.2996795 , -0.64494795, 0.9366044 , ..., 0.56906056,\n",
693
- " -0.41105643, -1.3466226 ]],\n",
694
- " \n",
695
- " [[-1.1238688 , -1.4134488 , -1.7364419 , ..., -1.1350064 ,\n",
696
- " -1.6918911 , -2.3156018 ],\n",
697
- " [-1.4914126 , -1.3020718 , -0.99021643, ..., -1.658478 ,\n",
698
- " -1.7698549 , -1.8812319 ],\n",
699
- " [-1.5582387 , -0.85656416, -0.43333182, ..., -1.6027895 ,\n",
700
- " -1.914645 , -1.7698549 ],\n",
701
- " ...,\n",
702
- " [ 0.6358867 , 0.9031913 , 1.0702567 , ..., 0.11241514,\n",
703
- " 0.07900206, 0.34630668],\n",
704
- " [ 0.70271283, 1.1593583 , 0.9254667 , ..., 0.335169 ,\n",
705
- " 0.41313285, 0.23492976],\n",
706
- " [ 0.35744438, 1.1482205 , 0.8697783 , ..., -0.2662664 ,\n",
707
- " 0.56906056, 0.624749 ]]], dtype=float32)]"
708
  ]
709
  },
710
- "execution_count": 8,
711
  "metadata": {},
712
  "output_type": "execute_result"
713
  }
714
  ],
715
  "source": [
716
- "patches_top_5"
717
  ]
718
  },
719
  {
 
160
  },
161
  {
162
  "cell_type": "code",
163
+ "execution_count": 3,
164
  "id": "c91a5802",
165
  "metadata": {},
166
  "outputs": [
 
168
  "name": "stderr",
169
  "output_type": "stream",
170
  "text": [
171
+ " 0%| | 0/1 [00:00<?, ?it/s]"
172
+ ]
173
+ },
174
+ {
175
+ "name": "stderr",
176
+ "output_type": "stream",
177
+ "text": [
178
+ "100%|██████████| 1/1 [00:02<00:00, 2.45s/it]\n",
179
  "You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
180
+ "100%|██████████| 1/1 [00:38<00:00, 38.82s/it]\n",
181
+ "100%|██████████| 1/1 [00:00<00:00, 4.11it/s]\n",
182
+ "100%|██████████| 1/1 [00:00<00:00, 9.82it/s]\n"
183
  ]
184
  }
185
  ],
 
197
  "\n",
198
  "args = register_files(args)\n",
199
  "args = get_segmask(args)\n",
200
+ "args = histmatch(args)\n",
201
+ "args = get_heatmap(args)\n"
202
  ]
203
  },
204
  {
 
379
  },
380
  {
381
  "cell_type": "code",
382
+ "execution_count": 5,
383
  "id": "8b5d382e",
384
  "metadata": {},
385
  "outputs": [],
 
394
  },
395
  {
396
  "cell_type": "code",
397
+ "execution_count": 6,
398
  "id": "4cf061ec",
399
  "metadata": {},
400
  "outputs": [
 
428
  },
429
  {
430
  "cell_type": "code",
431
+ "execution_count": 7,
432
  "id": "fac15515",
433
  "metadata": {},
434
  "outputs": [],
 
461
  },
462
  {
463
  "cell_type": "code",
464
+ "execution_count": 8,
465
  "id": "eb80047b",
466
  "metadata": {},
467
  "outputs": [],
 
502
  },
503
  {
504
  "cell_type": "code",
505
+ "execution_count": 9,
506
  "id": "dbcfc97f",
507
  "metadata": {},
508
  "outputs": [
509
  {
510
  "data": {
511
  "text/plain": [
512
+ "list"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
  ]
514
  },
515
+ "execution_count": 9,
516
  "metadata": {},
517
  "output_type": "execute_result"
518
  }
519
  ],
520
  "source": [
521
+ "type(patches_top_5)"
522
  ]
523
  },
524
  {
tests/test_run.py CHANGED
@@ -20,7 +20,15 @@ def test_run_pirads_training():
20
 
21
  # Run the script with the config
22
  result = subprocess.run(
23
- [sys.executable, str(script_path), "--mode", "train", "--config", str(config_path), "--dry_run", "True" ],
 
 
 
 
 
 
 
 
24
  capture_output=True,
25
  text=True,
26
  )
@@ -28,6 +36,7 @@ def test_run_pirads_training():
28
  # Check that it ran without errors
29
  assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
30
 
 
31
  def test_run_pirads_inference():
32
  """
33
  Test that run_cspca.py runs without crashing using an existing YAML config.
@@ -45,14 +54,23 @@ def test_run_pirads_inference():
45
 
46
  # Run the script with the config
47
  result = subprocess.run(
48
- [sys.executable, str(script_path), "--mode", "test", "--config", str(config_path), "--dry_run", "True" ],
 
 
 
 
 
 
 
 
49
  capture_output=True,
50
  text=True,
51
  )
52
 
53
  # Check that it ran without errors
54
  assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
55
-
 
56
  def test_run_cspca_training():
57
  """
58
  Test that run_cspca.py runs without crashing using an existing YAML config.
@@ -70,14 +88,23 @@ def test_run_cspca_training():
70
 
71
  # Run the script with the config
72
  result = subprocess.run(
73
- [sys.executable, str(script_path), "--mode", "train", "--config", str(config_path), "--dry_run", "True" ],
 
 
 
 
 
 
 
 
74
  capture_output=True,
75
  text=True,
76
  )
77
 
78
  # Check that it ran without errors
79
  assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
80
-
 
81
  def test_run_cspca_inference():
82
  """
83
  Test that run_cspca.py runs without crashing using an existing YAML config.
@@ -95,12 +122,18 @@ def test_run_cspca_inference():
95
 
96
  # Run the script with the config
97
  result = subprocess.run(
98
- [sys.executable, str(script_path), "--mode", "test", "--config", str(config_path), "--dry_run", "True" ],
 
 
 
 
 
 
 
 
99
  capture_output=True,
100
  text=True,
101
  )
102
 
103
  # Check that it ran without errors
104
  assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
105
-
106
-
 
20
 
21
  # Run the script with the config
22
  result = subprocess.run(
23
+ [
24
+ sys.executable,
25
+ str(script_path),
26
+ "--mode",
27
+ "train",
28
+ "--config",
29
+ str(config_path),
30
+ "--dry_run",
31
+ ],
32
  capture_output=True,
33
  text=True,
34
  )
 
36
  # Check that it ran without errors
37
  assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
38
 
39
+
40
  def test_run_pirads_inference():
41
  """
42
  Test that run_cspca.py runs without crashing using an existing YAML config.
 
54
 
55
  # Run the script with the config
56
  result = subprocess.run(
57
+ [
58
+ sys.executable,
59
+ str(script_path),
60
+ "--mode",
61
+ "test",
62
+ "--config",
63
+ str(config_path),
64
+ "--dry_run",
65
+ ],
66
  capture_output=True,
67
  text=True,
68
  )
69
 
70
  # Check that it ran without errors
71
  assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
72
+
73
+
74
  def test_run_cspca_training():
75
  """
76
  Test that run_cspca.py runs without crashing using an existing YAML config.
 
88
 
89
  # Run the script with the config
90
  result = subprocess.run(
91
+ [
92
+ sys.executable,
93
+ str(script_path),
94
+ "--mode",
95
+ "train",
96
+ "--config",
97
+ str(config_path),
98
+ "--dry_run",
99
+ ],
100
  capture_output=True,
101
  text=True,
102
  )
103
 
104
  # Check that it ran without errors
105
  assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
106
+
107
+
108
  def test_run_cspca_inference():
109
  """
110
  Test that run_cspca.py runs without crashing using an existing YAML config.
 
122
 
123
  # Run the script with the config
124
  result = subprocess.run(
125
+ [
126
+ sys.executable,
127
+ str(script_path),
128
+ "--mode",
129
+ "test",
130
+ "--config",
131
+ str(config_path),
132
+ "--dry_run",
133
+ ],
134
  capture_output=True,
135
  text=True,
136
  )
137
 
138
  # Check that it ran without errors
139
  assert result.returncode == 0, f"Script failed with:\n{result.stderr}"