| import numpy as np | |
| from monai.transforms import MapTransform | |
| class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform): | |
| """ | |
| Convert labels to multi channels based on brats classes: | |
| label 1 is the necrotic and non-enhancing tumor core | |
| label 2 is the peritumoral edema | |
| label 4 is the GD-enhancing tumor | |
| The possible classes are TC (Tumor core), WT (Whole tumor) | |
| and ET (Enhancing tumor). | |
| """ | |
| def __call__(self, data): | |
| d = dict(data) | |
| for key in self.keys: | |
| result = [] | |
| # merge label 1 and label 4 to construct TC | |
| result.append(np.logical_or(d[key] == 1, d[key] == 4)) | |
| # merge labels 1, 2 and 4 to construct WT | |
| result.append( | |
| np.logical_or( | |
| np.logical_or(d[key] == 1, d[key] == 4), d[key] == 2 | |
| ) | |
| ) | |
| # label 4 is ET | |
| result.append(d[key] == 4) | |
| d[key] = np.stack(result, axis=0).astype(np.float32) | |
| return d |