| import torch |
| import numpy as np |
| import sys |
| import os |
| from transformers import RobertaTokenizer, AutoModelForTokenClassification, RobertaForSequenceClassification |
| import spacy |
| import tokenizations |
| from numpy import asarray |
| from numpy import savetxt, loadtxt |
| import numpy as np |
| import json |
| from copy import deepcopy |
| from sty import fg, bg, ef, rs, RgbBg, Style |
| import re |
| from tqdm import tqdm |
| import gradio as gr |
|
|
| nlp = spacy.load("en_core_web_sm") |
| tokenizer = RobertaTokenizer.from_pretrained("roberta-base") |
| clause_model = AutoModelForTokenClassification.from_pretrained("./clause_model_512", num_labels=3) |
| classification_model = RobertaForSequenceClassification.from_pretrained("./classfication_model", num_labels=18) |
|
|
|
|
| labels2attrs = { |
| "##BOUNDED EVENT (SPECIFIC)": ("specific", "dynamic", "episodic"), |
| "##BOUNDED EVENT (GENERIC)": ("generic", "dynamic", "episodic"), |
| "##UNBOUNDED EVENT (SPECIFIC)": ("specific", "dynamic", "static"), |
| "##UNBOUNDED EVENT (GENERIC)": ("generic", "dynamic", "static"), |
| "##BASIC STATE": ("specific", "stative", "static"), |
| "##COERCED STATE (SPECIFIC)": ("specific", "dynamic", "static"), |
| "##COERCED STATE (GENERIC)": ("generic", "dynamic", "static"), |
| "##PERFECT COERCED STATE (SPECIFIC)": ("specific", "dynamic", "episodic"), |
| "##PERFECT COERCED STATE (GENERIC)": ("generic", "dynamic", "episodic"), |
| "##GENERIC SENTENCE (DYNAMIC)": ("generic", "dynamic", "habitual"), |
| "##GENERIC SENTENCE (STATIC)": ("generic", "stative", "static"), |
| "##GENERIC SENTENCE (HABITUAL)": ("generic", "stative", "habitual"), |
| "##GENERALIZING SENTENCE (DYNAMIC)": ("specific", "dynamic", "habitual"), |
| "##GENERALIZING SENTENCE (STATIVE)": ("specific", "stative", "habitual"), |
| "##QUESTION": ("NA", "NA", "NA"), |
| "##IMPERATIVE": ("NA", "NA", "NA"), |
| "##NONSENSE": ("NA", "NA", "NA"), |
| "##OTHER": ("NA", "NA", "NA"), |
| } |
|
|
| label2index = {l:i for l,i in zip(labels2attrs.keys(), np.arange(len(labels2attrs)))} |
| index2label = {i:l for l,i in label2index.items()} |
|
|
| def auto_split(text): |
| doc = nlp(text) |
| current_len = 0 |
| snippets = [] |
| current_snippet = "" |
| for sent in doc.sents: |
| text = sent.text |
| words = text.split() |
| if current_len + len(words) > 200: |
| snippets.append(current_snippet) |
| current_snippet = text |
| current_len = len(words) |
| else: |
| current_snippet += " " + text |
| current_len += len(words) |
| snippets.append(current_snippet) |
| return snippets |
|
|
|
|
| def majority_vote(array): |
| unique, counts = np.unique(np.array(array), return_counts=True) |
| return unique[np.argmax(counts)] |
|
|
| def get_pred_clause_labels(text, words): |
| model_inputs = tokenizer(text, padding='max_length', max_length=512, truncation=True, return_tensors='pt') |
| roberta_tokens = (tokenizer.convert_ids_to_tokens(model_inputs['input_ids'][0])) |
| a2b, b2a = tokenizations.get_alignments(words, roberta_tokens) |
| logits = clause_model(**model_inputs)[0] |
| tagging = logits.argmax(-1)[0].numpy() |
| pred_labels = [] |
| for aligment in a2b: |
| if len(aligment) == 0: pred_labels.append(1) |
| elif len(aligment) == 1: pred_labels.append(tagging[aligment[0]]) |
| else: |
| pred_labels.append(majority_vote([tagging[a] for a in aligment])) |
| assert len(pred_labels) == len(words) |
| return pred_labels |
|
|
| def seg_clause(text): |
| words = text.strip().split() |
| labels = get_pred_clause_labels(text, words) |
| segmented_clauses = [] |
| prev_label = 2 |
| current_clause = None |
| for cur_token, cur_label in zip(words, labels): |
| if prev_label == 2: current_clause = [] |
| if current_clause != None: current_clause.append(cur_token) |
| |
| if cur_label == 2: |
| if prev_label in [0, 1]: |
| segmented_clauses.append(deepcopy(current_clause)) |
| current_clause = None |
| prev_label = cur_label |
|
|
| if current_clause is not None and len(current_clause) != 0: |
| segmented_clauses.append(deepcopy(current_clause)) |
| return [" ".join(clause) for clause in segmented_clauses if clause is not None] |
|
|
| def pretty_print_segmented_clause(segmented_clauses): |
| np.random.seed(42) |
| bg.orange = Style(RgbBg(255, 150, 50)) |
| bg.purple = Style(RgbBg(180, 130, 225)) |
| colors = [bg.red, bg.orange, bg.yellow, bg.green, bg.blue, bg.purple] |
| prev_color = 0 |
| to_print = [] |
| for cl in segmented_clauses: |
| color_choice = np.random.choice(np.delete(np.arange(len(colors)), prev_color)) |
| prev_color = color_choice |
| colored_cl = colors[color_choice] + cl + bg.rs |
| to_print.append(colored_cl) |
| print(*to_print, sep=" ") |
| |
|
|
| def get_pred_classification_labels(clauses, batch_size=32): |
| clause2labels = [] |
| for i in range(0, len(clauses) + 1, batch_size): |
| batch_examples = clauses[i : i + batch_size] |
| model_inputs = tokenizer(batch_examples, padding='max_length', max_length=128, truncation=True, return_tensors='pt') |
| logits = classification_model(**model_inputs)[0] |
| pred_labels = logits.argmax(-1).numpy() |
| pred_labels = [index2label[l] for l in pred_labels] |
| |
| clause2labels.extend([(s, str(l),) for s,l in zip(batch_examples, pred_labels)]) |
| return clause2labels |
|
|
|
|
|
|
| def run_pipeline(text): |
| snippets = auto_split(text) |
| all_clauses = [] |
| for s in snippets: |
| segmented_clauses = seg_clause(s) |
| all_clauses.extend(segmented_clauses) |
| clause2labels = get_pred_classification_labels(all_clauses) |
| output_clauses = [(c, str(i + 1)) for i, c in enumerate(all_clauses)] |
| return output_clauses, clause2labels |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| color_panel_1 = ["red", "green", "yellow", "DodgerBlue", "orange", "DarkSalmon", "pink", "cyan", "gold", "aqua", "violet"] |
| index_colormap = {str(i) : color_panel_1[i % len(color_panel_1)] for i in np.arange(1, 100000)} |
| color_panel_2 = ["Violet", "DodgerBlue", "Wheat", "OliveDrab", "DarkKhaki", "DarkSalmon", "Orange", "Gold", "Aqua", "Tomato", "Gray"] |
| str_attrs = [str(v) for v in set(labels2attrs.values())] |
| print(str_attrs, len(str_attrs), len(color_panel_2)) |
| assert len(str_attrs) == len(color_panel_2) |
| attr_colormap = {a:c for a, c in zip(str_attrs, color_panel_2)} |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| demo = gr.Interface( |
| fn=run_pipeline, |
| inputs=["text"], |
| outputs= [ |
| gr.HighlightedText( |
| label="Clause Segmentation", |
| show_label=True, |
| combine_adjacent=False, |
| ).style(color_map = index_colormap), |
|
|
| gr.HighlightedText( |
| label="Attribute Classification", |
| show_label=True, |
| show_legend=True, |
| combine_adjacent=False, |
| ).style(color_map=attr_colormap), |
| ] |
| ) |
|
|
| demo.launch(share=True) |