| import re
|
| from rouge import Rouge
|
| import argparse
|
| import os
|
| import json
|
| import numpy as np
|
| from sklearn.feature_extraction.text import TfidfVectorizer
|
| from sklearn.metrics.pairwise import cosine_similarity
|
|
|
|
|
| spot_the_diff = ["Spot-the-Diff", "Birds-to-Words", "CLEVR-Change"]
|
| image_edit_instruct = ["IEdit", "HQ-Edit", "MagicBrush"]
|
| visual_story_telling = ["AESOP", "FlintstonesSV", "PororoSV", "VIST"]
|
| visual_cloze = ["COMICS_Dialogue", "RecipeQA_VisualCloze"]
|
| text_rich_vqa = ["WebQA", "TQA", "OCR-VQA", "DocVQA"]
|
| multi_image_vqa = ["MIT-States_StateCoherence", "MIT-States_PropertyCoherence", "VISION", "RecipeQA_ImageCoherence"]
|
|
|
| puzzle = ["RAVEN"]
|
| nlrv2 = ["NLVR2_Mantis"]
|
| qbench = ["QBench"]
|
|
|
| class Eval:
|
| def __init__(self):
|
| self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
|
| self.commaStrip = re.compile("(\d)(\,)(\d)")
|
| self.punct = [
|
| ";",
|
| r"/",
|
| "[",
|
| "]",
|
| '"',
|
| "{",
|
| "}",
|
| "(",
|
| ")",
|
| "=",
|
| "+",
|
| "\\",
|
| "_",
|
| "-",
|
| ">",
|
| "<",
|
| "@",
|
| "`",
|
| ",",
|
| "?",
|
| "!",
|
| ]
|
|
|
| def processPunctuation(self, inText):
|
| outText = inText
|
| for p in self.punct:
|
| if (p + " " in inText or " " + p in inText) or (
|
| re.search(self.commaStrip, inText) != None
|
| ):
|
| outText = outText.replace(p, "")
|
| else:
|
| outText = outText.replace(p, " ")
|
| outText = self.periodStrip.sub("", outText, re.UNICODE)
|
| return outText
|
|
|
| def process(self, answer):
|
| answer = answer.replace("\n", " ")
|
| answer = answer.replace("\t", " ")
|
| answer = answer.strip()
|
| answer = self.processPunctuation(answer)
|
| answer = answer.strip('\'')
|
| answer = answer.strip('\"')
|
| answer = answer.strip(')')
|
| answer = answer.strip('(')
|
| answer = answer.strip().lower()
|
| return answer
|
|
|
| def evaluate_rouge(self,preds):
|
| rouge = Rouge()
|
| acc = {'f': []}
|
| eval_list = []
|
| for i, res in enumerate(preds):
|
| sample_id = res['sample_id']
|
|
|
| gt_ans = self.process(res["gt_response"])
|
| pred_ans = self.process(res["pred_response"])
|
|
|
|
|
| if gt_ans == '':
|
| continue
|
|
|
| if pred_ans == '':
|
| s = 0
|
| else:
|
| if len(pred_ans) > 512:
|
| pred_ans = pred_ans[0: 512]
|
| s = rouge.get_scores(pred_ans, gt_ans)[0]['rouge-l']['f']
|
| acc['f'].append(s)
|
| eval_list.append({'id':str(sample_id),'score':str(round(s,3))})
|
| results = {'Rouge-L f': np.mean(acc['f'])}
|
| return results,eval_list
|
|
|
|
|
| def judge_multi_choice(self,sample):
|
| sample_id = sample['sample_id']
|
| gt_ans = sample["gt_response"]
|
| pred_ans = sample["pred_response"]
|
|
|
| if ":" in pred_ans:
|
| a_list = pred_ans.split(":")
|
| a_list = [a.strip() for a in a_list ]
|
| for a in a_list:
|
| if len(a) == 1 and a[-1] in ["a", "b", "c", "d", "e", "f", "g", "h"]:
|
| pred_ans = a
|
|
|
| if pred_ans == gt_ans:
|
| return 1
|
| else:
|
| return 0
|
|
|
| def process_sample(self,sample):
|
| sample["gt_response"] = self.process(sample["gt_response"])
|
| sample["pred_response"] = self.process(sample["pred_response"])
|
|
|
| def evaluate_multichoice(self, preditions):
|
| correct = 0
|
| eval_list = []
|
| for i, sample in enumerate(preditions):
|
| self.process_sample(sample)
|
| score = self.judge_multi_choice(sample)
|
| sample_id = sample['sample_id']
|
| sample['result'] = score
|
| eval_list.append({'id':str(sample_id),'score':str(score)})
|
| correct+=score
|
| return {'Accuracy':correct/len(preditions)},eval_list
|
|
|
| def evaluate_multi_choice_image(self,preditions):
|
| correct = 0
|
| eval_list = []
|
| for i,sample in enumerate(preditions):
|
| gt_ans = self.process(sample["gt_response"])
|
| pred_ans = self.process(sample["pred_response"])
|
| sample_id = sample['sample_id']
|
|
|
| if ":" in pred_ans:
|
| a_list = pred_ans.split(":")
|
| a_list = [a.strip() for a in a_list ]
|
| for a in a_list:
|
| if len(a) == 1 and a[-1] in ["a", "b", "c", "d", "e", "f", "g", "h"]:
|
| pred_ans = a
|
|
|
| if gt_ans == pred_ans:
|
| score = 1
|
| else:
|
| score = 0
|
| sample_id = sample['sample_id']
|
| sample['result'] = score
|
| eval_list.append({'id':str(sample_id),'score':str(score)})
|
| correct+=score
|
| return {'Accuracy':correct/len(preditions)},eval_list
|
|
|
|
|
| if __name__ == "__main__":
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument('--result-dir', type=str, required=True)
|
|
|
| args = parser.parse_args()
|
|
|
| result_file = os.path.join(args.result_dir, "result.jsonl")
|
|
|
| if not os.path.exists(result_file):
|
| print('No prediction file found')
|
| exit(0)
|
| with open(result_file, 'r') as f:
|
| preds_all = [json.loads(line) for line in f]
|
|
|
| preds_all_dict = dict()
|
| for pred in preds_all:
|
| if pred["dataset"] not in preds_all_dict:
|
| preds_all_dict[pred["dataset"]] = list()
|
| preds_all_dict[pred["dataset"]].append(pred)
|
|
|
| image_choice_dataset_list = ["recipeqa-RecipeQA_VisualCloze", "RecipeQA_ImageCoherence", "COMICS_Panel"]
|
| E = Eval()
|
|
|
| eval_result_list = dict()
|
| eval_result_list_detail = dict()
|
|
|
| for dataset in preds_all_dict:
|
|
|
| preds = preds_all_dict[dataset]
|
| question_type = preds[0]["question_type"]
|
|
|
| if question_type == 'open-ended':
|
| eval_result, eval_list = E.evaluate_rouge(preds)
|
|
|
| elif question_type == 'multi-choice' or dataset == 'nlrv2':
|
| if dataset in image_choice_dataset_list:
|
| eval_result, eval_list = E.evaluate_multi_choice_image(preds)
|
| else:
|
| eval_result, eval_list = E.evaluate_multichoice(preds)
|
|
|
| else:
|
| eval_result = 'Dataset not supported'
|
| print('Dataset not supported')
|
| exit(0)
|
|
|
| print(dataset, end = ': ')
|
| print(eval_result)
|
|
|
| eval_result_list[dataset] = eval_result
|
| eval_result_list_detail[dataset] = eval_list
|
|
|
| os.makedirs(args.result_dir, exist_ok=True)
|
| with open(os.path.join(args.result_dir, 'eval_dataset.json'), 'w') as f:
|
| json.dump(eval_result_list, f, indent=4)
|
|
|
| with open(os.path.join(args.result_dir,'eval_dataset_details.json'), 'w') as f:
|
| json.dump(eval_result_list_detail, f, indent=4)
|
|
|
|
|
| eval_cat_list = dict()
|
| print()
|
|
|
|
|
| score = 0
|
| count = 0
|
| for dataset in eval_result_list:
|
| if dataset in spot_the_diff:
|
| count += 1
|
| score += list(eval_result_list[dataset].values())[0]
|
| if count > 0:
|
| score /= count
|
| eval_cat_list["spot_the_diff"] = score
|
| print("spot_the_diff", end = ': ')
|
| print('{:.2f}'.format(100 * score))
|
|
|
|
|
| score = 0
|
| count = 0
|
| for dataset in eval_result_list:
|
| if dataset in image_edit_instruct:
|
| count += 1
|
| score += list(eval_result_list[dataset].values())[0]
|
| if count > 0:
|
| score /= count
|
| eval_cat_list["image_edit_instruct"] = score
|
| print("image_edit_instruct", end = ': ')
|
| print('{:.2f}'.format(100 * score))
|
|
|
|
|
| score = 0
|
| count = 0
|
| for dataset in eval_result_list:
|
| if dataset in visual_story_telling:
|
| count += 1
|
| score += list(eval_result_list[dataset].values())[0]
|
| if count > 0:
|
| score /= count
|
| eval_cat_list["visual_story_telling"] = score
|
| print("visual_story_telling", end = ': ')
|
| print('{:.2f}'.format(100 * score))
|
|
|
|
|
| score = 0
|
| count = 0
|
| for dataset in eval_result_list:
|
| if dataset in visual_cloze:
|
| count += 1
|
| score += list(eval_result_list[dataset].values())[0]
|
| if count > 0:
|
| score /= count
|
| eval_cat_list["visual_cloze"] = score
|
| print("visual_cloze", end = ': ')
|
| print('{:.2f}'.format(100 * score))
|
|
|
|
|
| score = 0
|
| count = 0
|
| for dataset in eval_result_list:
|
| if dataset in text_rich_vqa:
|
| count += 1
|
| score += list(eval_result_list[dataset].values())[0]
|
| if count > 0:
|
| score /= count
|
| eval_cat_list["text_rich_vqa"] = score
|
| print("text_rich_vqa", end = ': ')
|
| print('{:.2f}'.format(100 * score))
|
|
|
|
|
| score = 0
|
| count = 0
|
| for dataset in eval_result_list:
|
| if dataset in multi_image_vqa:
|
| count += 1
|
| score += list(eval_result_list[dataset].values())[0]
|
| if count > 0:
|
| score /= count
|
| eval_cat_list["multi_image_vqa"] = score
|
| print("multi_image_vqa", end = ': ')
|
| print('{:.2f}'.format(100 * score))
|
|
|
|
|
| score = 0
|
| count = 0
|
| for dataset in eval_result_list:
|
| if dataset in puzzle:
|
| count += 1
|
| score += list(eval_result_list[dataset].values())[0]
|
| if count > 0:
|
| score /= count
|
| eval_cat_list["puzzle"] = score
|
| print("puzzle", end = ': ')
|
| print('{:.2f}'.format(100 * score))
|
|
|
|
|
| score = 0
|
| count = 0
|
| for dataset in eval_result_list:
|
| if dataset in nlrv2:
|
| count += 1
|
| score += list(eval_result_list[dataset].values())[0]
|
| if count > 0:
|
| score /= count
|
| eval_cat_list["nlrv2"] = score
|
| print("nlrv2", end = ': ')
|
| print('{:.2f}'.format(100 * score))
|
|
|
|
|
| score = 0
|
| count = 0
|
| for dataset in eval_result_list:
|
| if dataset in qbench:
|
| count += 1
|
| score += list(eval_result_list[dataset].values())[0]
|
| if count > 0:
|
| score /= count
|
| eval_cat_list["qbench"] = score
|
| print("qbench", end = ': ')
|
| print('{:.2f}'.format(100 * score))
|
|
|
| with open(os.path.join(args.result_dir,'eval_cat.json'), 'w') as f:
|
| json.dump(eval_cat_list, f, indent=4) |