| """BERT NER Inference.""" |
|
|
| from __future__ import absolute_import, division, print_function |
|
|
| import json |
| import os |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.nn import CrossEntropyLoss |
| from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset |
| from torch.utils.data.distributed import DistributedSampler |
| from tqdm import tqdm, trange |
| from nltk import word_tokenize |
| |
| |
| from pytorch_transformers import (BertForTokenClassification, BertTokenizer) |
|
|
|
|
| class BertNer(BertForTokenClassification): |
|
|
| def forward(self, input_ids, token_type_ids=None, attention_mask=None, valid_ids=None): |
| sequence_output = self.bert(input_ids, token_type_ids, attention_mask, head_mask=None)[0] |
| batch_size,max_len,feat_dim = sequence_output.shape |
| valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device='cpu') |
| |
| for i in range(batch_size): |
| jj = -1 |
| for j in range(max_len): |
| if valid_ids[i][j].item() == 1: |
| jj += 1 |
| valid_output[i][jj] = sequence_output[i][j] |
| sequence_output = self.dropout(valid_output) |
| logits = self.classifier(sequence_output) |
| return logits |
|
|
| class Ner: |
|
|
| def __init__(self,model_dir: str): |
| self.model , self.tokenizer, self.model_config = self.load_model(model_dir) |
| self.label_map = self.model_config["label_map"] |
| self.max_seq_length = self.model_config["max_seq_length"] |
| self.label_map = {int(k):v for k,v in self.label_map.items()} |
| self.device = "cpu" |
| |
| self.model = self.model.to(self.device) |
| self.model.eval() |
|
|
| def load_model(self, model_dir: str, model_config: str = "model_config.json"): |
| model_config = os.path.join(model_dir,model_config) |
| model_config = json.load(open(model_config)) |
| model = BertNer.from_pretrained(model_dir) |
| tokenizer = BertTokenizer.from_pretrained(model_dir, do_lower_case=model_config["do_lower"]) |
| return model, tokenizer, model_config |
|
|
| def tokenize(self, text: str): |
| """ tokenize input""" |
| words = word_tokenize(text) |
| tokens = [] |
| valid_positions = [] |
| for i,word in enumerate(words): |
| token = self.tokenizer.tokenize(word) |
| tokens.extend(token) |
| for i in range(len(token)): |
| if i == 0: |
| valid_positions.append(1) |
| else: |
| valid_positions.append(0) |
| |
| return tokens, valid_positions |
|
|
| def preprocess(self, text: str): |
| """ preprocess """ |
| tokens, valid_positions = self.tokenize(text) |
| |
| tokens.insert(0,"[CLS]") |
| valid_positions.insert(0,1) |
| |
| tokens.append("[SEP]") |
| valid_positions.append(1) |
| segment_ids = [] |
| for i in range(len(tokens)): |
| segment_ids.append(0) |
| input_ids = self.tokenizer.convert_tokens_to_ids(tokens) |
| |
| input_mask = [1] * len(input_ids) |
| while len(input_ids) < self.max_seq_length: |
| input_ids.append(0) |
| input_mask.append(0) |
| segment_ids.append(0) |
| valid_positions.append(0) |
| return input_ids,input_mask,segment_ids,valid_positions |
|
|
| def predict_entity(self, B_lab, I_lab, words, labels, entity_list): |
| temp=[] |
| entity=[] |
|
|
| for word, (label, confidence), B_l, I_l in zip(words, labels, B_lab, I_lab): |
|
|
| if ((label==B_l) or (label==I_l)) and label!='O': |
| if label==B_l: |
| entity.append(temp) |
| temp=[] |
| temp.append(label) |
| |
| temp.append(word) |
|
|
| entity.append(temp) |
| |
|
|
| entity_name_label = [] |
| for entity_name in entity[1:]: |
| for ent_key, ent_value in entity_list.items(): |
| if (ent_key==entity_name[0]): |
| |
| entity_name_label.append([' '.join(entity_name[1:]), ent_value]) |
| |
| return entity_name_label |
|
|
| def predict(self, text: str): |
| input_ids,input_mask,segment_ids,valid_ids = self.preprocess(text) |
| |
| input_ids = torch.tensor([input_ids],dtype=torch.long,device=self.device) |
| input_mask = torch.tensor([input_mask],dtype=torch.long,device=self.device) |
| segment_ids = torch.tensor([segment_ids],dtype=torch.long,device=self.device) |
| valid_ids = torch.tensor([valid_ids],dtype=torch.long,device=self.device) |
|
|
| with torch.no_grad(): |
| logits = self.model(input_ids, segment_ids, input_mask,valid_ids) |
| |
| logits = F.softmax(logits,dim=2) |
| |
| logits_label = torch.argmax(logits,dim=2) |
| logits_label = logits_label.detach().cpu().numpy().tolist()[0] |
| |
|
|
| logits_confidence = [values[label].item() for values,label in zip(logits[0],logits_label)] |
|
|
| logits = [] |
| pos = 0 |
| for index,mask in enumerate(valid_ids[0]): |
| if index == 0: |
| continue |
| if mask == 1: |
| logits.append((logits_label[index-pos],logits_confidence[index-pos])) |
| else: |
| pos += 1 |
| logits.pop() |
| labels = [(self.label_map[label],confidence) for label,confidence in logits] |
| words = word_tokenize(text) |
|
|
| entity_list = {'B-PER':'Person', |
| 'B-FAC':'Facility', |
| 'B-LOC':'Location', |
| 'B-ORG':'Organization', |
| 'B-ART':'Work Of Art', |
| 'B-EVENT':'Event', |
| 'B-DATE':'Date-Time Entity', |
| 'B-TIME':'Date-Time Entity', |
| 'B-LAW':'Law Terms', |
| 'B-PRODUCT':'Product', |
| 'B-PERCENT':'Percentage', |
| 'B-MONEY':'Currency', |
| 'B-LANGUAGE':'Langauge', |
| 'B-NORP':'Nationality / Religion / Political group', |
| 'B-QUANTITY':'Quantity', |
| 'B-ORDINAL':'Ordinal Number', |
| 'B-CARDINAL':'Cardinal Number'} |
| |
| B_labels=[] |
| I_labels=[] |
| for label, confidence in labels: |
| if (label[:1]=='B'): |
| B_labels.append(label) |
| I_labels.append('O') |
| elif (label[:1]=='I'): |
| I_labels.append(label) |
| B_labels.append('O') |
| else: |
| B_labels.append('O') |
| I_labels.append('O') |
|
|
| assert len(labels) == len(words) == len(I_labels) == len(B_labels) |
|
|
| output = self.predict_entity(B_labels, I_labels, words, labels, entity_list) |
| print(output) |
|
|
| |
| return output |
|
|
|
|
|
|