BoSAM / medim_infer.py
ziyanlu's picture
Upload folder using huggingface_hub
9859ea2 verified
# -*- encoding: utf-8 -*-
'''
@File : infer_with_medim.py
@Time : 2024/09/08 11:31:02
@Author : Haoyu Wang
@Contact : small_dark@sina.com
@Brief : Example code for inference with MedIM
'''
import medim
import torch
import numpy as np
import torch.nn.functional as F
import torchio as tio
import os.path as osp
import os
from torchio.data.io import sitk_to_nib
import SimpleITK as sitk
def random_sample_next_click(prev_mask, gt_mask):
"""
Randomly sample one click from ground-truth mask and previous seg mask
Arguements:
prev_mask: (torch.Tensor) [H,W,D] previous mask that SAM-Med3D predict
gt_mask: (torch.Tensor) [H,W,D] ground-truth mask for this image
"""
prev_mask = prev_mask > 0
true_masks = gt_mask > 0
if (not true_masks.any()):
raise ValueError("Cannot find true value in the ground-truth!")
fn_masks = torch.logical_and(true_masks, torch.logical_not(prev_mask))
fp_masks = torch.logical_and(torch.logical_not(true_masks), prev_mask)
to_point_mask = torch.logical_or(fn_masks, fp_masks)
all_points = torch.argwhere(to_point_mask)
point = all_points[np.random.randint(len(all_points))]
if fn_masks[point[0], point[1], point[2]]:
is_positive = True
else:
is_positive = False
sampled_point = point.clone().detach().reshape(1, 1, 3)
sampled_label = torch.tensor([
int(is_positive),
]).reshape(1, 1)
return sampled_point, sampled_label
def sam_model_infer(model,
roi_image,
prompt_generator=random_sample_next_click,
roi_gt=None,
prev_low_res_mask=None):
'''
Inference for SAM-Med3D, inputs prompt points with its labels (positive/negative for each points)
# roi_image: (torch.Tensor) cropped image, shape [1,1,128,128,128]
# prompt_points_and_labels: (Tuple(torch.Tensor, torch.Tensor))
'''
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("using device", device)
model = model.to(device)
with torch.no_grad():
input_tensor = roi_image.to(device)
image_embeddings = model.image_encoder(input_tensor)
points_coords, points_labels = torch.zeros(1, 0,
3).to(device), torch.zeros(
1, 0).to(device)
new_points_co, new_points_la = torch.Tensor(
[[[64, 64, 64]]]).to(device), torch.Tensor([[1]]).to(torch.int64)
if (roi_gt is not None):
prev_low_res_mask = prev_low_res_mask if (
prev_low_res_mask is not None) else torch.zeros(
1, 1, roi_image.shape[2] // 4, roi_image.shape[3] //
4, roi_image.shape[4] // 4)
new_points_co, new_points_la = prompt_generator(
torch.zeros_like(roi_image)[0, 0], roi_gt[0, 0])
new_points_co, new_points_la = new_points_co.to(
device), new_points_la.to(device)
points_coords = torch.cat([points_coords, new_points_co], dim=1)
points_labels = torch.cat([points_labels, new_points_la], dim=1)
sparse_embeddings, dense_embeddings = model.prompt_encoder(
points=[points_coords, points_labels],
boxes=None, # we currently not support bbox prompt
masks=prev_low_res_mask.to(device),
# masks=None,
)
low_res_masks, _ = model.mask_decoder(
image_embeddings=image_embeddings, # (1, 384, 8, 8, 8)
image_pe=model.prompt_encoder.get_dense_pe(), # (1, 384, 8, 8, 8)
sparse_prompt_embeddings=sparse_embeddings, # (1, 2, 384)
dense_prompt_embeddings=dense_embeddings, # (1, 384, 8, 8, 8)
)
prev_mask = F.interpolate(low_res_masks,
size=roi_image.shape[-3:],
mode='trilinear',
align_corners=False)
# convert prob to mask
medsam_seg_prob = torch.sigmoid(prev_mask) # (1, 1, 64, 64, 64)
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
medsam_seg_mask = (medsam_seg_prob > 0.5).astype(np.uint8)
return medsam_seg_mask
def resample_nii(input_path: str,
output_path: str,
target_spacing: tuple = (1.5, 1.5, 1.5),
n=None,
reference_image=None,
mode="linear"):
"""
Resample a nii.gz file to a specified spacing using torchio.
Parameters:
- input_path: Path to the input .nii.gz file.
- output_path: Path to save the resampled .nii.gz file.
- target_spacing: Desired spacing for resampling. Default is (1.5, 1.5, 1.5).
"""
# Load the nii.gz file using torchio
subject = tio.Subject(img=tio.ScalarImage(input_path))
resampler = tio.Resample(target=target_spacing, image_interpolation=mode)
resampled_subject = resampler(subject)
if (n != None):
image = resampled_subject.img
tensor_data = image.data
if (isinstance(n, int)):
n = [n]
for ni in n:
tensor_data[tensor_data == ni] = -1
tensor_data[tensor_data != -1] = 0
tensor_data[tensor_data != 0] = 1
save_image = tio.ScalarImage(tensor=tensor_data, affine=image.affine)
reference_size = reference_image.shape[
1:] # omitting the channel dimension
cropper_or_padder = tio.CropOrPad(reference_size)
save_image = cropper_or_padder(save_image)
else:
save_image = resampled_subject.img
save_image.save(output_path)
def read_data_from_nii(img_path, gt_path):
sitk_image = sitk.ReadImage(img_path)
sitk_label = sitk.ReadImage(gt_path)
if sitk_image.GetOrigin() != sitk_label.GetOrigin():
sitk_image.SetOrigin(sitk_label.GetOrigin())
if sitk_image.GetDirection() != sitk_label.GetDirection():
sitk_image.SetDirection(sitk_label.GetDirection())
sitk_image_arr, _ = sitk_to_nib(sitk_image)
sitk_label_arr, _ = sitk_to_nib(sitk_label)
subject = tio.Subject(
image=tio.ScalarImage(tensor=sitk_image_arr),
label=tio.LabelMap(tensor=sitk_label_arr),
)
crop_transform = tio.CropOrPad(mask_name='label',
target_shape=(128, 128, 128))
padding_params, cropping_params = crop_transform.compute_crop_or_pad(
subject)
if (cropping_params is None): cropping_params = (0, 0, 0, 0, 0, 0)
if (padding_params is None): padding_params = (0, 0, 0, 0, 0, 0)
infer_transform = tio.Compose([
crop_transform,
tio.ZNormalization(masking_method=lambda x: x > 0),
])
subject_roi = infer_transform(subject)
img3D_roi, gt3D_roi = subject_roi.image.data.clone().detach().unsqueeze(
1), subject_roi.label.data.clone().detach().unsqueeze(1)
ori_roi_offset = (
cropping_params[0],
cropping_params[0] + 128 - padding_params[0] - padding_params[1],
cropping_params[2],
cropping_params[2] + 128 - padding_params[2] - padding_params[3],
cropping_params[4],
cropping_params[4] + 128 - padding_params[4] - padding_params[5],
)
meta_info = {
"image_path": img_path,
"image_shape": sitk_image_arr.shape[1:],
"origin": sitk_label.GetOrigin(),
"direction": sitk_label.GetDirection(),
"spacing": sitk_label.GetSpacing(),
"padding_params": padding_params,
"cropping_params": cropping_params,
"ori_roi": ori_roi_offset,
}
return (
img3D_roi,
gt3D_roi,
meta_info,
)
def save_numpy_to_nifti(in_arr: np.array, out_path, meta_info):
# torchio turn 1xHxWxD -> DxWxH
# so we need to squeeze and transpose back to HxWxD
ori_arr = np.transpose(in_arr.squeeze(), (2, 1, 0))
out = sitk.GetImageFromArray(ori_arr)
sitk_meta_translator = lambda x: [float(i) for i in x]
out.SetOrigin(sitk_meta_translator(meta_info["origin"]))
out.SetDirection(sitk_meta_translator(meta_info["direction"]))
out.SetSpacing(sitk_meta_translator(meta_info["spacing"]))
sitk.WriteImage(out, out_path)
def data_preprocess(img_path, gt_path, category_index):
target_img_path = osp.join(
osp.dirname(img_path),
osp.basename(img_path).replace(".nii.gz", "_resampled.nii.gz"))
target_gt_path = osp.join(
osp.dirname(gt_path),
osp.basename(gt_path).replace(".nii.gz", "_resampled.nii.gz"))
resample_nii(img_path, target_img_path)
resample_nii(gt_path,
target_gt_path,
n=category_index,
reference_image=tio.ScalarImage(target_img_path),
mode="nearest")
roi_image, roi_label, meta_info = read_data_from_nii(
target_img_path, target_gt_path)
return roi_image, roi_label, meta_info
def data_postprocess(roi_pred, meta_info, output_path, ori_img_path):
os.makedirs(osp.dirname(output_path), exist_ok=True)
pred3D_full = np.zeros(meta_info["image_shape"])
padding_params = meta_info["padding_params"]
unpadded_pred = roi_pred[padding_params[0] : 128-padding_params[1],
padding_params[2] : 128-padding_params[3],
padding_params[4] : 128-padding_params[5]]
ori_roi = meta_info["ori_roi"]
pred3D_full[ori_roi[0]:ori_roi[1], ori_roi[2]:ori_roi[3],
ori_roi[4]:ori_roi[5]] = unpadded_pred
sitk_image = sitk.ReadImage(ori_img_path)
ori_meta_info = {
"image_path": ori_img_path,
"image_shape": sitk_image.GetSize(),
"origin": sitk_image.GetOrigin(),
"direction": sitk_image.GetDirection(),
"spacing": sitk_image.GetSpacing(),
}
pred3D_full_ori = F.interpolate(
torch.Tensor(pred3D_full)[None][None],
size=ori_meta_info["image_shape"],
mode='nearest').cpu().numpy().squeeze()
save_numpy_to_nifti(pred3D_full_ori, output_path, meta_info)
if __name__ == "__main__":
''' 1. read and pre-process your input data '''
img_path = "./test_data/kidney_right/AMOS/imagesVal/amos_0013.nii.gz"
gt_path = "./test_data/kidney_right/AMOS/labelsVal/amos_0013.nii.gz"
category_index = 3 # the index of your target category in the gt annotation
output_dir = "./test_data/kidney_right/AMOS/pred/"
roi_image, roi_label, meta_info = data_preprocess(img_path, gt_path, category_index=category_index)
''' 2. prepare the pre-trained model with local path or huggingface url '''
ckpt_path = "https://huggingface.co/blueyo0/SAM-Med3D/blob/main/sam_med3d_turbo.pth"
# or you can use the local path like: ckpt_path = "./ckpt/sam_med3d_turbo.pth"
model = medim.create_model("SAM-Med3D",
pretrained=True,
checkpoint_path=ckpt_path)
''' 3. infer with the pre-trained SAM-Med3D model '''
roi_pred = sam_model_infer(model, roi_image, roi_gt=roi_label)
''' 4. post-process and save the result '''
output_path = osp.join(output_dir, osp.basename(img_path).replace(".nii.gz", "_pred.nii.gz"))
data_postprocess(roi_pred, meta_info, output_path, img_path)
print("result saved to", output_path)