| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| This example shows how to efficiently compute Dice scores for pairs of segmentation prediction |
| and references in multi-processing based on MONAI's metrics API. |
| It can even run on multi-nodes. |
| Main steps to set up the distributed data parallel: |
| |
| - Execute `torchrun` to create processes on every node for every process. |
| It receives parameters as below: |
| `--nproc_per_node=NUM_PROCESSES_PER_NODE` |
| `--nnodes=NUM_NODES` |
| For more details, refer to https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py. |
| Alternatively, we can also use `torch.multiprocessing.spawn` to start program, but it that case, need to handle |
| all the above parameters and compute `rank` manually, then set to `init_process_group`, etc. |
| `torchrun` is even more efficient than `torch.multiprocessing.spawn`. |
| - Use `init_process_group` to initialize every process. |
| - Partition the saved predictions and labels into ranks for parallel computation. |
| - Compute `Dice Metric` on every process, reduce the results after synchronization. |
| |
| Note: |
| `torchrun` will launch `nnodes * nproc_per_node = world_size` processes in total. |
| Example script to execute this program on a single node with 2 processes: |
| `torchrun --nproc_per_node=2 compute_metric.py` |
| |
| Referring to: https://pytorch.org/tutorials/intermediate/ddp_tutorial.html |
| |
| """ |
|
|
| import os |
|
|
| import torch |
| import torch.distributed as dist |
| from monai.data import partition_dataset |
| from monai.handlers import write_metrics_reports |
| from monai.metrics import DiceMetric |
| from monai.transforms import ( |
| AddLabelNamesd, |
| AsDiscreted, |
| Compose, |
| EnsureChannelFirstd, |
| LoadImaged, |
| Orientationd, |
| ToDeviced, |
| ) |
| from monai.utils import string_list_all_gather |
| from scripts.monai_utils import CopyFilenamesd |
|
|
|
|
| def compute(datalist, output_dir): |
| |
| local_rank = int(os.environ["LOCAL_RANK"]) |
| |
| dist.init_process_group(backend="nccl", init_method="env://") |
|
|
| |
| data_part = partition_dataset( |
| data=datalist, num_partitions=dist.get_world_size(), shuffle=False, even_divisible=False |
| )[dist.get_rank()] |
|
|
| device = torch.device(f"cuda:{local_rank}") |
| torch.cuda.set_device(device) |
| |
| |
| transforms = Compose( |
| [ |
| CopyFilenamesd(keys="label"), |
| LoadImaged(keys=["pred", "label"]), |
| ToDeviced(keys=["pred", "label"], device=device), |
| EnsureChannelFirstd(keys=["pred", "label"]), |
| Orientationd(keys=("pred", "label"), axcodes="RAS"), |
| AsDiscreted(keys=("pred", "label"), argmax=(False, False), to_onehot=(4, 4)), |
| ] |
| ) |
|
|
| data_part = [transforms(item) for item in data_part] |
|
|
| |
| metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False) |
| metric(y_pred=[i["pred"] for i in data_part], y=[i["label"] for i in data_part]) |
| filenames = [item["filename"] for item in data_part] |
| |
| result = metric.aggregate().item() |
| filenames = string_list_all_gather(strings=filenames) |
|
|
| if local_rank == 0: |
| print("mean dice: ", result) |
| |
| write_metrics_reports( |
| save_dir=output_dir, |
| images=filenames, |
| metrics={"mean_dice": result}, |
| metric_details={"mean_dice": metric.get_buffer()}, |
| summary_ops="*", |
| ) |
|
|
| metric.reset() |
|
|
| dist.destroy_process_group() |
|
|
|
|
| def compute_single_node(datalist, output_dir): |
| local_rank = int(os.environ["LOCAL_RANK"]) |
|
|
| filenames = [d["label"].split("/")[-1] for d in datalist] |
|
|
| data_part = datalist |
| device = torch.device(f"cuda:{local_rank}") |
| torch.cuda.set_device(device) |
|
|
| |
| labels = {"background": 0, "liver": 1, "spleen": 2, "pancreas": 3} |
| transforms = Compose( |
| [ |
| LoadImaged(keys=["pred", "label"]), |
| ToDeviced(keys=["pred", "label"], device=device), |
| EnsureChannelFirstd(keys=["pred", "label"]), |
| Orientationd(keys=("pred", "label"), axcodes="RAS"), |
| AddLabelNamesd(keys=("pred", "label"), label_names=labels), |
| AsDiscreted(keys=("pred", "label"), argmax=(False, False), to_onehot=(4, 4)), |
| ] |
| ) |
| data_part = [transforms(item) for item in data_part] |
| |
| metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False) |
| for d in datalist: |
| d = transforms(d) |
| metric(y_pred=[d["pred"]], y=[d["label"]]) |
|
|
| result = metric.aggregate().item() |
|
|
| print("mean dice: ", result) |
| write_metrics_reports( |
| save_dir=output_dir, |
| images=filenames, |
| metrics={"mean_dice": result}, |
| metric_details={"mean_dice": metric.get_buffer()}, |
| summary_ops="*", |
| ) |
|
|
| metric.reset() |
|
|