| import torch |
| import torch.nn.functional as F |
| from tqdm import tqdm |
|
|
| from utils.endtoend import multiclass_dice_coeff, multiclass_iou |
| num_classes = 8 |
|
|
| def evaluate(net, dataloader, device, eval_class): |
| net.eval() |
| num_val_batches = len(dataloader) |
| dice_score = 0 |
| iou_score = 0 |
| |
| for batch in tqdm(dataloader, desc='Validation round', unit='batch', leave=False): |
| image, mask_true = batch['image'], batch['mask_ete'] |
| |
| image = image.to(device=device, dtype=torch.float32) |
| mask_true = mask_true.to(device=device, dtype=torch.long) |
| |
| |
| mask_true_vector = F.one_hot(mask_true, num_classes).permute(0, 3 , 1, 2).float() |
| with torch.no_grad(): |
| |
| mask_pred = net(image) |
| mask_pred = mask_pred.argmax(dim=1) |
| mask_pred_vector = F.one_hot(mask_pred, num_classes).permute(0, 3 , 1, 2).float() |
| |
| dice_score += multiclass_dice_coeff(mask_pred_vector[:, eval_class, ...], mask_true_vector[:, eval_class, ...], |
| reduce_batch_first=False) |
| iou_score += multiclass_iou(mask_pred_vector[:,eval_class, ...], mask_true_vector[:, eval_class, ...]) |
|
|
| net.train() |
| return dice_score / num_val_batches, iou_score/ num_val_batches |
|
|
|
|
| def evaluate_3d_iou(net, dataset, device, eval_class): |
| net.eval() |
| iou_score = 0 |
| |
| num_items = 0 |
| for image_3d in tqdm(dataset.get_3d_iter(), desc='3D Evaluation', unit='image(s)', leave=False): |
| image, mask_true = image_3d['image'], image_3d['mask_ete'] |
| num_items += 1 |
| |
|
|
| image = image.to(device=device, dtype=torch.float32) |
| mask_true = mask_true.to(device=device, dtype=torch.long) |
| mask_true_vector = F.one_hot(mask_true, num_classes).permute(0, 3, 1, 2).float() |
|
|
| with torch.no_grad(): |
| |
| mask_pred = net(image) |
| mask_pred = mask_pred.argmax(dim=1) |
| mask_pred_vector = F.one_hot(mask_pred, num_classes).permute(0, 3, 1, 2).float() |
| iou_score += multiclass_iou(mask_pred_vector[:, eval_class, ...], mask_true_vector[:, eval_class, ...], reduce_batch_first=True) |
| net.train() |
| return iou_score/num_items |
|
|
| def evaluate_3d_iou_large(net, dataset, device, eval_class): |
| net.eval() |
| iou_score = 0 |
| |
| num_items = 0 |
| for image_3d in tqdm(dataset.get_3d_iter(), desc='3D Evaluation', unit='image(s)', leave=False): |
| image, mask_true = image_3d['image'], image_3d['mask'] |
| num_items += 1 |
| |
|
|
| image = image.to(device=device) |
| mask_true = mask_true.to(device=device) |
| mask_true_vector = F.one_hot(mask_true, num_classes).permute(0, 3, 1, 2).float() |
|
|
| net.to(device=device) |
| with torch.no_grad(): |
| |
| mask_pred = net(image) |
| mask_pred = mask_pred.argmax(dim=1) |
| mask_pred_vector = F.one_hot(mask_pred, num_classes).permute(0, 3, 1, 2).float() |
| iou_score += multiclass_iou(mask_pred_vector[:, eval_class, ...], mask_true_vector[:, eval_class, ...], reduce_batch_first=True) |
| net.train() |
| return iou_score/num_items |
|
|
| def evaluate_3d_iou_fast(net, dataset, device, eval_class): |
| """ |
| This function is similar as evaluate_3d_iou but get a batch size in shape [batch_size, dimension, W, H] |
| :param net: |
| :param dataset: |
| :param device: |
| :param eval_class: |
| :return: |
| """ |
| net.eval() |
| iou_score = 0 |
| |
| num_items = 0 |
| for image_3d in tqdm(dataset, desc='3D Evaluation', unit='image(s)', leave=False): |
| image, mask_true = image_3d['image'][0], image_3d['mask'][0] |
| |
| num_items += 1 |
| |
|
|
| image = image.to(device=device, dtype=torch.float32) |
| mask_true = mask_true.to(device=device, dtype=torch.long) |
| mask_true_vector = F.one_hot(mask_true, num_classes).permute(0, 3, 1, 2).float() |
|
|
| with torch.no_grad(): |
| |
| mask_pred = net(image) |
| mask_pred = mask_pred.argmax(dim=1) |
| mask_pred_vector = F.one_hot(mask_pred, num_classes).permute(0, 3, 1, 2).float() |
| iou_score += multiclass_iou(mask_pred_vector[:, eval_class, ...], mask_true_vector[:, eval_class, ...], reduce_batch_first=True) |
| net.train() |
| return iou_score/num_items |
|
|
| def evaluate_3d_dice(net, dataset, device, eval_class): |
| net.eval() |
| dice_score = 0 |
| |
| num_items = 0 |
| for image_3d in tqdm(dataset.get_3d_iter(), desc='3D Evaluation', unit='image(s)', leave=False): |
| image, mask_true = image_3d['image'], image_3d['mask_ete'] |
| num_items += 1 |
| |
|
|
| image = image.to(device=device, dtype=torch.float32) |
| mask_true = mask_true.to(device=device, dtype=torch.long) |
| mask_true_vector = F.one_hot(mask_true, num_classes).permute(0, 3, 1, 2).float() |
|
|
| with torch.no_grad(): |
| |
| mask_pred = net(image) |
| mask_pred = mask_pred.argmax(dim=1) |
| mask_pred_vector = F.one_hot(mask_pred, num_classes).permute(0, 3, 1, 2).float() |
| dice_score += multiclass_dice_coeff(mask_pred_vector[:, eval_class, ...], mask_true_vector[:, eval_class, ...], reduce_batch_first=True) |
| net.train() |
| return dice_score/num_items |
|
|