| |
| import numpy as np |
| import random |
| import datetime |
| import logging |
| import matplotlib.pyplot as plt |
| import os |
| join = os.path.join |
| from tqdm import tqdm |
| from torch.backends import cudnn |
| import torch |
| import torch.distributed as dist |
| import torch.nn.functional as F |
| import torchio as tio |
| from torch.utils.data.distributed import DistributedSampler |
| from segment_anything.build_sam3D import sam_model_registry3D |
| import argparse |
| from torch.cuda import amp |
| import torch.multiprocessing as mp |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from monai.losses import DiceCELoss |
| from contextlib import nullcontext |
| from utils.click_method import get_next_click3D_torch_2 |
| from utils.data_loader import Dataset_Union_ALL, Union_Dataloader |
| from utils.data_paths import img_datas |
|
|
|
|
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--task_name', type=str, default='union_train') |
| parser.add_argument('--click_type', type=str, default='random') |
| parser.add_argument('--multi_click', action='store_true', default=False) |
| parser.add_argument('--model_type', type=str, default='vit_b_ori') |
| parser.add_argument('--checkpoint', type=str, default='ckpt/sam_med3d.pth') |
| parser.add_argument('--device', type=str, default='cuda') |
| parser.add_argument('--work_dir', type=str, default='work_dir') |
|
|
| |
| parser.add_argument('--num_workers', type=int, default=24) |
| parser.add_argument('--gpu_ids', type=int, nargs='+', default=[0,1]) |
| parser.add_argument('--multi_gpu', action='store_true', default=False) |
| parser.add_argument('--resume', action='store_true', default=False) |
| parser.add_argument('--allow_partial_weight', action='store_true', default=False) |
|
|
| |
| parser.add_argument('--lr_scheduler', type=str, default='multisteplr') |
| parser.add_argument('--step_size', type=list, default=[120, 180]) |
| parser.add_argument('--gamma', type=float, default=0.1) |
| parser.add_argument('--num_epochs', type=int, default=200) |
| parser.add_argument('--img_size', type=int, default=128) |
| parser.add_argument('--batch_size', type=int, default=12) |
| parser.add_argument('--accumulation_steps', type=int, default=20) |
| parser.add_argument('--lr', type=float, default=8e-4) |
| parser.add_argument('--weight_decay', type=float, default=0.1) |
| parser.add_argument('--port', type=int, default=12361) |
|
|
| args = parser.parse_args() |
|
|
| device = args.device |
| os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(i) for i in args.gpu_ids]) |
| logger = logging.getLogger(__name__) |
| LOG_OUT_DIR = join(args.work_dir, args.task_name) |
| click_methods = { |
| 'random': get_next_click3D_torch_2, |
| } |
| MODEL_SAVE_PATH = join(args.work_dir, args.task_name) |
| os.makedirs(MODEL_SAVE_PATH, exist_ok=True) |
|
|
| def build_model(args): |
| sam_model = sam_model_registry3D[args.model_type](checkpoint=None).to(device) |
| if args.multi_gpu: |
| sam_model = DDP(sam_model, device_ids=[args.rank], output_device=args.rank) |
| return sam_model |
|
|
|
|
| def get_dataloaders(args): |
| train_dataset = Dataset_Union_ALL(paths=img_datas, transform=tio.Compose([ |
| tio.ToCanonical(), |
| tio.CropOrPad(mask_name='label', target_shape=(args.img_size,args.img_size,args.img_size)), |
| tio.RandomFlip(axes=(0, 1, 2)), |
| ]), |
| threshold=1000) |
|
|
| if args.multi_gpu: |
| train_sampler = DistributedSampler(train_dataset) |
| shuffle = False |
| else: |
| train_sampler = None |
| shuffle = True |
|
|
| |
| train_dataloader = Union_Dataloader( |
| dataset=train_dataset, |
| sampler=train_sampler, |
| batch_size=args.batch_size, |
| shuffle=shuffle, |
| num_workers=args.num_workers, |
| pin_memory=True, |
| ) |
| return train_dataloader |
|
|
| class BaseTrainer: |
| def __init__(self, model, dataloaders, args): |
|
|
| self.model = model |
| self.dataloaders = dataloaders |
| self.args = args |
| self.best_loss = np.inf |
| self.best_dice = 0.0 |
| self.step_best_loss = np.inf |
| self.step_best_dice = 0.0 |
| self.losses = [] |
| self.dices = [] |
| self.ious = [] |
| self.set_loss_fn() |
| self.set_optimizer() |
| self.set_lr_scheduler() |
| if(args.resume): |
| self.init_checkpoint(join(self.args.work_dir, self.args.task_name, 'sam_model_latest.pth')) |
| else: |
| self.init_checkpoint(self.args.checkpoint) |
|
|
| self.norm_transform = tio.ZNormalization(masking_method=lambda x: x > 0) |
| |
| def set_loss_fn(self): |
| self.seg_loss = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean') |
| |
| def set_optimizer(self): |
| if self.args.multi_gpu: |
| sam_model = self.model.module |
| else: |
| sam_model = self.model |
|
|
| self.optimizer = torch.optim.AdamW([ |
| {'params': sam_model.image_encoder.parameters()}, |
| {'params': sam_model.prompt_encoder.parameters() , 'lr': self.args.lr * 0.1}, |
| {'params': sam_model.mask_decoder.parameters(), 'lr': self.args.lr * 0.1}, |
| ], lr=self.args.lr, betas=(0.9,0.999), weight_decay=self.args.weight_decay) |
|
|
| def set_lr_scheduler(self): |
| if self.args.lr_scheduler == "multisteplr": |
| self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, |
| self.args.step_size, |
| self.args.gamma) |
| elif self.args.lr_scheduler == "steplr": |
| self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, |
| self.args.step_size[0], |
| self.args.gamma) |
| elif self.args.lr_scheduler == 'coswarm': |
| self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer) |
| else: |
| self.lr_scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, 0.1) |
|
|
| def init_checkpoint(self, ckp_path): |
| last_ckpt = None |
| if os.path.exists(ckp_path): |
| if self.args.multi_gpu: |
| dist.barrier() |
| last_ckpt = torch.load(ckp_path, map_location=self.args.device) |
| else: |
| last_ckpt = torch.load(ckp_path, map_location=self.args.device) |
| |
| if last_ckpt: |
| if(self.args.allow_partial_weight): |
| if self.args.multi_gpu: |
| self.model.module.load_state_dict(last_ckpt['model_state_dict'], strict=False) |
| else: |
| self.model.load_state_dict(last_ckpt['model_state_dict'], strict=False) |
| else: |
| if self.args.multi_gpu: |
| self.model.module.load_state_dict(last_ckpt['model_state_dict']) |
| else: |
| self.model.load_state_dict(last_ckpt['model_state_dict']) |
| if not self.args.resume: |
| self.start_epoch = 0 |
| else: |
| self.start_epoch = last_ckpt['epoch'] |
| self.optimizer.load_state_dict(last_ckpt['optimizer_state_dict']) |
| self.lr_scheduler.load_state_dict(last_ckpt['lr_scheduler_state_dict']) |
| self.losses = last_ckpt['losses'] |
| self.dices = last_ckpt['dices'] |
| self.best_loss = last_ckpt['best_loss'] |
| self.best_dice = last_ckpt['best_dice'] |
| print(f"Loaded checkpoint from {ckp_path} (epoch {self.start_epoch})") |
| else: |
| self.start_epoch = 0 |
| print(f"No checkpoint found at {ckp_path}, start training from scratch") |
|
|
| def save_checkpoint(self, epoch, state_dict, describe="last"): |
| torch.save({ |
| "epoch": epoch + 1, |
| "model_state_dict": state_dict, |
| "optimizer_state_dict": self.optimizer.state_dict(), |
| "lr_scheduler_state_dict": self.lr_scheduler.state_dict(), |
| "losses": self.losses, |
| "dices": self.dices, |
| "best_loss": self.best_loss, |
| "best_dice": self.best_dice, |
| "args": self.args, |
| "used_datas": img_datas, |
| }, join(MODEL_SAVE_PATH, f"sam_model_{describe}.pth")) |
| |
| def batch_forward(self, sam_model, image_embedding, gt3D, low_res_masks, points=None): |
| |
| sparse_embeddings, dense_embeddings = sam_model.prompt_encoder( |
| points=points, |
| boxes=None, |
| masks=low_res_masks, |
| ) |
| low_res_masks, iou_predictions = sam_model.mask_decoder( |
| image_embeddings=image_embedding.to(device), |
| image_pe=sam_model.prompt_encoder.get_dense_pe(), |
| sparse_prompt_embeddings=sparse_embeddings, |
| dense_prompt_embeddings=dense_embeddings, |
| multimask_output=False, |
| ) |
| prev_masks = F.interpolate(low_res_masks, size=gt3D.shape[-3:], mode='trilinear', align_corners=False) |
| return low_res_masks, prev_masks |
|
|
| def get_points(self, prev_masks, gt3D): |
| batch_points, batch_labels = click_methods[self.args.click_type](prev_masks, gt3D) |
|
|
| points_co = torch.cat(batch_points, dim=0).to(device) |
| points_la = torch.cat(batch_labels, dim=0).to(device) |
|
|
| self.click_points.append(points_co) |
| self.click_labels.append(points_la) |
|
|
| points_multi = torch.cat(self.click_points, dim=1).to(device) |
| labels_multi = torch.cat(self.click_labels, dim=1).to(device) |
|
|
| if self.args.multi_click: |
| points_input = points_multi |
| labels_input = labels_multi |
| else: |
| points_input = points_co |
| labels_input = points_la |
| return points_input, labels_input |
|
|
| def interaction(self, sam_model, image_embedding, gt3D, num_clicks): |
| return_loss = 0 |
| prev_masks = torch.zeros_like(gt3D).to(gt3D.device) |
| low_res_masks = F.interpolate(prev_masks.float(), size=(args.img_size//4,args.img_size//4,args.img_size//4)) |
| random_insert = np.random.randint(2, 9) |
| for num_click in range(num_clicks): |
| points_input, labels_input = self.get_points(prev_masks, gt3D) |
|
|
| if num_click == random_insert or num_click == num_clicks - 1: |
| low_res_masks, prev_masks = self.batch_forward(sam_model, image_embedding, gt3D, low_res_masks, points=None) |
| else: |
| low_res_masks, prev_masks = self.batch_forward(sam_model, image_embedding, gt3D, low_res_masks, points=[points_input, labels_input]) |
| loss = self.seg_loss(prev_masks, gt3D) |
| return_loss += loss |
| return prev_masks, return_loss |
| |
| def get_dice_score(self, prev_masks, gt3D): |
| def compute_dice(mask_pred, mask_gt): |
| mask_threshold = 0.5 |
|
|
| mask_pred = (mask_pred > mask_threshold) |
| mask_gt = (mask_gt > 0) |
| |
| volume_sum = mask_gt.sum() + mask_pred.sum() |
| if volume_sum == 0: |
| return np.NaN |
| volume_intersect = (mask_gt & mask_pred).sum() |
| return 2*volume_intersect / volume_sum |
| |
| pred_masks = (prev_masks > 0.5) |
| true_masks = (gt3D > 0) |
| dice_list = [] |
| for i in range(true_masks.shape[0]): |
| dice_list.append(compute_dice(pred_masks[i], true_masks[i])) |
| return (sum(dice_list)/len(dice_list)).item() |
|
|
|
|
| def train_epoch(self, epoch, num_clicks): |
| epoch_loss = 0 |
| epoch_iou = 0 |
| self.model.train() |
| if self.args.multi_gpu: |
| sam_model = self.model.module |
| else: |
| sam_model = self.model |
| self.args.rank = -1 |
| |
| if not self.args.multi_gpu or (self.args.multi_gpu and self.args.rank == 0): |
| tbar = tqdm(self.dataloaders) |
| else: |
| tbar = self.dataloaders |
|
|
| self.optimizer.zero_grad() |
| step_loss = 0 |
| epoch_dice = 0 |
| for step, data3D in enumerate(tbar): |
| try: |
| image3D, gt3D = data3D["image"], data3D["label"] |
| except Exception as e: |
| print(f"Error processing batch at step {step}: {e}") |
| |
| my_context = self.model.no_sync if self.args.rank != -1 and step % self.args.accumulation_steps != 0 else nullcontext |
|
|
| with my_context(): |
|
|
| image3D = self.norm_transform(image3D.squeeze(dim=1)) |
| image3D = image3D.unsqueeze(dim=1) |
| |
| image3D = image3D.to(device) |
| gt3D = gt3D.to(device).type(torch.long) |
| with torch.amp.autocast("cuda"): |
| image_embedding = sam_model.image_encoder(image3D) |
|
|
| self.click_points = [] |
| self.click_labels = [] |
|
|
| pred_list = [] |
|
|
| prev_masks, loss = self.interaction(sam_model, image_embedding, gt3D, num_clicks=11) |
|
|
| epoch_loss += loss.item() |
| epoch_dice += self.get_dice_score(prev_masks,gt3D) |
| cur_loss = loss.item() |
| |
| loss /= self.args.accumulation_steps |
| |
| self.scaler.scale(loss).backward() |
|
|
| if step % self.args.accumulation_steps == 0 and step != 0: |
| self.scaler.step(self.optimizer) |
| self.scaler.update() |
| self.optimizer.zero_grad() |
|
|
| print_loss = step_loss / self.args.accumulation_steps |
| step_loss = 0 |
| print_dice = self.get_dice_score(prev_masks, gt3D) |
| else: |
| step_loss += cur_loss |
|
|
| if not self.args.multi_gpu or (self.args.multi_gpu and self.args.rank == 0): |
| if step % self.args.accumulation_steps == 0 and step != 0: |
| print(f'Epoch: {epoch}, Step: {step}, Loss: {print_loss}, Dice: {print_dice}') |
| if print_dice > self.step_best_dice: |
| self.step_best_dice = print_dice |
| if print_dice > 0.9: |
| self.save_checkpoint( |
| epoch, |
| sam_model.state_dict(), |
| describe=f'{epoch}_step_dice:{print_dice}_best' |
| ) |
| if print_loss < self.step_best_loss: |
| self.step_best_loss = print_loss |
| |
| epoch_loss /= step+1 |
| epoch_dice /= step+1 |
|
|
| return epoch_loss, epoch_iou, epoch_dice, pred_list |
|
|
| def eval_epoch(self, epoch, num_clicks): |
| return 0 |
| |
| def plot_result(self, plot_data, description, save_name): |
| plt.plot(plot_data) |
| plt.title(description) |
| plt.xlabel('Epoch') |
| plt.ylabel(f'{save_name}') |
| plt.savefig(join(MODEL_SAVE_PATH, f'{save_name}.png')) |
| plt.close() |
|
|
|
|
| def train(self): |
| self.scaler = torch.amp.GradScaler("cuda") |
| for epoch in range(self.start_epoch, self.args.num_epochs): |
| print(f'Epoch: {epoch}/{self.args.num_epochs - 1}') |
|
|
| if self.args.multi_gpu: |
| dist.barrier() |
| self.dataloaders.sampler.set_epoch(epoch) |
| num_clicks = np.random.randint(1, 21) |
| epoch_loss, epoch_iou, epoch_dice, pred_list = self.train_epoch(epoch, num_clicks) |
|
|
| if self.lr_scheduler is not None: |
| self.lr_scheduler.step() |
| if self.args.multi_gpu: |
| dist.barrier() |
| |
| if not self.args.multi_gpu or (self.args.multi_gpu and self.args.rank == 0): |
| self.losses.append(epoch_loss) |
| self.dices.append(epoch_dice) |
| print(f'EPOCH: {epoch}, Loss: {epoch_loss}') |
| print(f'EPOCH: {epoch}, Dice: {epoch_dice}') |
| logger.info(f'Epoch\t {epoch}\t : loss: {epoch_loss}, dice: {epoch_dice}') |
|
|
| if self.args.multi_gpu: |
| state_dict = self.model.module.state_dict() |
| else: |
| state_dict = self.model.state_dict() |
| |
| |
| self.save_checkpoint( |
| epoch, |
| state_dict, |
| describe='latest' |
| ) |
|
|
| |
| if epoch_loss < self.best_loss: |
| self.best_loss = epoch_loss |
| self.save_checkpoint( |
| epoch, |
| state_dict, |
| describe='loss_best' |
| ) |
| |
| |
| if epoch_dice > self.best_dice: |
| self.best_dice = epoch_dice |
| self.save_checkpoint( |
| epoch, |
| state_dict, |
| describe='dice_best' |
| ) |
|
|
| self.plot_result(self.losses, 'Dice + Cross Entropy Loss', 'Loss') |
| self.plot_result(self.dices, 'Dice', 'Dice') |
| logger.info('=====================================================================') |
| logger.info(f'Best loss: {self.best_loss}') |
| logger.info(f'Best dice: {self.best_dice}') |
| logger.info(f'Total loss: {self.losses}') |
| logger.info(f'Total dice: {self.dices}') |
| logger.info('=====================================================================') |
| logger.info(f'args : {self.args}') |
| logger.info(f'Used datasets : {img_datas}') |
| logger.info('=====================================================================') |
|
|
| def init_seeds(seed=0, cuda_deterministic=True): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| |
| if cuda_deterministic: |
| cudnn.deterministic = True |
| cudnn.benchmark = False |
| else: |
| cudnn.deterministic = False |
| cudnn.benchmark = True |
| |
| def device_config(args): |
| try: |
| if not args.multi_gpu: |
| |
| if args.device == 'mps': |
| args.device = torch.device('mps') |
| else: |
| args.device = torch.device(f"cuda:{args.gpu_ids[0]}") |
| else: |
| args.nodes = 1 |
| args.ngpus_per_node = len(args.gpu_ids) |
| args.world_size = args.nodes * args.ngpus_per_node |
|
|
| except RuntimeError as e: |
| print(e) |
|
|
|
|
| def main(): |
| mp.set_sharing_strategy('file_system') |
| device_config(args) |
| if args.multi_gpu: |
| mp.spawn( |
| main_worker, |
| nprocs=args.world_size, |
| args=(args, ) |
| ) |
| else: |
| random.seed(2023) |
| np.random.seed(2023) |
| torch.manual_seed(2023) |
| |
| dataloaders = get_dataloaders(args) |
| |
| model = build_model(args) |
| |
| trainer = BaseTrainer(model, dataloaders, args) |
| |
| trainer.train() |
|
|
| def main_worker(rank, args): |
| setup(rank, args.world_size) |
|
|
| torch.cuda.set_device(rank) |
| args.num_workers = int(args.num_workers / args.ngpus_per_node) |
| args.device = torch.device(f"cuda:{rank}") |
| args.rank = rank |
|
|
| init_seeds(2023 + rank) |
|
|
| cur_time = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') |
| logging.basicConfig( |
| format='[%(asctime)s] - %(message)s', |
| datefmt='%Y/%m/%d %H:%M:%S', |
| level=logging.INFO if rank in [-1, 0] else logging.WARN, |
| filemode='w', |
| filename=os.path.join(LOG_OUT_DIR, f'output_{cur_time}.log')) |
| |
| dataloaders = get_dataloaders(args) |
| model = build_model(args) |
| trainer = BaseTrainer(model, dataloaders, args) |
| trainer.train() |
| cleanup() |
|
|
|
|
| def setup(rank, world_size): |
| |
| dist.init_process_group( |
| backend='nccl', |
| init_method=f'tcp://127.0.0.1:{args.port}', |
| world_size=world_size, |
| rank=rank |
| ) |
|
|
| def cleanup(): |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|