| import datetime |
| import logging |
| import time |
| import itertools |
|
|
| import torch |
| import torch.distributed as dist |
|
|
| from utils.basic_utils import MetricLogger |
|
|
| from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer |
| from pycocoevalcap.bleu.bleu import Bleu |
| from pycocoevalcap.cider.cider import Cider |
| from pycocoevalcap.meteor.meteor import Meteor |
| from pycocoevalcap.rouge.rouge import Rouge |
| from pycocoevalcap.spice.spice import Spice |
|
|
| from utils.basic_utils import MetricLogger |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def eval_nlp_scores(pred, gt, verbose=False): |
| """ |
| Stolen from https://github.com/zohrehghaderi/VASTA/blob/ede0461fd0fc00da575bfca2399532e8eb7607ac/nlp_metrics/cocval_evalution.py#L9 |
| evaluates the nlp scores bleu1-bleu4, meteor, rouge-l, cider, spice |
| Also logs the corpus values as scalars and the individual scores as histograms! |
| Args: |
| pred (List): List of predictions |
| gt (List): List of ground truths |
| """ |
| tokenizer = PTBTokenizer() |
| |
| gts = tokenizer.tokenize(gt) |
| res = tokenizer.tokenize(pred) |
|
|
| |
| if verbose: print('Setting up scorers...') |
| scorers = [ |
| (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), |
| (Meteor(), "METEOR"), |
| (Rouge(), "ROUGE_L"), |
| (Cider(), "CIDEr") |
| ] |
|
|
| |
| results = {} |
| for scorer, method in scorers: |
| score, scores= scorer.compute_score(gts, res) |
| if isinstance(method, list): |
| for sc, scs, m in zip(score, scores, method): |
| results[m] = sc |
| else: |
| results[method] = score |
|
|
| return results |
|
|
| @torch.no_grad() |
| def evaluation_wrapper(model, data_loader, tokenizer, device, config, prefix=""): |
| model.eval() |
|
|
| metric_logger = MetricLogger(delimiter=" ") |
| header = "[evaluation] Generating captions:" |
| log_freq = config.log_freq // 2 |
|
|
| logger.info("Start generating results.") |
|
|
| iterator = metric_logger.log_every(data_loader, log_freq, header) |
|
|
| all_texts = data_loader.dataset.text |
| img2txt = data_loader.dataset.img2txt |
| for v in img2txt.values(): |
| assert len(v) == 1, "Only support one caption per image" |
| img2txt = {k: v[0] for k, v in img2txt.items()} |
|
|
| all_pred_caption = [] |
| all_gt_caption = [] |
|
|
| for n, (image, idx) in enumerate(iterator): |
| image = image.to(device, non_blocking=True) |
| caption = [all_texts[img2txt[idx[i].item()]] for i in range(len(idx))] |
|
|
| pred_tokens, pred_caption = model( |
| image, |
| train=False, |
| raw_caption=caption, |
| ) |
|
|
| all_pred_caption += pred_caption |
| all_gt_caption += caption |
| |
| logger.info("Finish generating results.") |
| logger.info("Computing accuracy.") |
| |
| preds = [None] * dist.get_world_size() |
| gts = [None] * dist.get_world_size() |
| dist.all_gather_object(preds, all_pred_caption) |
| dist.all_gather_object(gts, all_gt_caption) |
|
|
| preds = list(itertools.chain(*preds)) |
| gts = list(itertools.chain(*gts)) |
|
|
| preds = {k: [{'caption': v}] for k, v in enumerate(preds)} |
| gts = {k: [{'caption': v}] for k, v in enumerate(gts)} |
|
|
| if dist.get_rank() == 0: |
| results = eval_nlp_scores(preds, gts, verbose=False) |
| results = [results] |
| else: |
| results = [None] |
| |
| dist.broadcast_object_list(results, src=0) |
| |
| return {prefix: results[0]} |
|
|
|
|