| import re |
| import os |
| import sys |
| from tqdm import tqdm |
|
|
|
|
| def remove_bpe(line, bpe_symbol="@@ "): |
| line = line.replace("\n", '') |
| line = (line + ' ').replace(bpe_symbol, '').rstrip() |
| return line + ("\n") |
|
|
|
|
| def remove_bpe_fn(i=sys.stdin, o=sys.stdout, bpe="@@ "): |
| lines = tqdm(i) |
| lines = map(lambda x: remove_bpe(x, bpe), lines) |
| |
| for line in lines: |
| o.write(line) |
|
|
|
|
| def reprocess(fle): |
| |
| |
| |
| |
| |
| |
| with open(fle, 'r') as f: |
| txt = f.read() |
| |
| """reprocess generate.py output""" |
| p = re.compile(r"[STHP][-]\d+\s*") |
| hp = re.compile(r"(\s*[-]?\d+[.]?\d+(e[+-]?\d+)?\s*)|(\s*(-inf)\s*)") |
| source_dict = {} |
| hypothesis_dict = {} |
| score_dict = {} |
| target_dict = {} |
| pos_score_dict = {} |
| lines = txt.split("\n") |
| |
| for line in lines: |
| line += "\n" |
| prefix = re.search(p, line) |
| if prefix is not None: |
| assert len(prefix.group()) > 2, "prefix id not found" |
| _, j = prefix.span() |
| id_num = prefix.group()[2:] |
| id_num = int(id_num) |
| line_type = prefix.group()[0] |
| if line_type == "H": |
| h_txt = line[j:] |
| hypo = re.search(hp, h_txt) |
| assert hypo is not None, ("regular expression failed to find the hypothesis scoring") |
| _, i = hypo.span() |
| score = hypo.group() |
| hypo_str = h_txt[i:] |
| |
| |
| if id_num in hypothesis_dict: |
| hypothesis_dict[id_num].append(hypo_str) |
| score_dict[id_num].append(float(score)) |
| else: |
| hypothesis_dict[id_num] = [hypo_str] |
| score_dict[id_num] = [float(score)] |
| |
| elif line_type == "S": |
| source_dict[id_num] = (line[j:]) |
| elif line_type == "T": |
| |
| continue |
| elif line_type == "P": |
| pos_scores = (line[j:]).split() |
| pos_scores = [float(x) for x in pos_scores] |
| if id_num in pos_score_dict: |
| pos_score_dict[id_num].append(pos_scores) |
| else: |
| pos_score_dict[id_num] = [pos_scores] |
| |
| return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict |
|
|
|
|
| def get_hypo_and_ref(fle, hyp_file, ref_input, ref_file, rank=0): |
| with open(ref_input, 'r') as f: |
| refs = f.readlines() |
| _, hypo_dict, _, _, _ = reprocess(fle) |
| assert rank < len(hypo_dict[0]) |
| maxkey = max(hypo_dict, key=int) |
| f_hyp = open(hyp_file, "w") |
| f_ref = open(ref_file, "w") |
| for idx in range(maxkey + 1): |
| if idx not in hypo_dict: |
| continue |
| f_hyp.write(hypo_dict[idx][rank]) |
| f_ref.write(refs[idx]) |
| f_hyp.close() |
| f_ref.close() |
|
|
|
|
| def recover_bpe(hyp_file): |
| f_hyp = open(hyp_file, "r") |
| f_hyp_out = open(hyp_file + ".nobpe", "w") |
| for _s in ["hyp"]: |
| f = eval("f_{}".format(_s)) |
| fout = eval("f_{}_out".format(_s)) |
| remove_bpe_fn(i=f, o=fout) |
| f_hyp.close() |
| f_hyp_out.close() |
|
|
|
|
| if __name__ == "__main__": |
| filename = sys.argv[1] |
| ref_in = sys.argv[2] |
| hypo_file = os.path.join(os.path.dirname(filename), "hypo.out") |
| ref_out = os.path.join(os.path.dirname(filename), "ref.out") |
| get_hypo_and_ref(filename, hypo_file, ref_in, ref_out) |
| recover_bpe(hypo_file) |
|
|