|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| from __future__ import annotations
|
|
|
| import argparse
|
| import json
|
| import logging
|
| import os
|
| from pathlib import Path
|
|
|
| import monai
|
| import nibabel as nib
|
| import numpy as np
|
| import torch
|
| import torch.distributed as dist
|
| from monai.transforms import Compose
|
| from monai.utils import set_determinism
|
|
|
| from .diff_model_setting import initialize_distributed, load_config, setup_logging
|
| from .utils import define_instance
|
|
|
|
|
| set_determinism(seed=0)
|
|
|
|
|
| def create_transforms(dim: tuple = None) -> Compose:
|
| """
|
| Create a set of MONAI transforms for preprocessing.
|
|
|
| Args:
|
| dim (tuple, optional): New dimensions for resizing. Defaults to None.
|
|
|
| Returns:
|
| Compose: Composed MONAI transforms.
|
| """
|
| if dim:
|
| return Compose(
|
| [
|
| monai.transforms.LoadImaged(keys="image"),
|
| monai.transforms.EnsureChannelFirstd(keys="image"),
|
| monai.transforms.Orientationd(keys="image", axcodes="RAS"),
|
| monai.transforms.EnsureTyped(keys="image", dtype=torch.float32),
|
| monai.transforms.ScaleIntensityRanged(
|
| keys="image", a_min=-1000, a_max=1000, b_min=0, b_max=1, clip=True
|
| ),
|
| monai.transforms.Resized(keys="image", spatial_size=dim, mode="trilinear"),
|
| ]
|
| )
|
| else:
|
| return Compose(
|
| [
|
| monai.transforms.LoadImaged(keys="image"),
|
| monai.transforms.EnsureChannelFirstd(keys="image"),
|
| monai.transforms.Orientationd(keys="image", axcodes="RAS"),
|
| ]
|
| )
|
|
|
|
|
| def round_number(number: int, base_number: int = 128) -> int:
|
| """
|
| Round the number to the nearest multiple of the base number, with a minimum value of the base number.
|
|
|
| Args:
|
| number (int): Number to be rounded.
|
| base_number (int): Number to be common divisor.
|
|
|
| Returns:
|
| int: Rounded number.
|
| """
|
| new_number = max(round(float(number) / float(base_number)), 1.0) * float(base_number)
|
| return int(new_number)
|
|
|
|
|
| def load_filenames(data_list_path: str) -> list:
|
| """
|
| Load filenames from the JSON data list.
|
|
|
| Args:
|
| data_list_path (str): Path to the JSON data list file.
|
|
|
| Returns:
|
| list: List of filenames.
|
| """
|
| with open(data_list_path, "r") as file:
|
| json_data = json.load(file)
|
| filenames_raw = json_data["training"]
|
| return [_item["image"] for _item in filenames_raw]
|
|
|
|
|
| def process_file(
|
| filepath: str,
|
| args: argparse.Namespace,
|
| autoencoder: torch.nn.Module,
|
| device: torch.device,
|
| plain_transforms: Compose,
|
| new_transforms: Compose,
|
| logger: logging.Logger,
|
| ) -> None:
|
| """
|
| Process a single file to create training data.
|
|
|
| Args:
|
| filepath (str): Path to the file to be processed.
|
| args (argparse.Namespace): Configuration arguments.
|
| autoencoder (torch.nn.Module): Autoencoder model.
|
| device (torch.device): Device to process the file on.
|
| plain_transforms (Compose): Plain transforms.
|
| new_transforms (Compose): New transforms.
|
| logger (logging.Logger): Logger for logging information.
|
| """
|
| out_filename_base = filepath.replace(".gz", "").replace(".nii", "")
|
| out_filename_base = os.path.join(args.embedding_base_dir, out_filename_base)
|
| out_filename = out_filename_base + "_emb.nii.gz"
|
|
|
| if os.path.isfile(out_filename):
|
| return
|
|
|
| test_data = {"image": os.path.join(args.data_base_dir, filepath)}
|
| transformed_data = plain_transforms(test_data)
|
| nda = transformed_data["image"]
|
|
|
| dim = [int(nda.meta["dim"][_i]) for _i in range(1, 4)]
|
| spacing = [float(nda.meta["pixdim"][_i]) for _i in range(1, 4)]
|
|
|
| logger.info(f"old dim: {dim}, old spacing: {spacing}")
|
|
|
| new_data = new_transforms(test_data)
|
| nda_image = new_data["image"]
|
|
|
| new_affine = nda_image.meta["affine"].numpy()
|
| nda_image = nda_image.numpy().squeeze()
|
|
|
| logger.info(f"new dim: {nda_image.shape}, new affine: {new_affine}")
|
|
|
| try:
|
| out_path = Path(out_filename)
|
| out_path.parent.mkdir(parents=True, exist_ok=True)
|
| logger.info(f"out_filename: {out_filename}")
|
|
|
| with torch.amp.autocast("cuda"):
|
| pt_nda = torch.from_numpy(nda_image).float().to(device).unsqueeze(0).unsqueeze(0)
|
| z = autoencoder.encode_stage_2_inputs(pt_nda)
|
| logger.info(f"z: {z.size()}, {z.dtype}")
|
|
|
| out_nda = z.squeeze().cpu().detach().numpy().transpose(1, 2, 3, 0)
|
| out_img = nib.Nifti1Image(np.float32(out_nda), affine=new_affine)
|
| nib.save(out_img, out_filename)
|
| except Exception as e:
|
| logger.error(f"Error processing {filepath}: {e}")
|
|
|
|
|
| @torch.inference_mode()
|
| def diff_model_create_training_data(
|
| env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int
|
| ) -> None:
|
| """
|
| Create training data for the diffusion model.
|
|
|
| Args:
|
| env_config_path (str): Path to the environment configuration file.
|
| model_config_path (str): Path to the model configuration file.
|
| model_def_path (str): Path to the model definition file.
|
| """
|
| args = load_config(env_config_path, model_config_path, model_def_path)
|
| local_rank, world_size, device = initialize_distributed(num_gpus=num_gpus)
|
| logger = setup_logging("creating training data")
|
| logger.info(f"Using device {device}")
|
|
|
| autoencoder = define_instance(args, "autoencoder_def").to(device)
|
| try:
|
| checkpoint_autoencoder = torch.load(args.trained_autoencoder_path, weights_only=True)
|
| autoencoder.load_state_dict(checkpoint_autoencoder)
|
| except Exception:
|
| logger.error("The trained_autoencoder_path does not exist!")
|
|
|
| Path(args.embedding_base_dir).mkdir(parents=True, exist_ok=True)
|
|
|
| filenames_raw = load_filenames(args.json_data_list)
|
| logger.info(f"filenames_raw: {filenames_raw}")
|
|
|
| plain_transforms = create_transforms(dim=None)
|
|
|
| for _iter in range(len(filenames_raw)):
|
| if _iter % world_size != local_rank:
|
| continue
|
|
|
| filepath = filenames_raw[_iter]
|
| new_dim = tuple(
|
| round_number(
|
| int(plain_transforms({"image": os.path.join(args.data_base_dir, filepath)})["image"].meta["dim"][_i])
|
| )
|
| for _i in range(1, 4)
|
| )
|
| new_transforms = create_transforms(new_dim)
|
|
|
| process_file(filepath, args, autoencoder, device, plain_transforms, new_transforms, logger)
|
|
|
| if dist.is_initialized():
|
| dist.destroy_process_group()
|
|
|
|
|
| if __name__ == "__main__":
|
| parser = argparse.ArgumentParser(description="Diffusion Model Training Data Creation")
|
| parser.add_argument(
|
| "--env_config",
|
| type=str,
|
| default="./configs/environment_maisi_diff_model_train.json",
|
| help="Path to environment configuration file",
|
| )
|
| parser.add_argument(
|
| "--model_config",
|
| type=str,
|
| default="./configs/config_maisi_diff_model_train.json",
|
| help="Path to model training/inference configuration",
|
| )
|
| parser.add_argument(
|
| "--model_def", type=str, default="./configs/config_maisi.json", help="Path to model definition file"
|
| )
|
| parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use for distributed training")
|
|
|
| args = parser.parse_args()
|
| diff_model_create_training_data(args.env_config, args.model_config, args.model_def, args.num_gpus)
|
|
|