| import streamlit as st |
| import torch |
| from transformers import BertForTokenClassification, BertTokenizerFast |
|
|
| def load_model(model_name='dejanseo/LinkBERT'): |
| model = BertForTokenClassification.from_pretrained(model_name, num_labels=2) |
| model.eval() |
| return model |
|
|
| def predict_and_annotate(model, tokenizer, text): |
| |
| inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, return_offsets_mapping=True) |
| input_ids, attention_mask, offset_mapping = inputs["input_ids"], inputs["attention_mask"], inputs["offset_mapping"] |
|
|
| with torch.no_grad(): |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
| predictions = torch.argmax(outputs.logits, dim=-1) |
|
|
| tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze().tolist()) |
| predictions = predictions.squeeze().tolist() |
| offset_mapping = offset_mapping.squeeze().tolist() |
|
|
| annotated_text = "" |
| previous_end = 0 |
| for offset, prediction in zip(offset_mapping, predictions): |
| start, end = offset |
| if start == end: |
| continue |
| if prediction == 1: |
| if start > previous_end: |
| annotated_text += text[previous_end:start] |
| annotated_text += f"<u>{text[start:end]}</u>" |
| else: |
| if start > previous_end: |
| annotated_text += text[previous_end:start] |
| annotated_text += text[start:end] |
| previous_end = end |
| annotated_text += text[previous_end:] |
|
|
| return annotated_text |
|
|
| |
| st.title("BERT Token Classification for Anchor Text Prediction") |
|
|
| |
| model = load_model('dejanseo/LinkBERT') |
| tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') |
|
|
| |
| user_input = st.text_area("Paste the text you want to analyze:", "Type or paste text here.") |
|
|
| if st.button("Predict Anchor Texts"): |
| if user_input: |
| annotated_text = predict_and_annotate(model, tokenizer, user_input) |
| st.markdown(annotated_text, unsafe_allow_html=True) |
| else: |
| st.write("Please paste some text into the text area.") |
|
|