| import torch |
| import os |
| from monai.utils import set_determinism |
| from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping |
| import os |
| from pytorch_lightning.loggers import TensorBoardLogger |
| from trainer import BRATS |
| from dataset.utils import get_loader |
| import pytorch_lightning as pl |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| set_determinism(seed=0) |
|
|
| os.system('cls||clear') |
| print("Training ...") |
|
|
| data_dir = "/app/brats_2021_task1" |
| json_list = "/app/info.json" |
| roi = (128, 128, 128) |
| batch_size = 1 |
| fold = 1 |
| max_epochs = 500 |
| val_every = 10 |
| train_loader, val_loader,test_loader = get_loader(batch_size, data_dir, json_list, fold, roi, volume=1, test_size=0.2) |
| print("Done initialize dataloader !! ") |
|
|
| model = BRATS(use_VAE = True, train_loader = train_loader,val_loader = val_loader, test_loader=test_loader ) |
| checkpoint_callback = ModelCheckpoint( |
| monitor='val/MeanDiceScore', |
| dirpath='./checkpoints/{}'.format("SegTransVAE"), |
| filename='Epoch{epoch:3d}-MeanDiceScore{val/MeanDiceScore:.4f}', |
| save_top_k=3, |
| mode='max', |
| save_last= True, |
| auto_insert_metric_name=False |
| ) |
| early_stop_callback = EarlyStopping( |
| monitor='val/MeanDiceScore', |
| min_delta=0.0001, |
| patience=15, |
| verbose=False, |
| mode='max' |
| ) |
| tensorboardlogger = TensorBoardLogger( |
| 'logs', |
| name = "SegTransVAE", |
| default_hp_metric = None |
| ) |
| trainer = pl.Trainer( |
| |
| |
| devices = [0], |
| precision=16, |
| max_epochs = max_epochs, |
| enable_progress_bar=True, |
| callbacks=[checkpoint_callback, early_stop_callback], |
| |
| num_sanity_val_steps=1, |
| logger = tensorboardlogger, |
| check_val_every_n_epoch = 10, |
| |
| |
| ) |
| |
| trainer.fit(model) |
|
|
|
|
|
|
|
|