| import streamlit as st |
| import numpy as np |
| import torch |
| from transformers import DistilBertTokenizer, DistilBertForMaskedLM |
| from qa_model import ReuseQuestionDistilBERT |
|
|
| @st.cache_resource |
| def load_model(): |
| try: |
| mod = DistilBertForMaskedLM.from_pretrained("distilbert-base-uncased").distilbert |
| m = ReuseQuestionDistilBERT(mod) |
| m.load_state_dict(torch.load("distilbert_reuse.model", map_location=torch.device('cpu'))) |
| model = m |
| tokenizer = DistilBertTokenizer.from_pretrained('qa_tokenizer') |
| return model, tokenizer |
| except Exception as e: |
| st.error(f"Error loading model: {e}") |
| return None, None |
|
|
| def get_answer(question, text, tokenizer, model): |
| if model is None or tokenizer is None: |
| return "Model not loaded properly." |
|
|
| question = [question.strip()] |
| text = [text.strip()] |
|
|
| inputs = tokenizer( |
| question, |
| text, |
| max_length=512, |
| truncation="only_second", |
| padding="max_length", |
| return_tensors="pt" |
| ) |
|
|
| with torch.no_grad(): |
| outputs = model( |
| inputs["input_ids"], |
| attention_mask=inputs["attention_mask"], |
| start_positions=None, |
| end_positions=None |
| ) |
|
|
| if "start_logits" not in outputs or "end_logits" not in outputs: |
| return "Error: Model output structure is incorrect." |
|
|
| start = torch.argmax(outputs["start_logits"], dim=1) |
| end = torch.argmax(outputs["end_logits"], dim=1) |
|
|
| ans_tokens = inputs["input_ids"][0, start:end + 1] |
| answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens, skip_special_tokens=True) |
| predicted = tokenizer.convert_tokens_to_string(answer_tokens) |
| return predicted or "No answer found." |
|
|
| def main(): |
| st.set_page_config(page_title="Question Answering Tool", page_icon=":mag_right:") |
| st.write("# Question Answering Tool") |
| |
| model, tokenizer = load_model() |
| |
| with st.form("qa_form"): |
| text = st.text_area("Enter your text here") |
| question = st.text_input("Enter your question here") |
| |
| if st.form_submit_button("Submit"): |
| if not text or not question: |
| st.warning("Please enter both text and a question.") |
| else: |
| st.text("Processing...") |
| answer = get_answer(question, text, tokenizer, model) |
| st.text(f"Answer: {answer}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|