| |
| ''' |
| @File : infer_with_medim.py |
| @Time : 2024/09/08 11:31:02 |
| @Author : Haoyu Wang |
| @Contact : small_dark@sina.com |
| @Brief : Example code for inference with MedIM |
| ''' |
|
|
| import medim |
| import torch |
| import numpy as np |
| import torch.nn.functional as F |
| import torchio as tio |
| import os.path as osp |
| import os |
| from torchio.data.io import sitk_to_nib |
| import SimpleITK as sitk |
|
|
|
|
| def random_sample_next_click(prev_mask, gt_mask): |
| """ |
| Randomly sample one click from ground-truth mask and previous seg mask |
| |
| Arguements: |
| prev_mask: (torch.Tensor) [H,W,D] previous mask that SAM-Med3D predict |
| gt_mask: (torch.Tensor) [H,W,D] ground-truth mask for this image |
| """ |
| prev_mask = prev_mask > 0 |
| true_masks = gt_mask > 0 |
|
|
| if (not true_masks.any()): |
| raise ValueError("Cannot find true value in the ground-truth!") |
|
|
| fn_masks = torch.logical_and(true_masks, torch.logical_not(prev_mask)) |
| fp_masks = torch.logical_and(torch.logical_not(true_masks), prev_mask) |
|
|
| to_point_mask = torch.logical_or(fn_masks, fp_masks) |
|
|
| all_points = torch.argwhere(to_point_mask) |
| point = all_points[np.random.randint(len(all_points))] |
|
|
| if fn_masks[point[0], point[1], point[2]]: |
| is_positive = True |
| else: |
| is_positive = False |
|
|
| sampled_point = point.clone().detach().reshape(1, 1, 3) |
| sampled_label = torch.tensor([ |
| int(is_positive), |
| ]).reshape(1, 1) |
|
|
| return sampled_point, sampled_label |
|
|
|
|
| def sam_model_infer(model, |
| roi_image, |
| prompt_generator=random_sample_next_click, |
| roi_gt=None, |
| prev_low_res_mask=None): |
| ''' |
| Inference for SAM-Med3D, inputs prompt points with its labels (positive/negative for each points) |
| |
| # roi_image: (torch.Tensor) cropped image, shape [1,1,128,128,128] |
| # prompt_points_and_labels: (Tuple(torch.Tensor, torch.Tensor)) |
| ''' |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print("using device", device) |
| model = model.to(device) |
|
|
| with torch.no_grad(): |
| input_tensor = roi_image.to(device) |
| image_embeddings = model.image_encoder(input_tensor) |
|
|
| points_coords, points_labels = torch.zeros(1, 0, |
| 3).to(device), torch.zeros( |
| 1, 0).to(device) |
| new_points_co, new_points_la = torch.Tensor( |
| [[[64, 64, 64]]]).to(device), torch.Tensor([[1]]).to(torch.int64) |
| if (roi_gt is not None): |
| prev_low_res_mask = prev_low_res_mask if ( |
| prev_low_res_mask is not None) else torch.zeros( |
| 1, 1, roi_image.shape[2] // 4, roi_image.shape[3] // |
| 4, roi_image.shape[4] // 4) |
| new_points_co, new_points_la = prompt_generator( |
| torch.zeros_like(roi_image)[0, 0], roi_gt[0, 0]) |
| new_points_co, new_points_la = new_points_co.to( |
| device), new_points_la.to(device) |
| points_coords = torch.cat([points_coords, new_points_co], dim=1) |
| points_labels = torch.cat([points_labels, new_points_la], dim=1) |
|
|
| sparse_embeddings, dense_embeddings = model.prompt_encoder( |
| points=[points_coords, points_labels], |
| boxes=None, |
| masks=prev_low_res_mask.to(device), |
| |
| ) |
|
|
| low_res_masks, _ = model.mask_decoder( |
| image_embeddings=image_embeddings, |
| image_pe=model.prompt_encoder.get_dense_pe(), |
| sparse_prompt_embeddings=sparse_embeddings, |
| dense_prompt_embeddings=dense_embeddings, |
| ) |
|
|
| prev_mask = F.interpolate(low_res_masks, |
| size=roi_image.shape[-3:], |
| mode='trilinear', |
| align_corners=False) |
|
|
| |
| medsam_seg_prob = torch.sigmoid(prev_mask) |
| medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze() |
| medsam_seg_mask = (medsam_seg_prob > 0.5).astype(np.uint8) |
|
|
| return medsam_seg_mask |
|
|
|
|
| def resample_nii(input_path: str, |
| output_path: str, |
| target_spacing: tuple = (1.5, 1.5, 1.5), |
| n=None, |
| reference_image=None, |
| mode="linear"): |
| """ |
| Resample a nii.gz file to a specified spacing using torchio. |
| |
| Parameters: |
| - input_path: Path to the input .nii.gz file. |
| - output_path: Path to save the resampled .nii.gz file. |
| - target_spacing: Desired spacing for resampling. Default is (1.5, 1.5, 1.5). |
| """ |
| |
| subject = tio.Subject(img=tio.ScalarImage(input_path)) |
| resampler = tio.Resample(target=target_spacing, image_interpolation=mode) |
| resampled_subject = resampler(subject) |
|
|
| if (n != None): |
| image = resampled_subject.img |
| tensor_data = image.data |
| if (isinstance(n, int)): |
| n = [n] |
| for ni in n: |
| tensor_data[tensor_data == ni] = -1 |
| tensor_data[tensor_data != -1] = 0 |
| tensor_data[tensor_data != 0] = 1 |
| save_image = tio.ScalarImage(tensor=tensor_data, affine=image.affine) |
| reference_size = reference_image.shape[ |
| 1:] |
| cropper_or_padder = tio.CropOrPad(reference_size) |
| save_image = cropper_or_padder(save_image) |
| else: |
| save_image = resampled_subject.img |
|
|
| save_image.save(output_path) |
|
|
|
|
| def read_data_from_nii(img_path, gt_path): |
| sitk_image = sitk.ReadImage(img_path) |
| sitk_label = sitk.ReadImage(gt_path) |
|
|
| if sitk_image.GetOrigin() != sitk_label.GetOrigin(): |
| sitk_image.SetOrigin(sitk_label.GetOrigin()) |
| if sitk_image.GetDirection() != sitk_label.GetDirection(): |
| sitk_image.SetDirection(sitk_label.GetDirection()) |
|
|
| sitk_image_arr, _ = sitk_to_nib(sitk_image) |
| sitk_label_arr, _ = sitk_to_nib(sitk_label) |
|
|
| subject = tio.Subject( |
| image=tio.ScalarImage(tensor=sitk_image_arr), |
| label=tio.LabelMap(tensor=sitk_label_arr), |
| ) |
| crop_transform = tio.CropOrPad(mask_name='label', |
| target_shape=(128, 128, 128)) |
| padding_params, cropping_params = crop_transform.compute_crop_or_pad( |
| subject) |
| if (cropping_params is None): cropping_params = (0, 0, 0, 0, 0, 0) |
| if (padding_params is None): padding_params = (0, 0, 0, 0, 0, 0) |
|
|
| infer_transform = tio.Compose([ |
| crop_transform, |
| tio.ZNormalization(masking_method=lambda x: x > 0), |
| ]) |
| subject_roi = infer_transform(subject) |
|
|
| img3D_roi, gt3D_roi = subject_roi.image.data.clone().detach().unsqueeze( |
| 1), subject_roi.label.data.clone().detach().unsqueeze(1) |
| ori_roi_offset = ( |
| cropping_params[0], |
| cropping_params[0] + 128 - padding_params[0] - padding_params[1], |
| cropping_params[2], |
| cropping_params[2] + 128 - padding_params[2] - padding_params[3], |
| cropping_params[4], |
| cropping_params[4] + 128 - padding_params[4] - padding_params[5], |
| ) |
|
|
| meta_info = { |
| "image_path": img_path, |
| "image_shape": sitk_image_arr.shape[1:], |
| "origin": sitk_label.GetOrigin(), |
| "direction": sitk_label.GetDirection(), |
| "spacing": sitk_label.GetSpacing(), |
| "padding_params": padding_params, |
| "cropping_params": cropping_params, |
| "ori_roi": ori_roi_offset, |
| } |
| return ( |
| img3D_roi, |
| gt3D_roi, |
| meta_info, |
| ) |
|
|
|
|
| def save_numpy_to_nifti(in_arr: np.array, out_path, meta_info): |
| |
| |
| ori_arr = np.transpose(in_arr.squeeze(), (2, 1, 0)) |
| out = sitk.GetImageFromArray(ori_arr) |
| sitk_meta_translator = lambda x: [float(i) for i in x] |
| out.SetOrigin(sitk_meta_translator(meta_info["origin"])) |
| out.SetDirection(sitk_meta_translator(meta_info["direction"])) |
| out.SetSpacing(sitk_meta_translator(meta_info["spacing"])) |
| sitk.WriteImage(out, out_path) |
|
|
|
|
| def data_preprocess(img_path, gt_path, category_index): |
| target_img_path = osp.join( |
| osp.dirname(img_path), |
| osp.basename(img_path).replace(".nii.gz", "_resampled.nii.gz")) |
| target_gt_path = osp.join( |
| osp.dirname(gt_path), |
| osp.basename(gt_path).replace(".nii.gz", "_resampled.nii.gz")) |
| resample_nii(img_path, target_img_path) |
| resample_nii(gt_path, |
| target_gt_path, |
| n=category_index, |
| reference_image=tio.ScalarImage(target_img_path), |
| mode="nearest") |
| roi_image, roi_label, meta_info = read_data_from_nii( |
| target_img_path, target_gt_path) |
| return roi_image, roi_label, meta_info |
|
|
|
|
| def data_postprocess(roi_pred, meta_info, output_path, ori_img_path): |
| os.makedirs(osp.dirname(output_path), exist_ok=True) |
| pred3D_full = np.zeros(meta_info["image_shape"]) |
| padding_params = meta_info["padding_params"] |
| unpadded_pred = roi_pred[padding_params[0] : 128-padding_params[1], |
| padding_params[2] : 128-padding_params[3], |
| padding_params[4] : 128-padding_params[5]] |
| ori_roi = meta_info["ori_roi"] |
| pred3D_full[ori_roi[0]:ori_roi[1], ori_roi[2]:ori_roi[3], |
| ori_roi[4]:ori_roi[5]] = unpadded_pred |
|
|
| sitk_image = sitk.ReadImage(ori_img_path) |
| ori_meta_info = { |
| "image_path": ori_img_path, |
| "image_shape": sitk_image.GetSize(), |
| "origin": sitk_image.GetOrigin(), |
| "direction": sitk_image.GetDirection(), |
| "spacing": sitk_image.GetSpacing(), |
| } |
| pred3D_full_ori = F.interpolate( |
| torch.Tensor(pred3D_full)[None][None], |
| size=ori_meta_info["image_shape"], |
| mode='nearest').cpu().numpy().squeeze() |
| save_numpy_to_nifti(pred3D_full_ori, output_path, meta_info) |
|
|
|
|
| if __name__ == "__main__": |
| ''' 1. read and pre-process your input data ''' |
| img_path = "./test_data/kidney_right/AMOS/imagesVal/amos_0013.nii.gz" |
| gt_path = "./test_data/kidney_right/AMOS/labelsVal/amos_0013.nii.gz" |
| category_index = 3 |
| output_dir = "./test_data/kidney_right/AMOS/pred/" |
| roi_image, roi_label, meta_info = data_preprocess(img_path, gt_path, category_index=category_index) |
| |
| ''' 2. prepare the pre-trained model with local path or huggingface url ''' |
| ckpt_path = "https://huggingface.co/blueyo0/SAM-Med3D/blob/main/sam_med3d_turbo.pth" |
| |
| model = medim.create_model("SAM-Med3D", |
| pretrained=True, |
| checkpoint_path=ckpt_path) |
| |
| ''' 3. infer with the pre-trained SAM-Med3D model ''' |
| roi_pred = sam_model_infer(model, roi_image, roi_gt=roi_label) |
|
|
| ''' 4. post-process and save the result ''' |
| output_path = osp.join(output_dir, osp.basename(img_path).replace(".nii.gz", "_pred.nii.gz")) |
| data_postprocess(roi_pred, meta_info, output_path, img_path) |
|
|
| print("result saved to", output_path) |
|
|