| import os |
| import pytorch_lightning as pl |
| import matplotlib.pyplot as plt |
| import csv |
| import torch |
| from monai.transforms import AsDiscrete, Activations, Compose, EnsureType |
| from models.SegTranVAE.SegTranVAE import SegTransVAE |
| from loss.loss import Loss_VAE, DiceScore |
| from monai.losses import DiceLoss |
| import pytorch_lightning as pl |
| from monai.inferers import sliding_window_inference |
|
|
|
|
|
|
|
|
|
|
| class BRATS(pl.LightningModule): |
| def __init__(self,train_loader,val_loader,test_loader, use_VAE = True, lr = 1e-4 ): |
| super().__init__() |
| self.train_loader = train_loader |
| self.val_loader = val_loader |
| self.test_loader = test_loader |
| self.use_vae = use_VAE |
| self.lr = lr |
| self.model = SegTransVAE((128, 128, 128), 8, 4, 3, 768, 8, 4, 3072, in_channels_vae=128, use_VAE = use_VAE) |
|
|
| self.loss_vae = Loss_VAE() |
| self.dice_loss = DiceLoss(to_onehot_y=False, sigmoid=True, squared_pred=True) |
| self.post_trans_images = Compose( |
| [EnsureType(), |
| Activations(sigmoid=True), |
| AsDiscrete(threshold_values=True), |
| ] |
| ) |
|
|
| self.best_val_dice = 0 |
| |
| self.training_step_outputs = [] |
| self.val_step_loss = [] |
| self.val_step_dice = [] |
| self.val_step_dice_tc = [] |
| self.val_step_dice_wt = [] |
| self.val_step_dice_et = [] |
| self.test_step_loss = [] |
| self.test_step_dice = [] |
| self.test_step_dice_tc = [] |
| self.test_step_dice_wt = [] |
| self.test_step_dice_et = [] |
|
|
| def forward(self, x, is_validation = True): |
| return self.model(x, is_validation) |
| def training_step(self, batch, batch_index): |
| inputs, labels = (batch['image'], batch['label']) |
| |
| if not self.use_vae: |
| outputs = self.forward(inputs, is_validation=False) |
| loss = self.dice_loss(outputs, labels) |
| else: |
| outputs, recon_batch, mu, sigma = self.forward(inputs, is_validation=False) |
| |
| vae_loss = self.loss_vae(recon_batch, inputs, mu, sigma) |
| dice_loss = self.dice_loss(outputs, labels) |
| loss = dice_loss + 1/(4 * 128 * 128 * 128) * vae_loss |
| self.training_step_outputs.append(loss) |
| self.log('train/vae_loss', vae_loss) |
| self.log('train/dice_loss', dice_loss) |
| if batch_index == 10: |
|
|
| tensorboard = self.logger.experiment |
| fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(10, 5)) |
| |
|
|
| ax[0].imshow(inputs.detach().cpu()[0][0][:, :, 80], cmap='gray') |
| ax[0].set_title("Input") |
|
|
| ax[1].imshow(recon_batch.detach().cpu().float()[0][0][:,:, 80], cmap='gray') |
| ax[1].set_title("Reconstruction") |
| |
| ax[2].imshow(labels.detach().cpu().float()[0][0][:,:, 80], cmap='gray') |
| ax[2].set_title("Labels TC") |
| |
| ax[3].imshow(outputs.sigmoid().detach().cpu().float()[0][0][:,:, 80], cmap='gray') |
| ax[3].set_title("TC") |
| |
| ax[4].imshow(labels.detach().cpu().float()[0][2][:,:, 80], cmap='gray') |
| ax[4].set_title("Labels ET") |
| |
| ax[5].imshow(outputs.sigmoid().detach().cpu().float()[0][2][:,:, 80], cmap='gray') |
| ax[5].set_title("ET") |
|
|
| |
| tensorboard.add_figure('train_visualize', fig, self.current_epoch) |
|
|
| self.log('train/loss', loss) |
| |
| return loss |
| |
| def on_train_epoch_end(self): |
| |
|
|
| |
| |
| epoch_average = torch.stack(self.training_step_outputs).mean() |
| self.log("training_epoch_average", epoch_average) |
| self.training_step_outputs.clear() |
|
|
| def validation_step(self, batch, batch_index): |
| inputs, labels = (batch['image'], batch['label']) |
| roi_size = (128, 128, 128) |
| sw_batch_size = 1 |
| outputs = sliding_window_inference( |
| inputs, roi_size, sw_batch_size, self.model, overlap = 0.5) |
| loss = self.dice_loss(outputs, labels) |
| |
| |
| val_outputs = self.post_trans_images(outputs) |
| |
| |
| metric_tc = DiceScore(y_pred=val_outputs[:, 0:1], y=labels[:, 0:1], include_background = True) |
| metric_wt = DiceScore(y_pred=val_outputs[:, 1:2], y=labels[:, 1:2], include_background = True) |
| metric_et = DiceScore(y_pred=val_outputs[:, 2:3], y=labels[:, 2:3], include_background = True) |
| mean_val_dice = (metric_tc + metric_wt + metric_et)/3 |
| self.val_step_loss.append(loss) |
| self.val_step_dice.append(mean_val_dice) |
| self.val_step_dice_tc.append(metric_tc) |
| self.val_step_dice_wt.append(metric_wt) |
| self.val_step_dice_et.append(metric_et) |
| return {'val_loss': loss, 'val_mean_dice': mean_val_dice, 'val_dice_tc': metric_tc, |
| 'val_dice_wt': metric_wt, 'val_dice_et': metric_et} |
| |
| def on_validation_epoch_end(self): |
|
|
| loss = torch.stack(self.val_step_loss).mean() |
| mean_val_dice = torch.stack(self.val_step_dice).mean() |
| metric_tc = torch.stack(self.val_step_dice_tc).mean() |
| metric_wt = torch.stack(self.val_step_dice_wt).mean() |
| metric_et = torch.stack(self.val_step_dice_et).mean() |
| self.log('val/Loss', loss) |
| self.log('val/MeanDiceScore', mean_val_dice) |
| self.log('val/DiceTC', metric_tc) |
| self.log('val/DiceWT', metric_wt) |
| self.log('val/DiceET', metric_et) |
| os.makedirs(self.logger.log_dir, exist_ok=True) |
| if self.current_epoch == 0: |
| with open('{}/metric_log.csv'.format(self.logger.log_dir), 'w') as f: |
| writer = csv.writer(f) |
| writer.writerow(['Epoch', 'Mean Dice Score', 'Dice TC', 'Dice WT', 'Dice ET']) |
| with open('{}/metric_log.csv'.format(self.logger.log_dir), 'a') as f: |
| writer = csv.writer(f) |
| writer.writerow([self.current_epoch, mean_val_dice.item(), metric_tc.item(), metric_wt.item(), metric_et.item()]) |
|
|
| if mean_val_dice > self.best_val_dice: |
| self.best_val_dice = mean_val_dice |
| self.best_val_epoch = self.current_epoch |
| print( |
| f"\n Current epoch: {self.current_epoch} Current mean dice: {mean_val_dice:.4f}" |
| f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}" |
| f"\n Best mean dice: {self.best_val_dice}" |
| f" at epoch: {self.best_val_epoch}" |
| ) |
| |
| self.val_step_loss.clear() |
| self.val_step_dice.clear() |
| self.val_step_dice_tc.clear() |
| self.val_step_dice_wt.clear() |
| self.val_step_dice_et.clear() |
| return {'val_MeanDiceScore': mean_val_dice} |
| def test_step(self, batch, batch_index): |
| inputs, labels = (batch['image'], batch['label']) |
| |
| roi_size = (128, 128, 128) |
| sw_batch_size = 1 |
| test_outputs = sliding_window_inference( |
| inputs, roi_size, sw_batch_size, self.forward, overlap = 0.5) |
| loss = self.dice_loss(test_outputs, labels) |
| test_outputs = self.post_trans_images(test_outputs) |
| metric_tc = DiceScore(y_pred=test_outputs[:, 0:1], y=labels[:, 0:1], include_background = True) |
| metric_wt = DiceScore(y_pred=test_outputs[:, 1:2], y=labels[:, 1:2], include_background = True) |
| metric_et = DiceScore(y_pred=test_outputs[:, 2:3], y=labels[:, 2:3], include_background = True) |
| mean_test_dice = (metric_tc + metric_wt + metric_et)/3 |
| |
| self.test_step_loss.append(loss) |
| self.test_step_dice.append(mean_test_dice) |
| self.test_step_dice_tc.append(metric_tc) |
| self.test_step_dice_wt.append(metric_wt) |
| self.test_step_dice_et.append(metric_et) |
| |
| return {'test_loss': loss, 'test_mean_dice': mean_test_dice, 'test_dice_tc': metric_tc, |
| 'test_dice_wt': metric_wt, 'test_dice_et': metric_et} |
| |
| def test_epoch_end(self): |
| loss = torch.stack(self.test_step_loss).mean() |
| mean_test_dice = torch.stack(self.test_step_dice).mean() |
| metric_tc = torch.stack(self.test_step_dice_tc).mean() |
| metric_wt = torch.stack(self.test_step_dice_wt).mean() |
| metric_et = torch.stack(self.test_step_dice_et).mean() |
| self.log('test/Loss', loss) |
| self.log('test/MeanDiceScore', mean_test_dice) |
| self.log('test/DiceTC', metric_tc) |
| self.log('test/DiceWT', metric_wt) |
| self.log('test/DiceET', metric_et) |
|
|
| with open('{}/test_log.csv'.format(self.logger.log_dir), 'w') as f: |
| writer = csv.writer(f) |
| writer.writerow(["Mean Test Dice", "Dice TC", "Dice WT", "Dice ET"]) |
| writer.writerow([mean_test_dice, metric_tc, metric_wt, metric_et]) |
|
|
| self.test_step_loss.clear() |
| self.test_step_dice.clear() |
| self.test_step_dice_tc.clear() |
| self.test_step_dice_wt.clear() |
| self.test_step_dice_et.clear() |
| return {'test_MeanDiceScore': mean_test_dice} |
| |
| |
| def configure_optimizers(self): |
| optimizer = torch.optim.Adam( |
| self.model.parameters(), self.lr, weight_decay=1e-5, amsgrad=True |
| ) |
| |
| |
| |
| |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 200) |
| return [optimizer], [scheduler] |
| |
| def train_dataloader(self): |
| return self.train_loader |
| def val_dataloader(self): |
| return self.val_loader |
| |
| def test_dataloader(self): |
| return self.test_loader |