| import re |
| import numpy as np |
|
|
| def is_uni_punctuation(word): |
| match = re.match("^[^\w\s]+$]", word, flags=re.UNICODE) |
| return match is not None |
|
|
|
|
| def is_punctuation(word, pos, punct_set=None): |
| if punct_set is None: |
| return is_uni_punctuation(word) |
| else: |
| return pos in punct_set |
|
|
|
|
| def eval_(words, postags, heads_pred, arc_tag_pred, heads, arc_tag, word_alphabet, pos_alphabet, lengths, |
| punct_set=None, symbolic_root=False, symbolic_end=False): |
| batch_size, _ = words.shape |
| ucorr = 0. |
| lcorr = 0. |
| total = 0. |
| ucomplete_match = 0. |
| lcomplete_match = 0. |
|
|
| ucorr_nopunc = 0. |
| lcorr_nopunc = 0. |
| total_nopunc = 0. |
| ucomplete_match_nopunc = 0. |
| lcomplete_match_nopunc = 0. |
|
|
| corr_root = 0. |
| total_root = 0. |
| start = 1 if symbolic_root else 0 |
| end = 1 if symbolic_end else 0 |
| for i in range(batch_size): |
| ucm = 1. |
| lcm = 1. |
| ucm_nopunc = 1. |
| lcm_nopunc = 1. |
| for j in range(start, lengths[i] - end): |
| word = word_alphabet.get_instance(words[i, j]) |
| word = word.encode('utf8') |
|
|
| pos = pos_alphabet.get_instance(postags[i, j]) |
| pos = pos.encode('utf8') |
|
|
| total += 1 |
| if heads[i, j] == heads_pred[i, j]: |
| ucorr += 1 |
| if arc_tag[i, j] == arc_tag_pred[i, j]: |
| lcorr += 1 |
| else: |
| lcm = 0 |
| else: |
| ucm = 0 |
| lcm = 0 |
|
|
| if not is_punctuation(word, pos, punct_set): |
| total_nopunc += 1 |
| if heads[i, j] == heads_pred[i, j]: |
| ucorr_nopunc += 1 |
| if arc_tag[i, j] == arc_tag_pred[i, j]: |
| lcorr_nopunc += 1 |
| else: |
| lcm_nopunc = 0 |
| else: |
| ucm_nopunc = 0 |
| lcm_nopunc = 0 |
|
|
| if heads[i, j] == 0: |
| total_root += 1 |
| corr_root += 1 if heads_pred[i, j] == 0 else 0 |
|
|
| ucomplete_match += ucm |
| lcomplete_match += lcm |
| ucomplete_match_nopunc += ucm_nopunc |
| lcomplete_match_nopunc += lcm_nopunc |
|
|
| return (ucorr, lcorr, total, ucomplete_match, lcomplete_match), \ |
| (ucorr_nopunc, lcorr_nopunc, total_nopunc, ucomplete_match_nopunc, lcomplete_match_nopunc), \ |
| (corr_root, total_root), batch_size |
|
|
|
|
| def decode_MST(energies, lengths, leading_symbolic=0, labeled=True): |
| """ |
| decode best parsing tree with MST algorithm. |
| :param energies: energies: numpy 4D tensor |
| energies of each edge. the shape is [batch_size, num_labels, n_steps, n_steps], |
| where the summy root is at index 0. |
| :param masks: numpy 2D tensor |
| masks in the shape [batch_size, n_steps]. |
| :param leading_symbolic: int |
| number of symbolic dependency arcs leading in arc alphabets) |
| :return: |
| """ |
|
|
| def find_cycle(par): |
| added = np.zeros([length], np.bool) |
| added[0] = True |
| cycle = set() |
| findcycle = False |
| for i in range(1, length): |
| if findcycle: |
| break |
|
|
| if added[i] or not curr_nodes[i]: |
| continue |
|
|
| |
| tmp_cycle = set() |
| tmp_cycle.add(i) |
| added[i] = True |
| findcycle = True |
| l = i |
|
|
| while par[l] not in tmp_cycle: |
| l = par[l] |
| if added[l]: |
| findcycle = False |
| break |
| added[l] = True |
| tmp_cycle.add(l) |
|
|
| if findcycle: |
| lorg = l |
| cycle.add(lorg) |
| l = par[lorg] |
| while l != lorg: |
| cycle.add(l) |
| l = par[l] |
| break |
|
|
| return findcycle, cycle |
|
|
| def chuLiuEdmonds(): |
| par = np.zeros([length], dtype=np.int32) |
| |
| par[0] = -1 |
| for i in range(1, length): |
| |
| if curr_nodes[i]: |
| max_score = score_matrix[0, i] |
| par[i] = 0 |
| for j in range(1, length): |
| if j == i or not curr_nodes[j]: |
| continue |
|
|
| new_score = score_matrix[j, i] |
| if new_score > max_score: |
| max_score = new_score |
| par[i] = j |
|
|
| |
| findcycle, cycle = find_cycle(par) |
| |
| if not findcycle: |
| final_edges[0] = -1 |
| for i in range(1, length): |
| if not curr_nodes[i]: |
| continue |
|
|
| pr = oldI[par[i], i] |
| ch = oldO[par[i], i] |
| final_edges[ch] = pr |
| return |
|
|
| cyc_len = len(cycle) |
| cyc_weight = 0.0 |
| cyc_nodes = np.zeros([cyc_len], dtype=np.int32) |
| id = 0 |
| for cyc_node in cycle: |
| cyc_nodes[id] = cyc_node |
| id += 1 |
| cyc_weight += score_matrix[par[cyc_node], cyc_node] |
|
|
| rep = cyc_nodes[0] |
| for i in range(length): |
| if not curr_nodes[i] or i in cycle: |
| continue |
|
|
| max1 = float("-inf") |
| wh1 = -1 |
| max2 = float("-inf") |
| wh2 = -1 |
|
|
| for j in range(cyc_len): |
| j1 = cyc_nodes[j] |
| if score_matrix[j1, i] > max1: |
| max1 = score_matrix[j1, i] |
| wh1 = j1 |
|
|
| scr = cyc_weight + score_matrix[i, j1] - score_matrix[par[j1], j1] |
|
|
| if scr > max2: |
| max2 = scr |
| wh2 = j1 |
|
|
| score_matrix[rep, i] = max1 |
| oldI[rep, i] = oldI[wh1, i] |
| oldO[rep, i] = oldO[wh1, i] |
| score_matrix[i, rep] = max2 |
| oldO[i, rep] = oldO[i, wh2] |
| oldI[i, rep] = oldI[i, wh2] |
|
|
| rep_cons = [] |
| for i in range(cyc_len): |
| rep_cons.append(set()) |
| cyc_node = cyc_nodes[i] |
| for cc in reps[cyc_node]: |
| rep_cons[i].add(cc) |
|
|
| for i in range(1, cyc_len): |
| cyc_node = cyc_nodes[i] |
| curr_nodes[cyc_node] = False |
| for cc in reps[cyc_node]: |
| reps[rep].add(cc) |
|
|
| chuLiuEdmonds() |
|
|
| |
| found = False |
| wh = -1 |
| for i in range(cyc_len): |
| for repc in rep_cons[i]: |
| if repc in final_edges: |
| wh = cyc_nodes[i] |
| found = True |
| break |
| if found: |
| break |
|
|
| l = par[wh] |
| while l != wh: |
| ch = oldO[par[l], l] |
| pr = oldI[par[l], l] |
| final_edges[ch] = pr |
| l = par[l] |
|
|
| if labeled: |
| assert energies.ndim == 4, 'dimension of energies is not equal to 4' |
| else: |
| assert energies.ndim == 3, 'dimension of energies is not equal to 3' |
| input_shape = energies.shape |
| batch_size = input_shape[0] |
| max_length = input_shape[2] |
|
|
| pars = np.zeros([batch_size, max_length], dtype=np.int32) |
| arc_tags = np.zeros([batch_size, max_length], dtype=np.int32) if labeled else None |
| for i in range(batch_size): |
| energy = energies[i] |
|
|
| |
| length = lengths[i] |
|
|
| |
| if labeled: |
| energy = energy[leading_symbolic:, :length, :length] |
| |
| label_id_matrix = energy.argmax(axis=0) + leading_symbolic |
| energy = energy.max(axis=0) |
| else: |
| energy = energy[:length, :length] |
| label_id_matrix = None |
| |
| orig_score_matrix = energy |
| |
| score_matrix = np.array(orig_score_matrix, copy=True) |
|
|
| oldI = np.zeros([length, length], dtype=np.int32) |
| oldO = np.zeros([length, length], dtype=np.int32) |
| curr_nodes = np.zeros([length], dtype=np.bool) |
| reps = [] |
|
|
| for s in range(length): |
| orig_score_matrix[s, s] = 0.0 |
| score_matrix[s, s] = 0.0 |
| curr_nodes[s] = True |
| reps.append(set()) |
| reps[s].add(s) |
| for t in range(s + 1, length): |
| oldI[s, t] = s |
| oldO[s, t] = t |
|
|
| oldI[t, s] = t |
| oldO[t, s] = s |
|
|
| final_edges = dict() |
| chuLiuEdmonds() |
| par = np.zeros([max_length], np.int32) |
| if labeled: |
| arc_tag = np.ones([max_length], np.int32) |
| arc_tag[0] = 0 |
| else: |
| arc_tag = None |
|
|
| for ch, pr in final_edges.items(): |
| par[ch] = pr |
| if labeled and ch != 0: |
| arc_tag[ch] = label_id_matrix[pr, ch] |
|
|
| par[0] = 0 |
| pars[i] = par |
| if labeled: |
| arc_tags[i] = arc_tag |
|
|
| return pars, arc_tags |
|
|