| import time |
| import matplotlib as mpl |
| mpl.use('Agg') |
|
|
| import numpy as np |
| import torch |
| import torch.nn.parallel |
| import torch.optim |
| from torch.autograd import Variable |
| from torch.cuda.amp import autocast as autocast |
|
|
| from model.model_sbert_gref import * |
| from dataset.data_loader import * |
| from utils.losses import * |
| from utils.parsing_metrics import * |
| from utils.utils import * |
| from utils.utils import dice_loss, sigmoid_focal_loss |
|
|
| use_cuda = torch.cuda.is_available() |
| print("use_cuda, ", use_cuda) |
|
|
|
|
| def return_mask(emb_distance, verb_mask=None, rows_to_filter=None, cols_to_filter=None): |
| B_, B_ = emb_distance.shape |
| positive_mask = torch.zeros_like(emb_distance) |
| positive_mask.fill_diagonal_(1) |
| |
| if B_ < len(verb_mask): |
| |
| for i in range(B_ // 2): |
| positive_mask[2 * i, 2 * i + 1] = 1 |
| positive_mask[2 * i + 1, 2 * i] = 1 |
| else: |
| |
| i = 0 |
| while i < B_: |
| if verb_mask[i] == 1: |
| positive_mask[i, i + 1] = 1 |
| positive_mask[i + 1, i] = 1 |
| i += 2 |
| else: |
| i += 1 |
| negative_mask = torch.ones_like(emb_distance) - positive_mask |
| negative_mask = negative_mask.clone() |
| |
| if rows_to_filter is not None and cols_to_filter is not None : |
| for row, col in zip(rows_to_filter, cols_to_filter): |
| negative_mask[row * 2, col * 2] = 0 |
| negative_mask[row * 2, col * 2 + 1] = 0 |
| negative_mask[row * 2 + 1, col * 2] = 0 |
| negative_mask[row * 2 + 1, col * 2 + 1] = 0 |
|
|
| return positive_mask, negative_mask |
|
|
|
|
| def UniAngularLogitContrastLoss(total_fq, verb_mask, rows_to_filter, cols_to_filter, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None): |
| _, C, H, W = total_fq.shape |
|
|
| |
| if verbonly : |
| B = total_fq[verb_mask].shape[0] |
| emb = torch.mean(total_fq[verb_mask], dim=(-1, -2)).reshape(B, C) |
| assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2." |
| else : |
| emb = torch.mean(total_fq, dim=-1) |
|
|
| B_ = emb.shape[0] |
| emb_i = emb.unsqueeze(1).repeat(1, B_, 1) |
| emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) |
|
|
| sim = nn.CosineSimilarity(dim=-1, eps=1e-6) |
| sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) |
| sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999) |
|
|
| margin_in_radians = m / 57.2958 |
| theta_matrix = (torch.pi / 2) - torch.acos(sim_matrix) |
| |
| |
| |
| positive_mask, negative_mask = return_mask(sim_matrix, verb_mask, rows_to_filter, cols_to_filter) |
| |
| |
| |
| |
| |
| |
|
|
|
|
| theta_with_margin = theta_matrix.clone() |
| theta_with_margin[positive_mask.bool()] -= margin_in_radians |
| logits = theta_with_margin / tau |
| |
| |
| exp_logits = torch.exp(logits) |
| pos_exp_logits = exp_logits * positive_mask |
| pos_exp_logits = pos_exp_logits.sum(dim=-1) |
| neg_exp_logits = exp_logits * negative_mask |
| neg_exp_logits = neg_exp_logits.sum(dim=-1) |
| |
| total_exp_logits = pos_exp_logits + neg_exp_logits |
|
|
| positive_loss = -torch.log(pos_exp_logits/ total_exp_logits) |
| angular_loss = positive_loss.mean() |
| |
|
|
| return angular_loss, B_ |
|
|
|
|
| def train_epoch(rank, args, train_loader, model, optimizer, epoch, scaler, logger): |
| print('train at epoch %d'%epoch) |
| batch_time = AverageMeter() |
| losses = AverageMeter() |
| dice_losses = AverageMeter() |
| sigmoid_focal_losses = AverageMeter() |
| cos_losses = AverageMeter() |
| model.train() |
| end = time.time() |
| |
| |
| mlw = args.metric_loss_weight |
| metric_mode = args.metric_mode |
| filter_thres = args.filter_thres |
| metric_learning = args.metric_learning |
| |
| for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, params) in enumerate(train_loader): |
| B = imgs.size(0) |
| |
| hp_word_id = params['hp_word_id'] |
| hp_word_mask = params['hp_word_mask'] |
| hp_bert_embs = params['hardpos_emb'].cuda(non_blocking=True).squeeze(1) |
| pos_type = np.array(params['pos_type']) |
|
|
| pos_mask = torch.tensor(np.where(pos_type == 'hardpos', 1, 0)) |
| |
| |
| |
|
|
| |
| verb_masks = [] |
| cl_masks = [] |
| images = [] |
| targets = [] |
| sentences_ = [] |
| sentences_masked_ = [] |
| |
| for idx in range(len(imgs)) : |
| sentences_.append(word_id[idx]) |
| sentences_masked_.append(word_mask[idx]) |
| images.append(imgs[idx]) |
| targets.append(seg_map[idx]) |
|
|
| |
| if pos_mask[idx] : |
| verb_masks.extend([1, 1]) |
| cl_masks.extend([1, 0]) |
| sentences_.append(hp_word_id[idx]) |
| sentences_masked_.append(hp_word_mask[idx]) |
| images.append(imgs[idx]) |
| targets.append(seg_map[idx]) |
| else: |
| verb_masks.append(0) |
| cl_masks.append(1) |
|
|
| imgs, seg_map, word_id, word_mask, verb_masks, cl_masks = \ |
| torch.stack(images).cuda(rank, non_blocking=True),\ |
| torch.stack(targets).cuda(rank, non_blocking=True),\ |
| torch.stack(sentences_).cuda(rank, non_blocking=True),\ |
| torch.stack(sentences_masked_).cuda(rank, non_blocking=True),\ |
| torch.tensor(verb_masks, dtype=torch.bool).cuda(rank, non_blocking=True),\ |
| torch.tensor(cl_masks, dtype=torch.bool).cuda(rank, non_blocking=True) |
|
|
| image = Variable(imgs) |
| word_id = Variable(word_id) |
| word_mask = Variable(word_mask) |
| seg_map = Variable(seg_map) |
| verb_masks = Variable(verb_masks) |
| cl_masks = Variable(cl_masks) |
| |
| if hp_bert_embs.numel() > 0 : |
| mask = ~torch.all(hp_bert_embs == 0, dim=1) |
| hp_bert_embs = hp_bert_embs[mask] |
| |
| norms = torch.norm(hp_bert_embs, dim=-1, keepdim=True) |
| normed_embs = hp_bert_embs / norms |
| cosime_sim = torch.mm(normed_embs, normed_embs.T) |
| rows_to_filter, cols_to_filter = torch.where(cosime_sim > filter_thres) |
|
|
| |
| |
| |
| |
|
|
| |
| |
| with autocast(): |
| mask_out_all, metric_tensors = model(image, word_id, word_mask) |
| loss = 0. |
| |
| |
| mask_out = mask_out_all[cl_masks] |
| seg_map_cl = seg_map[cl_masks] |
| |
| mask_out_np = mask_out.data.cpu().numpy() |
| seg_map_np = seg_map_cl.cpu().numpy() |
| seg_iou = cal_seg_iou_loss(seg_map_np, mask_out_np, args.seg_thresh) |
| |
| dice_loss_ = dice_loss(mask_out, seg_map_cl) |
| sigmoid_focal_loss_ = sigmoid_focal_loss(mask_out, seg_map_cl) |
|
|
| dice_weight, focal_weight = 1.0, 1.0 |
| loss = (dice_weight * dice_loss_) + (focal_weight * sigmoid_focal_loss_) |
| |
| |
| if metric_learning and sum(pos_mask) > 1 : |
| metric_weight = mlw |
| |
| metric_loss, NS = UniAngularLogitContrastLoss(metric_tensors, verb_masks, rows_to_filter, cols_to_filter, m=args.margin_value, tau=args.temperature, verbonly=True, args=args) |
| loss += metric_weight * metric_loss |
|
|
| optimizer.zero_grad() |
| scaler.scale(loss).backward() |
| scaler.step(optimizer) |
| scaler.update() |
|
|
| losses.update(loss.item(), B) |
| dice_losses.update(dice_loss_.item(), B) |
| sigmoid_focal_losses.update(sigmoid_focal_loss_.item(), B) |
| cos_losses.update(seg_iou.mean().item(), B) |
| |
| |
| batch_time.update(time.time() - end) |
| end = time.time() |
|
|
| if rank == 0 and batch_idx % args.print_freq == 0: |
| print_str = 'Epoch: [{0}][{1}/{2}]\t' \ |
| 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ |
| 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \ |
| 'dice_losses {dice_losses.val:.4f} ({dice_losses.avg:.4f})\t' \ |
| 'sigmoid_focal_losses {sigmoid_focal_losses.val:.4f} ({sigmoid_focal_losses.avg:.4f})\t' \ |
| 'IoU {cos_loss.val:.4f} ({cos_loss.avg:.4f})\t' \ |
| .format(epoch, batch_idx, len(train_loader), batch_time=batch_time, loss=losses, dice_losses=dice_losses, sigmoid_focal_losses=sigmoid_focal_losses, cos_loss=cos_losses) |
| print(print_str) |
| logger.info(print_str) |
|
|
| return losses.avg |
|
|
|
|
|
|
| def validate_epoch(args, val_loader, model, logger, mode='val'): |
| print('begin test') |
| batch_time = AverageMeter() |
| miou = AverageMeter() |
| miou_seg = AverageMeter() |
|
|
| prec=dict() |
| thresholds = np.arange(0.5, 1, 0.05) |
|
|
| for thresh in thresholds: |
| prec[thresh]= AverageMeter() |
|
|
| model.eval() |
| end = time.time() |
| idx = 0 |
|
|
| t_all = [] |
| total_intersection = 0.0 |
| total_union = 0.0 |
| |
| for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, ratio, dw, dh, im_id, phrase, draw_img) in enumerate(val_loader): |
| |
| imgs = imgs.cuda(0) |
| word_id = word_id.cuda(0) |
| word_mask = word_mask.cuda(0) |
| seg_map = seg_map.cuda(0) |
| image = Variable(imgs) |
| word_id = Variable(word_id) |
| word_mask = Variable(word_mask) |
| seg_map = Variable(seg_map) |
|
|
| t1 = time.time() |
| with torch.no_grad(): |
| mask_out, _ = model(image, word_id, word_mask) |
| mask_out = mask_out.sigmoid() |
|
|
| t2 = time.time() |
| t_all.append(t2-t1) |
|
|
| |
| ih = seg_map.shape[-2] |
| iw = seg_map.shape[-1] |
| nh = int(ih * ratio) |
| nw = int(iw * ratio) |
| top, bottom = int(dh[0]), nh + int(dh[0]) |
| left, right = int(dw[0]), nw + int(dw[0]) |
| ratio = float(ratio) |
| new_shape = (iw, ih) |
| |
| |
| seg_map_np = seg_map[0,:,:,:].data.cpu().numpy().transpose(1,2,0) |
| seg_map_np = cv2.resize(seg_map_np, new_shape, interpolation=cv2.INTER_CUBIC) |
| img_np = imgs[0,:,top:bottom,left:right].data.cpu().numpy().transpose(1,2,0) |
| img_np = cv2.resize(img_np, new_shape, interpolation=cv2.INTER_CUBIC) |
|
|
| img_np = Variable(torch.from_numpy(img_np.transpose(2,0,1)).cuda().unsqueeze(0)) |
| |
| |
| mask_out = mask_out[0].data.cpu().numpy().transpose(1,2,0) |
| mask_out = cv2.resize(mask_out, (args.size, args.size)) |
| mask_out_np = mask_out[top:bottom, left:right] |
| mask_out_np = cv2.resize(mask_out_np, new_shape) |
| |
| seg_iou, seg_prec, inter_sum, union_sum = cal_seg_iou2(seg_map_np, mask_out_np, args.seg_thresh) |
|
|
| miou_seg.update(seg_iou, imgs.size(0)) |
| total_intersection += inter_sum |
| total_union += union_sum |
|
|
| for thresh in thresholds: |
| prec[thresh].update(seg_prec[thresh], imgs.size(0)) |
|
|
| |
| batch_time.update(time.time() - end) |
| end = time.time() |
| if batch_idx % 1000 == 0: |
| print_str = '[{0}/{1}]\t' \ |
| 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ |
| 'seg_iu {seg.val:.4f} ({seg.avg:.4f})\t' \ |
| .format( \ |
| batch_idx, len(val_loader), batch_time=batch_time, seg=miou_seg) |
| print(print_str) |
| logger.info(print_str) |
| idx = idx + 1 |
| overall_iou = (total_intersection + 1e-10) / (total_union + 1e-10) |
|
|
| print("Mean IoU:", miou_seg.avg) |
| print("Overall IoU:", overall_iou) |
| logger.info("Mean IoU: %.4f" % miou_seg.avg) |
| logger.info("Overall IoU: %.4f" % overall_iou) |
| |
| for thresh in thresholds: |
| print("prec@%f: %f"%(thresh,float(prec[thresh].avg))) |
| logger.info("prec@%f:%f"%(thresh,float(prec[thresh].avg))) |
| |
| |
| return miou_seg.avg, prec |
|
|
|
|
|
|