| import json |
| import metrics |
| import argparse |
| import numpy as np |
| import multiprocessing |
| from tqdm import trange |
| import signal, functools |
| import re, os, sys, random, time |
| from fraction import Fraction |
| from data_processing.answer_extraction import * |
| from eval.eval_script import * |
| from compute_perp import Evaluator, numberic_compare |
| MAX_INT = sys.maxsize |
| INVALID_ANS = "[Invalid]" |
|
|
| __all__ = ["DSU"] |
|
|
| class DSU: |
| def __init__(self, n): |
| self.n = n |
| self.father = [i for i in range(n)] |
| self.size = [1 for i in range(n)] |
| self.attr = [{} for i in range(n)] |
|
|
| def get_father(self, x): |
| if self.father[x] == x: |
| return x |
| self.father[x] = self.get_father(self.father[x]) |
| return self.father[x] |
|
|
| def merge(self, x, y): |
| fx = self.get_father(x) |
| fy = self.get_father(y) |
| if fx == fy: |
| return |
| self.father[fy] = fx |
| self.size[fx] += self.size[fy] |
| self.size[fy] = 0 |
| for key in self.attr[fy].keys(): |
| if key not in self.attr[fx]: |
| self.attr[fx][key] = self.attr[fy][key] |
| else: |
| self.attr[fx][key] |= self.attr[fy][key] |
| self.attr[fy] = {} |
|
|
|
|
| def sc_evaluator(predicts, completions, perplexities, answer, equal_func, check_equal): |
| m = len(predicts) |
| dsu = DSU(m) |
|
|
| |
| for i in range(m): |
| if dsu.get_father(i) != i: |
| continue |
| for j in range(i): |
| ans_i = predicts[i] |
| ans_j = predicts[j] |
| completion_i = completions[i] |
| completion_j = completions[j] |
| if equal_func(ans_i, ans_j, completion_i, completion_j): |
| dsu.merge(i, j) |
|
|
| |
| max_size, max_size_count = 0, 0 |
| for i in range(m): |
| if dsu.get_father(i) != i: |
| continue |
| if dsu.size[i] > max_size: |
| max_size = dsu.size[i] |
| max_size_count = 0 |
| if dsu.size[i] == max_size: |
| max_size_count += 1 |
|
|
| |
| correct, answers = 0, [] |
| for i in range(m): |
| if dsu.get_father(i) != i: |
| continue |
| ans_i = predicts[i] |
| answers.append([ans_i, dsu.size[i] / m, check_equal(ans_i, answer)]) |
| if dsu.size[i] < max_size: |
| continue |
| if check_equal(ans_i, answer): |
| correct += 1.0 / max_size_count |
|
|
| |
| sum_proba = np.sum([x[1] for x in answers]) |
| for i in range(len(answers)): |
| answers[i][1] /= sum_proba |
|
|
| return correct, answers |
|
|
|
|
| class SCEvaluator(Evaluator): |
| def __init__(self): |
| self.name = "Self-Consistency" |
|
|
| def worker(self, args): |
| json_file, cache_file, K, seed = args |
| acc, maximum, average, max_bins, avg_bins = self.process( |
| json_file=json_file, |
| cache_file=cache_file, |
| equal_func=numberic_compare, |
| evaluator=sc_evaluator, |
| K=K, |
| seed=seed |
| ) |
| return acc, maximum, average |
|
|