| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import streamlit as st |
|
|
| import warnings |
| from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union |
|
|
| import torch |
| import torch.nn.functional as F |
| import random |
| import time |
| from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size |
| from monai.transforms import Resize |
| from monai.utils import ( |
| BlendMode, |
| PytorchPadMode, |
| convert_data_type, |
| ensure_tuple, |
| fall_back_tuple, |
| look_up_option, |
| optional_import, |
| ) |
|
|
| tqdm, _ = optional_import("tqdm", name="tqdm") |
|
|
| __all__ = ["sliding_window_inference"] |
|
|
| def logits2roi_coor(spatial_size, logits_global_single): |
| |
| pred_global_single = torch.sigmoid(logits_global_single) > 0.5 |
| |
| nonzero_indices = torch.nonzero(pred_global_single) |
| if nonzero_indices.shape[0] == 0: |
| return None, None, None, None, None, None |
| |
| min_d, max_d = nonzero_indices[:, 0].min(), nonzero_indices[:, 0].max() |
| min_h, max_h = nonzero_indices[:, 1].min(), nonzero_indices[:, 1].max() |
| min_w, max_w = nonzero_indices[:, 2].min(), nonzero_indices[:, 2].max() |
| |
| crop_d, crop_h, crop_w = max_d - min_d + 1, max_h - min_h + 1, max_w - min_w + 1, |
| window_d, window_h, window_w = spatial_size |
| padding_d, padding_h, padding_w = max(0, window_d-crop_d), max(0, window_h-crop_h), max(0, window_w-crop_w) |
| global_d, global_h, global_w = logits_global_single.shape |
| min_d = max(0, min_d - int(padding_d)//2) |
| min_h = max(0, min_h - int(padding_h)//2) |
| min_w = max(0, min_w - int(padding_w)//2) |
| max_d = min(global_d, max_d + int(padding_d)//2) |
| max_h = min(global_h, max_h + int(padding_h)//2) |
| max_w = min(global_w, max_w + int(padding_w)//2) |
| return min_d, min_h, min_w, max_d, max_h, max_w |
|
|
| def build_binary_cube(bbox, binary_cube_shape): |
| min_coord = bbox[0][:3].int().tolist() |
| max_coord = bbox[0][3:].int().tolist() |
| binary_cube = torch.zeros(binary_cube_shape) |
| binary_cube[min_coord[0]:max_coord[0]+1, min_coord[1]:max_coord[1]+1, min_coord[2]:max_coord[2]+1] = 1 |
| return binary_cube |
|
|
| def build_binary_points(points, labels, shape): |
| binary_points = torch.zeros(shape, dtype=torch.int16) |
| print(shape, labels == 1) |
| binary_points[points[labels == 1, 0].long(), points[labels == 1, 1].long(), points[labels == 1, 2].long()] = 1 |
| return binary_points |
|
|
| def sliding_window_inference( |
| inputs: torch.Tensor, |
| prompt_reflection: Union[torch.Tensor, Tuple[torch.Tensor, ...]], |
| roi_size: Union[Sequence[int], int], |
| sw_batch_size: int, |
| predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]], |
| overlap: float = 0.25, |
| mode: Union[BlendMode, str] = BlendMode.CONSTANT, |
| sigma_scale: Union[Sequence[float], float] = 0.125, |
| padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, |
| cval: float = 0.0, |
| sw_device: Union[torch.device, str, None] = None, |
| device: Union[torch.device, str, None] = None, |
| progress: bool = False, |
| roi_weight_map: Union[torch.Tensor, None] = None, |
| *args: Any, |
| **kwargs: Any, |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]: |
| """ |
| Sliding window inference on `inputs` with `predictor`. |
| |
| The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors. |
| Each output in the tuple or dict value is allowed to have different resolutions with respect to the input. |
| e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes |
| could be ([128,64,256], [64,32,128]). |
| In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still |
| an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters |
| so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension). |
| |
| When roi_size is larger than the inputs' spatial size, the input image are padded during inference. |
| To maintain the same spatial sizes, the output image will be cropped to the original input size. |
| |
| Args: |
| inputs: input image to be processed (assuming NCHW[D]) |
| roi_size: the spatial window size for inferences. |
| When its components have None or non-positives, the corresponding inputs dimension will be used. |
| if the components of the `roi_size` are non-positive values, the transform will use the |
| corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted |
| to `(32, 64)` if the second spatial dimension size of img is `64`. |
| sw_batch_size: the batch size to run window slices. |
| predictor: given input tensor ``patch_data`` in shape NCHW[D], |
| The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary |
| with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D']; |
| where H'W'[D'] represents the output patch's spatial size, M is the number of output channels, |
| N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128), |
| the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)). |
| In this case, the parameter `overlap` and `roi_size` need to be carefully chosen |
| to ensure the scaled output ROI sizes are still integers. |
| If the `predictor`'s input and output spatial sizes are different, |
| we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension. |
| overlap: Amount of overlap between scans. |
| mode: {``"constant"``, ``"gaussian"``} |
| How to blend output of overlapping windows. Defaults to ``"constant"``. |
| |
| - ``"constant``": gives equal weight to all predictions. |
| - ``"gaussian``": gives less weight to predictions on edges of windows. |
| |
| sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``. |
| Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``. |
| When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding |
| spatial dimensions. |
| padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``} |
| Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"`` |
| See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html |
| cval: fill value for 'constant' padding mode. Default: 0 |
| sw_device: device for the window data. |
| By default the device (and accordingly the memory) of the `inputs` is used. |
| Normally `sw_device` should be consistent with the device where `predictor` is defined. |
| device: device for the stitched output prediction. |
| By default the device (and accordingly the memory) of the `inputs` is used. If for example |
| set to device=torch.device('cpu') the gpu memory consumption is less and independent of the |
| `inputs` and `roi_size`. Output is on the `device`. |
| progress: whether to print a `tqdm` progress bar. |
| roi_weight_map: pre-computed (non-negative) weight map for each ROI. |
| If not given, and ``mode`` is not `constant`, this map will be computed on the fly. |
| args: optional args to be passed to ``predictor``. |
| kwargs: optional keyword args to be passed to ``predictor``. |
| |
| Note: |
| - input must be channel-first and have a batch dim, supports N-D sliding window. |
| |
| """ |
| print('sliding window inference for ROI') |
| text = kwargs['text'] |
| use_box = kwargs['use_box'] |
| use_point = kwargs['use_point'] |
| logits_global_single = kwargs['logits_global_single'] |
| assert not (use_box and use_point) |
| compute_dtype = inputs.dtype |
| num_spatial_dims = len(inputs.shape) - 2 |
| if overlap < 0 or overlap >= 1: |
| raise ValueError("overlap must be >= 0 and < 1.") |
|
|
| |
| |
| batch_size, _, *image_size_ = inputs.shape |
|
|
| if device is None: |
| device = inputs.device |
| if sw_device is None: |
| sw_device = inputs.device |
|
|
| roi_size = fall_back_tuple(roi_size, image_size_) |
| |
| image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims)) |
| pad_size = [] |
| for k in range(len(inputs.shape) - 1, 1, -1): |
| diff = max(roi_size[k - 2] - inputs.shape[k], 0) |
| half = diff // 2 |
| pad_size.extend([half, diff - half]) |
| inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode).value, value=cval) |
| |
| if use_point or use_box: |
| binary_prompt_map, global_preds = prompt_reflection |
| global_preds = F.pad(global_preds, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode).value, value=cval) |
| |
| scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) |
|
|
| |
| slices = dense_patch_slices(image_size, roi_size, scan_interval) |
| num_win = len(slices) |
| total_slices = num_win * batch_size |
| if total_slices > 10 and not st.session_state.enforce_zoom: |
| return logits_global_single |
|
|
| |
| valid_patch_size = get_valid_patch_size(image_size, roi_size) |
| if valid_patch_size == roi_size and (roi_weight_map is not None): |
| importance_map = roi_weight_map |
| else: |
| try: |
| importance_map = compute_importance_map(valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device) |
| except BaseException as e: |
| raise RuntimeError( |
| "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'." |
| ) from e |
| importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] |
| |
| min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3) |
| importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(compute_dtype) |
|
|
| |
| dict_key, output_image_list, count_map_list = None, [], [] |
| _initialized_ss = -1 |
| is_tensor_output = True |
|
|
| |
| |
| for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range(0, total_slices, sw_batch_size): |
| |
| st.write(f'zoom in inference {slice_g/total_slices*100.0:.2f}%') |
| slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices)) |
| unravel_slice = [ |
| [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win]) |
| for idx in slice_range |
| ] |
| window_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device) |
| |
| |
| boxes = None |
| points = None |
| if use_point: |
| window_binary_prompt_map = torch.cat([binary_prompt_map[win_slice] for win_slice in unravel_slice]).to(sw_device) |
| point, point_label = select_points(window_binary_prompt_map.squeeze()) |
| points = (point.unsqueeze(0).float(), point_label.unsqueeze(0).float()) |
| pseudo_label = torch.cat([global_preds[win_slice] for win_slice in unravel_slice]).to(sw_device) |
| boxes = generate_box(pseudo_label.squeeze()).unsqueeze(0).float() |
| if use_box: |
| if num_win == 1: |
| window_binary_prompt_map = torch.cat([binary_prompt_map[win_slice] for win_slice in unravel_slice]).to(sw_device) |
| boxes = generate_box(window_binary_prompt_map.squeeze()).unsqueeze(0).float() |
| else: |
| pseudo_label = torch.cat([global_preds[win_slice] for win_slice in unravel_slice]).to(sw_device) |
| boxes = generate_box(pseudo_label.squeeze()).unsqueeze(0).float() |
| seg_prob_out = predictor(window_data, text, boxes, points) |
| |
| |
| seg_prob_tuple: Tuple[torch.Tensor, ...] |
| if isinstance(seg_prob_out, torch.Tensor): |
| seg_prob_tuple = (seg_prob_out,) |
| elif isinstance(seg_prob_out, Mapping): |
| if dict_key is None: |
| dict_key = sorted(seg_prob_out.keys()) |
| seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key) |
| is_tensor_output = False |
| else: |
| seg_prob_tuple = ensure_tuple(seg_prob_out) |
| is_tensor_output = False |
|
|
| |
| for ss, seg_prob in enumerate(seg_prob_tuple): |
| seg_prob = seg_prob.to(device) |
|
|
| |
| zoom_scale = [] |
| for axis, (img_s_i, out_w_i, in_w_i) in enumerate( |
| zip(image_size, seg_prob.shape[2:], window_data.shape[2:]) |
| ): |
| _scale = out_w_i / float(in_w_i) |
| if not (img_s_i * _scale).is_integer(): |
| warnings.warn( |
| f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial " |
| f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs." |
| ) |
| zoom_scale.append(_scale) |
|
|
| if _initialized_ss < ss: |
| |
| output_classes = seg_prob.shape[1] |
| output_shape = [batch_size, output_classes] + [ |
| int(image_size_d * zoom_scale_d) for image_size_d, zoom_scale_d in zip(image_size, zoom_scale) |
| ] |
| |
| output_image_list.append(torch.zeros(output_shape, dtype=compute_dtype, device=device)) |
| count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device)) |
| _initialized_ss += 1 |
|
|
| |
| resizer = Resize(spatial_size=seg_prob.shape[2:], mode="nearest", anti_aliasing=False) |
|
|
| |
| for idx, original_idx in zip(slice_range, unravel_slice): |
| |
| original_idx_zoom = list(original_idx) |
| for axis in range(2, len(original_idx_zoom)): |
| zoomed_start = original_idx[axis].start * zoom_scale[axis - 2] |
| zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2] |
| if not zoomed_start.is_integer() or (not zoomed_end.is_integer()): |
| warnings.warn( |
| f"For axis-{axis-2} of output[{ss}], the output roi range is not int. " |
| f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). " |
| f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. " |
| f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n" |
| f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. " |
| "Tips: if overlap*roi_size*zoom_scale is an integer, it usually works." |
| ) |
| original_idx_zoom[axis] = slice(int(zoomed_start), int(zoomed_end), None) |
| importance_map_zoom = resizer(importance_map.unsqueeze(0))[0].to(compute_dtype) |
| |
| output_image_list[ss][original_idx_zoom] += importance_map_zoom * seg_prob[idx - slice_g] |
| count_map_list[ss][original_idx_zoom] += ( |
| importance_map_zoom.unsqueeze(0).unsqueeze(0).expand(count_map_list[ss][original_idx_zoom].shape) |
| ) |
| |
| for ss in range(len(output_image_list)): |
| output_image_list[ss] = (output_image_list[ss] / count_map_list.pop(0)).to(compute_dtype) |
|
|
| |
| for ss, output_i in enumerate(output_image_list): |
| if torch.isnan(output_i).any() or torch.isinf(output_i).any(): |
| warnings.warn("Sliding window inference results contain NaN or Inf.") |
|
|
| zoom_scale = [ |
| seg_prob_map_shape_d / roi_size_d for seg_prob_map_shape_d, roi_size_d in zip(output_i.shape[2:], roi_size) |
| ] |
|
|
| final_slicing: List[slice] = [] |
| for sp in range(num_spatial_dims): |
| slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2]) |
| slice_dim = slice( |
| int(round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])), |
| int(round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])), |
| ) |
| final_slicing.insert(0, slice_dim) |
| while len(final_slicing) < len(output_i.shape): |
| final_slicing.insert(0, slice(None)) |
| output_image_list[ss] = output_i[final_slicing] |
|
|
| if dict_key is not None: |
| final_output = dict(zip(dict_key, output_image_list)) |
| else: |
| final_output = tuple(output_image_list) |
| return final_output[0] if is_tensor_output else final_output |
|
|
|
|
| def _get_scan_interval( |
| image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: float |
| ) -> Tuple[int, ...]: |
| """ |
| Compute scan interval according to the image size, roi size and overlap. |
| Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0, |
| use 1 instead to make sure sliding window works. |
| |
| """ |
| if len(image_size) != num_spatial_dims: |
| raise ValueError("image coord different from spatial dims.") |
| if len(roi_size) != num_spatial_dims: |
| raise ValueError("roi coord different from spatial dims.") |
|
|
| scan_interval = [] |
| for i in range(num_spatial_dims): |
| if roi_size[i] == image_size[i]: |
| scan_interval.append(int(roi_size[i])) |
| else: |
| interval = int(roi_size[i] * (1 - overlap)) |
| scan_interval.append(interval if interval > 0 else 1) |
| return tuple(scan_interval) |
|
|
|
|
| def generate_box(pred_pre, bbox_shift=None): |
| meaning_post_label = pred_pre |
| ones_idx = (meaning_post_label > 0).nonzero(as_tuple=True) |
| if all(tensor.nelement() == 0 for tensor in ones_idx): |
| bboxes = torch.tensor([-1,-1,-1,-1,-1,-1]) |
| |
| return bboxes |
| min_coords = [dim.min() for dim in ones_idx] |
| max_coords = [dim.max() for dim in ones_idx] |
|
|
|
|
| if bbox_shift is None: |
| corner_min = [] |
| corner_max = [] |
| shape = meaning_post_label.shape |
| for coor in min_coords: |
| coor_ = max(0, coor) |
| corner_min.append(coor_) |
| for idx, coor in enumerate(max_coords): |
| coor_ = min(shape[idx], coor) |
| corner_max.append(coor_) |
| corner_min = torch.tensor(corner_min) |
| corner_max = torch.tensor(corner_max) |
| return torch.cat((corner_min, corner_max), dim=0) |
| else: |
| |
| corner_min = [] |
| corner_max = [] |
| shape = meaning_post_label.shape |
| for coor in min_coords: |
| coor_ = max(0, coor + random.randint(-bbox_shift, bbox_shift)) |
| corner_min.append(coor_) |
| for idx, coor in enumerate(max_coords): |
| coor_ = min(shape[idx], coor + random.randint(-bbox_shift, bbox_shift)) |
| corner_max.append(coor_) |
| corner_min = torch.tensor(corner_min) |
| corner_max = torch.tensor(corner_max) |
| return torch.cat((corner_min, corner_max), dim=0) |
|
|
|
|
| def select_points(preds, num_positive_extra=4, num_negative_extra=0, fix_extra_point_num=None): |
| spacial_dim = 3 |
| points = torch.zeros((0, 3)) |
| labels = torch.zeros((0)) |
| pos_thred = 0.9 |
| neg_thred = 0.1 |
| |
| |
| positive_indices = torch.nonzero(preds > pos_thred, as_tuple=True) |
| negative_indices = torch.nonzero(preds < neg_thred, as_tuple=True) |
|
|
| ones_idx = (preds > pos_thred).nonzero(as_tuple=True) |
| if all(tmp.nelement() == 0 for tmp in ones_idx): |
| |
| num_positive_extra = 0 |
| selected_positive_point = torch.tensor([-1,-1,-1]).unsqueeze(dim=0) |
| points = torch.cat((points, selected_positive_point), dim=0) |
| labels = torch.cat((labels, torch.tensor([-1]).reshape(1))) |
| else: |
| |
| random_idx = torch.randint(len(positive_indices[0]), (1,)) |
| selected_positive_point = torch.tensor([positive_indices[i][random_idx] for i in range(spacial_dim)]).unsqueeze(dim=0) |
| points = torch.cat((points, selected_positive_point), dim=0) |
| labels = torch.cat((labels, torch.ones((1)))) |
|
|
| if num_positive_extra > 0: |
| pos_idx_list = torch.randperm(len(positive_indices[0]))[:num_positive_extra] |
| extra_positive_points = [] |
| for pos_idx in pos_idx_list: |
| extra_positive_points.append([positive_indices[i][pos_idx] for i in range(spacial_dim)]) |
| extra_positive_points = torch.tensor(extra_positive_points).reshape(-1, 3) |
| points = torch.cat((points, extra_positive_points), dim=0) |
| labels = torch.cat((labels, torch.ones((extra_positive_points.shape[0])))) |
|
|
| if num_negative_extra > 0: |
| neg_idx_list = torch.randperm(len(negative_indices[0]))[:num_negative_extra] |
| extra_negative_points = [] |
| for neg_idx in neg_idx_list: |
| extra_negative_points.append([negative_indices[i][neg_idx] for i in range(spacial_dim)]) |
| extra_negative_points = torch.tensor(extra_negative_points).reshape(-1, 3) |
| points = torch.cat((points, extra_negative_points), dim=0) |
| labels = torch.cat((labels, torch.zeros((extra_negative_points.shape[0])))) |
| |
| |
| |
| if fix_extra_point_num is None: |
| left_point_num = num_positive_extra + num_negative_extra + 1 - labels.shape[0] |
| else: |
| left_point_num = fix_extra_point_num + 1 - labels.shape[0] |
|
|
| for _ in range(left_point_num): |
| ignore_point = torch.tensor([-1,-1,-1]).unsqueeze(dim=0) |
| points = torch.cat((points, ignore_point), dim=0) |
| labels = torch.cat((labels, torch.tensor([-1]).reshape(1))) |
|
|
| return (points, labels) |
|
|