|
|
|
|
| nltk.download('punkt') |
| import pandas as pd |
| import string |
|
|
| from gensim.models.phrases import Phrases, Phraser |
| from anytree import Node, RenderTree, PreOrderIter |
|
|
| from pathos.multiprocessing import ProcessingPool as Pool |
| import itertools |
| from time import time |
| import os |
| os.chdir('/content/') |
| nltk.download('stopwords') |
| import parmap |
|
|
| os.chdir('/content/') |
|
|
| device = torch.device('cuda') |
| from torch.utils.data import Dataset |
| from transformers import BertTokenizer |
|
|
| import numpy as np |
| from ast import literal_eval |
| import os.path |
| from torch.nn.utils import clip_grad_norm_ |
| from torch.utils.data import DataLoader |
| import time |
| import numpy as np |
| from sklearn import metrics |
| from transformers import get_linear_schedule_with_warmup |
| |
| |
| import torch.nn as nn |
|
|
|
|
| from transformers import * |
| import time |
| from transformers import BertModel |
|
|
| nltk.download('punkt') |
| nltk.download('wordnet') |
| nltk.download('omw-1.4') |
|
|
|
|
|
|
| device = torch.device('cuda') |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
| MAX_SEQ_LEN = 256 |
|
|
|
|
| MASK_TOKEN = '[MASK]' |
| BATCH_SIZE=32 |
|
|
| def generate_production_batch(batch): |
| tok=[(instance.tokens for instance in batch)] |
|
|
| tok=list( itertools.chain.from_iterable(tok)) |
| tok=list( itertools.chain.from_iterable([[' '.join(i)] for i in tok])) |
| encoded = tokenizer.__call__(tok, add_special_tokens=True, |
| max_length=MAX_SEQ_LEN, pad_to_max_length=True, |
| return_tensors='pt') |
| input_ids = encoded['input_ids'] |
| attn_mask = encoded['attention_mask'] |
|
|
| entity_indices = indices_for_entity_ranges([instance.entity_range for instance in batch]) |
|
|
| return input_ids, attn_mask, entity_indices, batch |
|
|
|
|
| def indices_for_entity_ranges(ranges): |
| max_e_len = max(end - start for start, end in ranges) |
| indices = torch.tensor([[[min(t, end)] * HIDDEN_OUTPUT_FEATURES |
| for t in range(start, start + max_e_len + 1)] |
| for start, end in ranges]) |
| return indices |
|
|
|
|
| open_file = open(project_dir+"/labels.pkl", "rb") |
| LABELS = pickle.load(open_file) |
| open_file.close() |
| with open(project_dir+'/labels_map.pkl', 'rb') as f: |
| LABEL_MAP = pickle.load(f) |
|
|
| open_file = open(project_dir+"/labels.pkl", "rb") |
| LABELS = pickle.load(open_file) |
| open_file.close() |
| with open(project_dir+'/labels_map.pkl', 'rb') as f: |
| LABEL_MAP = pickle.load(f) |
|
|
|
|
| class EntityDataset(Dataset): |
|
|
| def __init__(self, df, size=None): |
| |
| self.df = df[df.apply(lambda x: EntityDataset.instance_from_row(x) is not None, axis=1)] |
| print(len(self.df)) |
|
|
| |
| if size is not None and size < len(self): |
| self.df = self.df.sample(size, replace=False) |
|
|
| @staticmethod |
| def from_df(df, size=None): |
| dataset = EntityDataset(df, size=size) |
| print('Obtained dataset of size', len(dataset)) |
| return dataset |
|
|
|
|
| @staticmethod |
| def instance_from_row(row): |
| unpacked_arr = literal_eval(row['entityMentions']) if type(row['entityMentions']) is str else row['entityMentions'] |
| |
| |
| |
| entity= unpacked_arr[0]['text'] |
| |
| |
|
|
| text = row['sentText'] |
| |
| return EntityDataset.get_instance(text, entity) |
|
|
| @staticmethod |
| def get_instance(text, entity, label=None): |
| tokens = tokenizer.tokenize(text) |
|
|
| i = 0 |
| found_entity = True |
| entity_range = (0,100) |
|
|
| if found_entity: |
| return PairRelInstance(tokens, entity, entity_range, None, text) |
|
|
|
|
|
|
|
|
| def __len__(self): |
| return len(self.df.index) |
|
|
| def __getitem__(self, idx): |
| return EntityDataset.instance_from_row(self.df.iloc[idx]) |
|
|
|
|
|
|
| class PairRelInstance: |
|
|
| def __init__(self, tokens, entity, entity_range, label, text): |
| self.tokens = tokens |
| self.entity = entity |
| self.entity_range = entity_range |
| self.label = label |
| self.text = text |
|
|
| |
| |
|
|
| class PreTrainedPipeline(): |
| def __init__(self, path): |
| config = BertConfig.from_pretrained(TRAINED_WEIGHTS) |
| self.model = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config) |
|
|
| def __call__(self, inputs)-> Dict[str, str]: |
|
|
| return { |
| "text": "hello" |
| } |
|
|
| class EntityBertNet(nn.Module): |
|
|
| def __init__(self): |
| super(EntityBertNet, self).__init__() |
| config = BertConfig.from_pretrained(TRAINED_WEIGHTS) |
| self.bert_base = BertModel.from_pretrained(TRAINED_WEIGHTS, config=config) |
| self.fc = nn.Linear(HIDDEN_OUTPUT_FEATURES, NUM_CLASSES) |
|
|
| def forward(self, input_ids, attn_mask, entity_indices): |
| |
| bert_output, _ = self.bert_base(input_ids=input_ids, attention_mask=attn_mask,return_dict=False) |
| |
| |
| entity_pooled_output = EntityBertNet.pooled_output(bert_output, entity_indices) |
|
|
| |
| x = self.fc(entity_pooled_output) |
| return x |
|
|
| @staticmethod |
| def pooled_output(bert_output, indices): |
| |
| outputs = torch.gather(input=bert_output, dim=1, index=indices) |
| pooled_output, _ = torch.max(outputs, dim=1) |
| return pooled_output |
|
|