| import os |
| import numpy as np |
| import jsonlines |
| from collections import defaultdict |
| from sklearn.metrics import classification_report |
|
|
|
|
| def get_links(sample_string, sample_index): |
| """ |
| takes a sample string and returns a list of attach tuples |
| and a list of rel type strings |
| """ |
| |
| labels = ['COM','CONTR','CORR','QAP','ACK','ELAB','CLARIFQ','COND','CONTIN', |
| 'RES','EXPL','QELAB','ALT','NARR','CONFQ','SEQ'] |
| |
| split_list = [st.strip() for st in sample_string.split(' ')] |
| |
| rel_list = [] |
| attach_list = [] |
| bad = 0 |
| good = 0 |
| for a in split_list: |
| s_tuple = None |
| rel = None |
| try: |
| s = a.split('(')[1].split(')')[0].split(',') |
| r = a.split('(')[0].strip() |
| except IndexError: |
| print('split error at ', sample_index) |
| else: |
| try: |
| s_tuple = (int(s[0]), int(s[1])) |
| except IndexError: |
| print('split error at ', sample_index) |
| except ValueError: |
| print('value error at ', sample_index) |
| if r in labels: |
| |
| rel = r |
| |
| if rel != None and s_tuple != None and (s_tuple[1] - s_tuple[0]) <= 15: |
| |
| attach_list.append((int(s[0]), int(s[1]))) |
| rel_list.append(r) |
| good += 1 |
| else: |
| bad += 1 |
| |
| |
| |
| |
| full_list = [] |
| endpoints = [] |
| for i, r in enumerate(attach_list): |
| if r not in endpoints: |
| endpoints.append(r) |
| full_list.append((rel_list[i], r[0], r[1])) |
| return endpoints, full_list, [good, bad] |
| |
|
|
| current_folder=os.getcwd() |
|
|
| gold_path = '/path/to/jsonl' |
| pred_path = '/path/to/llamipa_output.txt' |
| save_results = '/path/to/eval_.txt' |
|
|
| |
| with open(pred_path, 'r') as txt: |
| text = txt.read().split('\n') |
|
|
| pred_outputs = [] |
|
|
| for t in text: |
| if t.startswith(' ### DS:'): |
| sample = t.split('### DS:')[1].strip() |
| pred_outputs.append(sample) |
| print(len(pred_outputs)) |
|
|
| |
| gold_outputs = [] |
|
|
| with jsonlines.open(gold_path) as reader: |
| for obj in reader: |
| if not obj['sample'].startswith('NEW DIALOGUE'): |
| gold_outputs.append(obj['PS']) |
|
|
| att_f1_l = [] |
| att_prec_l = [] |
| att_rec_l = [] |
|
|
| total_attach_tp = 0 |
| total_attach_fp = 0 |
| total_attach_fn = 0 |
|
|
| type_f1_l = [] |
| type_prec_l = [] |
| type_rec_l = [] |
|
|
| total_TP = [] |
|
|
| matrix_list = [] |
| bad_output = 0 |
| good_output = 0 |
|
|
| for i, s in enumerate(pred_outputs): |
|
|
| pred_att, pred_all, malform = get_links(s, i) |
| gold_att, gold_all, malform = get_links(gold_outputs[i], i) |
|
|
| bad_output += malform[1] |
| good_output += malform[0] |
|
|
| |
| common = len(set(pred_att).intersection(set(gold_att))) |
| expected_nulls = (len(pred_att) - common) + (len(gold_att) - common) |
|
|
|
|
| |
| if len(gold_att) > 0 and len(pred_att) > 0: |
| prec = len([e for e in pred_att if e in gold_att])/len(pred_att) |
| rec = len([e for e in pred_att if e in gold_att])/len(gold_att) |
| total_attach_tp += len([e for e in pred_att if e in gold_att]) |
| total_attach_fp += len([e for e in pred_att if e not in gold_att]) |
| total_attach_fn += len([e for e in gold_att if e not in pred_att]) |
| else: |
| prec = 0 |
| rec = 0 |
| att_prec_l.append(prec) |
| att_rec_l.append(rec) |
| if prec+rec==0: |
| att_f1_l.append(0) |
| else: |
| att_f1_l.append(2*prec*rec/(prec+rec)) |
|
|
| |
| if len(gold_all) > 0 and len(pred_all) > 0: |
| prec = len([e for e in pred_all if e in gold_all])/len(pred_all) |
| rec = len([e for e in pred_all if e in gold_all])/len(gold_all) |
| else: |
| prec = 0 |
| rec = 0 |
| type_prec_l.append(prec) |
| type_rec_l.append(rec) |
| if prec+rec==0: |
| type_f1_l.append(0) |
| else: |
| type_f1_l.append(2*prec*rec/(prec+rec)) |
|
|
| |
| TP = [e for e in pred_all if e in gold_all] |
| leftover_pred = [p for p in pred_all if p not in TP] |
| leftover_gold = [p for p in gold_all if p not in TP] |
|
|
| |
| total_TP.extend(TP) |
| |
| rem_dict = defaultdict(list) |
| for x in TP: |
| matrix_list.append([x[0], x[0]]) |
| for x in leftover_pred: |
| rem_dict[(x[1], x[2])].append(('p', x[0])) |
| for x in leftover_gold: |
| rem_dict[(x[1], x[2])].append(('g', x[0])) |
|
|
| p_count = 0 |
| g_count = 0 |
| null_count = 0 |
| for k in rem_dict.keys(): |
| p = 'NULL' |
| t = 'NULL' |
| for re in rem_dict[k]: |
| if re[0] == 'p': |
| p = re[1] |
| p_count += 1 |
| elif re[0] == 'g': |
| t = re[1] |
| g_count += 1 |
| matrix_list.append([t,p]) |
| if 'NULL' in [t,p]: |
| null_count += 1 |
| |
| assert(len(TP) + p_count == len(pred_all)) |
| assert(len(TP) + g_count == len(gold_all)) |
| assert null_count == expected_nulls |
|
|
| |
| gold = [m[0] for m in matrix_list] |
| pred = [m[1] for m in matrix_list] |
| gold.extend(pred) |
| labels = list(set(gold)) |
|
|
| microf1 = total_attach_tp/(total_attach_tp + 0.5*(total_attach_fp + total_attach_fn)) |
|
|
| gold_list = [labels.index(m[0]) for m in matrix_list] |
| pred_list = [labels.index(m[1]) for m in matrix_list] |
|
|
| f = open(save_results,"w") |
| print("Attachment F1:",np.mean(att_f1_l),len(att_f1_l), file=f) |
| print("Attachment Average Precision:",np.mean(att_prec_l), file=f) |
| print("Attachment Average Recall:",np.mean(att_rec_l), file=f) |
| print('Micro F1: ', microf1, file=f) |
| print('--------------------------------', file=f) |
| print("Attachment + Rel F1:",np.mean(type_f1_l),len(type_f1_l)) |
| print("Attachment + Rel Average Precision:",np.mean(type_prec_l)) |
| print("Attachment + Rel Average Recall:",np.mean(type_rec_l)) |
| print('---------------------------------------') |
| print(classification_report(gold_list,pred_list,target_names=labels), file=f) |
|
|
| |
| |
| |
| |
| |
| |
| |
| d = classification_report(gold_list,pred_list,target_names=labels,output_dict=True) |
| prec = 0 |
| rec = 0 |
| f1 = 0 |
| count = 0 |
|
|
| for label in labels: |
| if label!="NULL": |
| prec+=d[label]["precision"]*d[label]["support"] |
| rec+=d[label]["recall"]*d[label]["support"] |
| f1+=d[label]["f1-score"]*d[label]["support"] |
| count+=d[label]["support"] |
| |
| |
| |
| print('--------------------------------', file=f) |
| print("Weighted Average Precision:", prec/count, file=f) |
| print("Weighted Average Recall:", rec/count, file=f) |
| print("Weighted Average F1 score:", f1/count, file=f) |
|
|
| f.close() |
| |
| |
|
|