saim1309 commited on
Commit
c04ec27
·
verified ·
1 Parent(s): b6ea77d

Upload 2 files

Browse files
Files changed (2) hide show
  1. MEDIARFormer.py +102 -0
  2. Predictor.py +234 -0
MEDIARFormer.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from segmentation_models_pytorch import MAnet
5
+ from segmentation_models_pytorch.base.modules import Activation
6
+
7
+ __all__ = ["MEDIARFormer"]
8
+
9
+
10
+ class MEDIARFormer(MAnet):
11
+ """MEDIAR-Former Model"""
12
+
13
+ def __init__(
14
+ self,
15
+ encoder_name="mit_b5", # Default encoder
16
+ encoder_weights="imagenet", # Pre-trained weights
17
+ decoder_channels=(1024, 512, 256, 128, 64), # Decoder configuration
18
+ decoder_pab_channels=256, # Decoder Pyramid Attention Block channels
19
+ in_channels=3, # Number of input channels
20
+ classes=3, # Number of output classes
21
+ ):
22
+ # Initialize the MAnet model with provided parameters
23
+ super().__init__(
24
+ encoder_name=encoder_name,
25
+ encoder_weights=encoder_weights,
26
+ decoder_channels=decoder_channels,
27
+ decoder_pab_channels=decoder_pab_channels,
28
+ in_channels=in_channels,
29
+ classes=classes,
30
+ )
31
+
32
+ # Remove the default segmentation head as it's not used in this architecture
33
+ self.segmentation_head = None
34
+
35
+ # Modify all activation functions in the encoder and decoder from ReLU to Mish
36
+ _convert_activations(self.encoder, nn.ReLU, nn.Mish(inplace=True))
37
+ _convert_activations(self.decoder, nn.ReLU, nn.Mish(inplace=True))
38
+
39
+ # Add custom segmentation heads for different segmentation tasks
40
+ self.cellprob_head = DeepSegmentationHead(
41
+ in_channels=decoder_channels[-1], out_channels=1
42
+ )
43
+ self.gradflow_head = DeepSegmentationHead(
44
+ in_channels=decoder_channels[-1], out_channels=2
45
+ )
46
+
47
+ def forward(self, x):
48
+ """Forward pass through the network"""
49
+ # Ensure the input shape is correct
50
+ self.check_input_shape(x)
51
+
52
+ # Encode the input and then decode it
53
+ features = self.encoder(x)
54
+ decoder_output = self.decoder(*features)
55
+
56
+ # Generate masks for cell probability and gradient flows
57
+ cellprob_mask = self.cellprob_head(decoder_output)
58
+ gradflow_mask = self.gradflow_head(decoder_output)
59
+
60
+ # Concatenate the masks for output
61
+ masks = torch.cat([gradflow_mask, cellprob_mask], dim=1)
62
+
63
+ return masks
64
+
65
+
66
+ class DeepSegmentationHead(nn.Sequential):
67
+ """Custom segmentation head for generating specific masks"""
68
+
69
+ def __init__(
70
+ self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1
71
+ ):
72
+ # Define a sequence of layers for the segmentation head
73
+ layers = [
74
+ nn.Conv2d(
75
+ in_channels,
76
+ in_channels // 2,
77
+ kernel_size=kernel_size,
78
+ padding=kernel_size // 2,
79
+ ),
80
+ nn.Mish(inplace=True),
81
+ nn.BatchNorm2d(in_channels // 2),
82
+ nn.Conv2d(
83
+ in_channels // 2,
84
+ out_channels,
85
+ kernel_size=kernel_size,
86
+ padding=kernel_size // 2,
87
+ ),
88
+ nn.UpsamplingBilinear2d(scale_factor=upsampling)
89
+ if upsampling > 1
90
+ else nn.Identity(),
91
+ Activation(activation) if activation else nn.Identity(),
92
+ ]
93
+ super().__init__(*layers)
94
+
95
+
96
+ def _convert_activations(module, from_activation, to_activation):
97
+ """Recursively convert activation functions in a module"""
98
+ for name, child in module.named_children():
99
+ if isinstance(child, from_activation):
100
+ setattr(module, name, to_activation)
101
+ else:
102
+ _convert_activations(child, from_activation, to_activation)
Predictor.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import os, sys
4
+ from monai.inferers import sliding_window_inference
5
+
6
+ sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))
7
+
8
+ from core.BasePredictor import BasePredictor
9
+ from core.MEDIAR.utils import compute_masks
10
+
11
+ __all__ = ["Predictor"]
12
+
13
+
14
+ class Predictor(BasePredictor):
15
+ def __init__(
16
+ self,
17
+ model,
18
+ device,
19
+ input_path,
20
+ output_path,
21
+ make_submission=False,
22
+ exp_name=None,
23
+ algo_params=None,
24
+ ):
25
+ super(Predictor, self).__init__(
26
+ model,
27
+ device,
28
+ input_path,
29
+ output_path,
30
+ make_submission,
31
+ exp_name,
32
+ algo_params,
33
+ )
34
+ self.hflip_tta = HorizontalFlip()
35
+ self.vflip_tta = VerticalFlip()
36
+
37
+ @torch.no_grad()
38
+ def _inference(self, img_data):
39
+ """Conduct model prediction"""
40
+
41
+ img_data = img_data.to(self.device)
42
+ img_base = img_data
43
+ outputs_base = self._window_inference(img_base)
44
+ outputs_base = outputs_base.cpu().squeeze()
45
+ img_base.cpu()
46
+
47
+ if not self.use_tta:
48
+ pred_mask = outputs_base
49
+ return pred_mask
50
+
51
+ else:
52
+ # HorizontalFlip TTA
53
+ img_hflip = self.hflip_tta.apply_aug_image(img_data, apply=True)
54
+ outputs_hflip = self._window_inference(img_hflip)
55
+ outputs_hflip = self.hflip_tta.apply_deaug_mask(outputs_hflip, apply=True)
56
+ outputs_hflip = outputs_hflip.cpu().squeeze()
57
+ img_hflip = img_hflip.cpu()
58
+
59
+ # VertricalFlip TTA
60
+ img_vflip = self.vflip_tta.apply_aug_image(img_data, apply=True)
61
+ outputs_vflip = self._window_inference(img_vflip)
62
+ outputs_vflip = self.vflip_tta.apply_deaug_mask(outputs_vflip, apply=True)
63
+ outputs_vflip = outputs_vflip.cpu().squeeze()
64
+ img_vflip = img_vflip.cpu()
65
+
66
+ # Merge Results
67
+ pred_mask = torch.zeros_like(outputs_base)
68
+ pred_mask[0] = (outputs_base[0] + outputs_hflip[0] - outputs_vflip[0]) / 3
69
+ pred_mask[1] = (outputs_base[1] - outputs_hflip[1] + outputs_vflip[1]) / 3
70
+ pred_mask[2] = (outputs_base[2] + outputs_hflip[2] + outputs_vflip[2]) / 3
71
+
72
+ return pred_mask
73
+
74
+ def _window_inference(self, img_data, aux=False):
75
+ """Inference on RoI-sized window"""
76
+ outputs = sliding_window_inference(
77
+ img_data,
78
+ roi_size=512,
79
+ sw_batch_size=4,
80
+ predictor=self.model if not aux else self.model_aux,
81
+ padding_mode="constant",
82
+ mode="gaussian",
83
+ overlap=0.6,
84
+ )
85
+
86
+ return outputs
87
+
88
+ def _post_process(self, pred_mask):
89
+ """Generate cell instance masks."""
90
+ dP, cellprob = pred_mask[:2], self._sigmoid(pred_mask[-1])
91
+ H, W = pred_mask.shape[-2], pred_mask.shape[-1]
92
+
93
+ if np.prod(H * W) < (5000 * 5000):
94
+ pred_mask = compute_masks(
95
+ dP,
96
+ cellprob,
97
+ use_gpu=True,
98
+ flow_threshold=0.4,
99
+ device=self.device,
100
+ cellprob_threshold=0.5,
101
+ )[0]
102
+
103
+ else:
104
+ print("\n[Whole Slide] Grid Prediction starting...")
105
+ roi_size = 2000
106
+
107
+ # Get patch grid by roi_size
108
+ if H % roi_size != 0:
109
+ n_H = H // roi_size + 1
110
+ new_H = roi_size * n_H
111
+ else:
112
+ n_H = H // roi_size
113
+ new_H = H
114
+
115
+ if W % roi_size != 0:
116
+ n_W = W // roi_size + 1
117
+ new_W = roi_size * n_W
118
+ else:
119
+ n_W = W // roi_size
120
+ new_W = W
121
+
122
+ # Allocate values on the grid
123
+ pred_pad = np.zeros((new_H, new_W), dtype=np.uint32)
124
+ dP_pad = np.zeros((2, new_H, new_W), dtype=np.float32)
125
+ cellprob_pad = np.zeros((new_H, new_W), dtype=np.float32)
126
+
127
+ dP_pad[:, :H, :W], cellprob_pad[:H, :W] = dP, cellprob
128
+
129
+ for i in range(n_H):
130
+ for j in range(n_W):
131
+ print("Pred on Grid (%d, %d) processing..." % (i, j))
132
+ dP_roi = dP_pad[
133
+ :,
134
+ roi_size * i : roi_size * (i + 1),
135
+ roi_size * j : roi_size * (j + 1),
136
+ ]
137
+ cellprob_roi = cellprob_pad[
138
+ roi_size * i : roi_size * (i + 1),
139
+ roi_size * j : roi_size * (j + 1),
140
+ ]
141
+
142
+ pred_mask = compute_masks(
143
+ dP_roi,
144
+ cellprob_roi,
145
+ use_gpu=True,
146
+ flow_threshold=0.4,
147
+ device=self.device,
148
+ cellprob_threshold=0.5,
149
+ )[0]
150
+
151
+ pred_pad[
152
+ roi_size * i : roi_size * (i + 1),
153
+ roi_size * j : roi_size * (j + 1),
154
+ ] = pred_mask
155
+
156
+ pred_mask = pred_pad[:H, :W]
157
+
158
+ return pred_mask
159
+
160
+ def _sigmoid(self, z):
161
+ return 1 / (1 + np.exp(-z))
162
+
163
+
164
+ """
165
+ Adapted from the following references:
166
+ [1] https://github.com/qubvel/ttach/blob/master/ttach/transforms.py
167
+
168
+ """
169
+
170
+
171
+ def hflip(x):
172
+ """flip batch of images horizontally"""
173
+ return x.flip(3)
174
+
175
+
176
+ def vflip(x):
177
+ """flip batch of images vertically"""
178
+ return x.flip(2)
179
+
180
+
181
+ class DualTransform:
182
+ identity_param = None
183
+
184
+ def __init__(
185
+ self, name: str, params,
186
+ ):
187
+ self.params = params
188
+ self.pname = name
189
+
190
+ def apply_aug_image(self, image, *args, **params):
191
+ raise NotImplementedError
192
+
193
+ def apply_deaug_mask(self, mask, *args, **params):
194
+ raise NotImplementedError
195
+
196
+
197
+ class HorizontalFlip(DualTransform):
198
+ """Flip images horizontally (left -> right)"""
199
+
200
+ identity_param = False
201
+
202
+ def __init__(self):
203
+ super().__init__("apply", [False, True])
204
+
205
+ def apply_aug_image(self, image, apply=False, **kwargs):
206
+ if apply:
207
+ image = hflip(image)
208
+ return image
209
+
210
+ def apply_deaug_mask(self, mask, apply=False, **kwargs):
211
+ if apply:
212
+ mask = hflip(mask)
213
+ return mask
214
+
215
+
216
+ class VerticalFlip(DualTransform):
217
+ """Flip images vertically (up -> down)"""
218
+
219
+ identity_param = False
220
+
221
+ def __init__(self):
222
+ super().__init__("apply", [False, True])
223
+
224
+ def apply_aug_image(self, image, apply=False, **kwargs):
225
+ if apply:
226
+ image = vflip(image)
227
+
228
+ return image
229
+
230
+ def apply_deaug_mask(self, mask, apply=False, **kwargs):
231
+ if apply:
232
+ mask = vflip(mask)
233
+
234
+ return mask