| from transformers import AutoModel, AutoTokenizer
|
| import torch
|
| from tqdm import tqdm
|
| from torch.utils.data import Dataset, DataLoader
|
| import os
|
| import spacy
|
| import certifi
|
| import streamlit as st
|
|
|
| os.environ['SSL_CERT_FILE'] = certifi.where()
|
|
|
| nlp = spacy.load("en_core_web_lg")
|
|
|
| model_name = "microsoft/MiniLM-L12-H384-uncased"
|
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
| def mean_pooling(model_output, attention_mask):
|
| token_embeddings = model_output[0]
|
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
|
| sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
| return sum_embeddings/sum_mask
|
|
|
| class SentenceBERTClassifier(torch.nn.Module):
|
| def __init__(self, model_name="microsoft/MiniLM-L12-H384-uncased", input_dim=384):
|
| super(SentenceBERTClassifier, self).__init__()
|
| self.model = AutoModel.from_pretrained(model_name)
|
| self.dense1 = torch.nn.Linear(input_dim*3, 768)
|
| self.relu1 = torch.nn.ReLU()
|
| self.dropout1 = torch.nn.Dropout(0.1)
|
| self.dense2 = torch.nn.Linear(768, 384)
|
| self.relu2 = torch.nn.ReLU()
|
| self.dropout2 = torch.nn.Dropout(0.1)
|
| self.classifier = torch.nn.Linear(384, 1)
|
| self.sigmoid = torch.nn.Sigmoid()
|
|
|
| def forward(self, sent_ids, doc_ids, sent_mask, doc_mask):
|
| sent_output = self.model(input_ids=sent_ids, attention_mask=sent_mask)
|
| sent_embedding = mean_pooling(sent_output, sent_mask)
|
|
|
| doc_output = self.model(input_ids=doc_ids, attention_mask=doc_mask)
|
| doc_embedding = mean_pooling(doc_output, doc_mask)
|
|
|
| combined_embedding = sent_embedding * doc_embedding
|
| concat_embedding = torch.cat((sent_embedding, doc_embedding, combined_embedding), dim=1)
|
|
|
|
|
| dense_output1 = self.dense1(concat_embedding)
|
| relu_output1 = self.relu1(dense_output1)
|
| dropout_output1 = self.dropout1(relu_output1)
|
| dense_output2 = self.dense2(dropout_output1)
|
| relu_output2 = self.relu2(dense_output2)
|
| dropout_output2 = self.dropout2(relu_output2)
|
| logits = self.classifier(dropout_output2)
|
| probs = self.sigmoid(logits)
|
| return probs
|
|
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
| extractive_model = SentenceBERTClassifier(model_name=model_name)
|
| extractive_model.load_state_dict(torch.load("model_path\minilm_bal_exsum.pth", map_location=torch.device(device) ))
|
| extractive_model.eval()
|
|
|
| def get_tokens(text, tokenizer):
|
| inputs = tokenizer.batch_encode_plus(
|
| text
|
| , add_special_tokens=True
|
| , max_length = 512
|
| , padding="max_length"
|
| , return_token_type_ids=True
|
| , truncation=True
|
| , return_tensors="pt")
|
|
|
| ids = inputs["input_ids"]
|
| mask = inputs["attention_mask"]
|
|
|
| return ids, mask
|
|
|
|
|
| def predict(model,sents, doc):
|
| sent_id, sent_mask = get_tokens(sents,tokenizer)
|
| sent_id, sent_mask = torch.tensor(sent_id, dtype=torch.long),torch.tensor(sent_mask, dtype=torch.long)
|
|
|
| doc_id, doc_mask = get_tokens([doc],tokenizer)
|
| doc_id, doc_mask = doc_id.repeat(len(sents), 1), doc_mask.repeat(len(sents), 1)
|
| doc_id, doc_mask = torch.tensor(doc_id, dtype=torch.long),torch.tensor(doc_mask, dtype=torch.long)
|
|
|
|
|
|
|
| sent_id[sent_id >= tokenizer.vocab_size] = tokenizer.unk_token_id
|
| doc_id[doc_id >= tokenizer.vocab_size] = tokenizer.unk_token_id
|
|
|
| preds = model(sent_id, doc_id, sent_mask, doc_mask)
|
| return preds
|
|
|
| def extract_summary(doc, model=extractive_model, min_sentence_length=14, top_k=4, batch_size=4):
|
| doc = doc.replace("\n","")
|
| doc_sentences = []
|
| for sent in nlp(doc).sents:
|
| if len(sent) > min_sentence_length:
|
| doc_sentences.append(str(sent))
|
|
|
|
|
|
|
|
|
|
|
| scores = []
|
|
|
| for i in tqdm(range(int(len(doc_sentences) / batch_size) + 1)):
|
| batch_start = i*batch_size
|
| batch_end = (i+1) * batch_size if (i+1) * batch_size < len(doc_sentences) else len(doc_sentences)
|
| batch = doc_sentences[batch_start: batch_end]
|
| if batch:
|
| preds = predict(model, batch, doc)
|
| scores = scores + preds.tolist()
|
|
|
| sent_pred_list = [{"sentence": doc_sentences[i], "score": scores[i][0], "index":i} for i in range(len(doc_sentences))]
|
| sorted_sentences = sorted(sent_pred_list, key=lambda k: k['score'], reverse=True)
|
|
|
| sorted_result = sorted_sentences[:top_k]
|
| sorted_result = sorted(sorted_result, key=lambda k: k['index'])
|
|
|
| summary = [x["sentence"] for x in sorted_result]
|
| summary = " ".join(summary)
|
|
|
| return summary |