| import time |
| import matplotlib as mpl |
| mpl.use('Agg') |
|
|
| import numpy as np |
| import torch |
| import torch.nn.parallel |
| import torch.optim |
| from torch.autograd import Variable |
| from torch.cuda.amp import autocast as autocast |
|
|
| from model.model import * |
| from dataset.data_loader import * |
| from utils.losses import * |
| from utils.parsing_metrics import * |
| from utils.utils import * |
| from utils.utils import dice_loss, sigmoid_focal_loss |
|
|
| use_cuda = torch.cuda.is_available() |
| print("use_cuda, ", use_cuda) |
|
|
|
|
| def train_epoch(rank, args, train_loader, model, optimizer, epoch, scaler, logger): |
| print('train at epoch %d'%epoch) |
| batch_time = AverageMeter() |
| losses = AverageMeter() |
| dice_losses = AverageMeter() |
| sigmoid_focal_losses = AverageMeter() |
| cos_losses = AverageMeter() |
| model.train() |
| end = time.time() |
|
|
| for batch_idx, (imgs, word_id, word_mask, bbox, seg_map) in enumerate(train_loader): |
| imgs = imgs.cuda(rank, non_blocking=True) |
| word_id = word_id.cuda(rank, non_blocking=True) |
| word_mask = word_mask.cuda(rank, non_blocking=True) |
| seg_map = seg_map.cuda(rank, non_blocking=True) |
| image = Variable(imgs) |
| word_id = Variable(word_id) |
| word_mask = Variable(word_mask) |
| seg_map = Variable(seg_map) |
|
|
| with autocast(): |
| mask_out = model(image, word_id, word_mask) |
| loss = 0. |
| |
| mask_out_np = mask_out.data.cpu().numpy() |
| seg_map_np = seg_map.cpu().numpy() |
| seg_iou = cal_seg_iou_loss(seg_map_np, mask_out_np, args.seg_thresh) |
| |
| dice_loss_ = dice_loss(mask_out, seg_map) |
| sigmoid_focal_loss_ = sigmoid_focal_loss(mask_out, seg_map) |
|
|
| loss += dice_loss_ + sigmoid_focal_loss_ |
|
|
| optimizer.zero_grad() |
| scaler.scale(loss).backward() |
| scaler.step(optimizer) |
| scaler.update() |
|
|
| losses.update(loss.item(), imgs.size(0)) |
| dice_losses.update(dice_loss_.item(), imgs.size(0)) |
| sigmoid_focal_losses.update(sigmoid_focal_loss_.item(), imgs.size(0)) |
| cos_losses.update(seg_iou.mean().item(), imgs.size(0)) |
| |
| |
| batch_time.update(time.time() - end) |
| end = time.time() |
|
|
| if rank == 0 and batch_idx % args.print_freq == 0: |
| print_str = 'Epoch: [{0}][{1}/{2}]\t' \ |
| 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ |
| 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \ |
| 'dice_losses {dice_losses.val:.4f} ({dice_losses.avg:.4f})\t' \ |
| 'sigmoid_focal_losses {sigmoid_focal_losses.val:.4f} ({sigmoid_focal_losses.avg:.4f})\t' \ |
| 'IoU {cos_loss.val:.4f} ({cos_loss.avg:.4f})\t' \ |
| .format(epoch, batch_idx, len(train_loader), batch_time=batch_time, loss=losses, dice_losses=dice_losses, sigmoid_focal_losses=sigmoid_focal_losses, cos_loss=cos_losses) |
| print(print_str) |
| logger.info(print_str) |
|
|
| return losses.avg |
|
|
| def validate_epoch(args, val_loader, model, logger, mode='val'): |
| print('begin test') |
| batch_time = AverageMeter() |
| miou = AverageMeter() |
| miou_seg = AverageMeter() |
|
|
| prec=dict() |
| thresholds = np.arange(0.5, 1, 0.05) |
|
|
| for thresh in thresholds: |
| prec[thresh]= AverageMeter() |
|
|
| model.eval() |
| end = time.time() |
| idx = 0 |
|
|
| t_all = [] |
| total_intersection = 0.0 |
| total_union = 0.0 |
| |
| for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, ratio, dw, dh, im_id, phrase, draw_img) in enumerate(val_loader): |
| |
| imgs = imgs.cuda(0) |
| word_id = word_id.cuda(0) |
| word_mask = word_mask.cuda(0) |
| seg_map = seg_map.cuda(0) |
| image = Variable(imgs) |
| word_id = Variable(word_id) |
| word_mask = Variable(word_mask) |
| seg_map = Variable(seg_map) |
|
|
| t1 = time.time() |
| with torch.no_grad(): |
| mask_out = model(image, word_id, word_mask) |
| mask_out = mask_out.sigmoid() |
|
|
| t2 = time.time() |
| t_all.append(t2-t1) |
|
|
| |
| ih = seg_map.shape[-2] |
| iw = seg_map.shape[-1] |
| nh = int(ih * ratio) |
| nw = int(iw * ratio) |
| top, bottom = int(dh[0]), nh + int(dh[0]) |
| left, right = int(dw[0]), nw + int(dw[0]) |
| ratio = float(ratio) |
| new_shape = (iw, ih) |
| |
| |
| seg_map_np = seg_map[0,:,:,:].data.cpu().numpy().transpose(1,2,0) |
| seg_map_np = cv2.resize(seg_map_np, new_shape, interpolation=cv2.INTER_CUBIC) |
| img_np = imgs[0,:,top:bottom,left:right].data.cpu().numpy().transpose(1,2,0) |
| img_np = cv2.resize(img_np, new_shape, interpolation=cv2.INTER_CUBIC) |
|
|
| img_np = Variable(torch.from_numpy(img_np.transpose(2,0,1)).cuda().unsqueeze(0)) |
| |
| |
| mask_out = mask_out[0].data.cpu().numpy().transpose(1,2,0) |
| mask_out = cv2.resize(mask_out, (args.size, args.size)) |
| mask_out_np = mask_out[top:bottom, left:right] |
| mask_out_np = cv2.resize(mask_out_np, new_shape) |
| |
| seg_iou, seg_prec, inter_sum, union_sum = cal_seg_iou2(seg_map_np, mask_out_np, args.seg_thresh) |
|
|
| miou_seg.update(seg_iou, imgs.size(0)) |
| total_intersection += inter_sum |
| total_union += union_sum |
|
|
| for thresh in thresholds: |
| prec[thresh].update(seg_prec[thresh], imgs.size(0)) |
|
|
| |
| batch_time.update(time.time() - end) |
| end = time.time() |
| if batch_idx % 1000 == 0: |
| print_str = '[{0}/{1}]\t' \ |
| 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ |
| 'seg_iu {seg.val:.4f} ({seg.avg:.4f})\t' \ |
| .format( \ |
| batch_idx, len(val_loader), batch_time=batch_time, seg=miou_seg) |
| print(print_str) |
| logger.info(print_str) |
| idx = idx + 1 |
| overall_iou = (total_intersection + 1e-10) / (total_union + 1e-10) |
|
|
| print("Mean IoU:", miou_seg.avg) |
| print("Overall IoU:", overall_iou) |
| logger.info("Mean IoU: %.4f" % miou_seg.avg) |
| logger.info("Overall IoU: %.4f" % overall_iou) |
| |
| for thresh in thresholds: |
| print("prec@%f: %f"%(thresh,float(prec[thresh].avg))) |
| logger.info("prec@%f:%f"%(thresh,float(prec[thresh].avg))) |
| |
| return miou_seg.avg, overall_iou, prec |
|
|
|
|