| import streamlit as st
|
| import io
|
| import os
|
| import yaml
|
| import pyarrow
|
| import tokenizers
|
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
|
|
|
| st.set_page_config(layout="wide")
|
|
|
| @st.cache_resource
|
| def from_library():
|
| from retro_reader import RetroReader
|
| from retro_reader import constants as C
|
| return C, RetroReader
|
|
|
| C, RetroReader = from_library()
|
|
|
| my_hash_func = {
|
| io.TextIOWrapper: lambda _: None,
|
| pyarrow.lib.Buffer: lambda _: 0,
|
| tokenizers.Tokenizer: lambda _: None,
|
| tokenizers.AddedToken: lambda _: None
|
| }
|
|
|
| @st.cache_resource(hash_funcs=my_hash_func)
|
| def load_en_electra_base_model():
|
| config_file = "configs/inference_en_electra_base.yaml"
|
| return RetroReader.load(config_file=config_file)
|
|
|
| @st.cache_resource(hash_funcs=my_hash_func)
|
| def load_en_electra_large_model():
|
| config_file = "configs/inference_en_electra_large.yaml"
|
| return RetroReader.load(config_file=config_file)
|
|
|
| RETRO_READER_HOST = {
|
| "google/electra-base-discriminator": load_en_electra_base_model(),
|
| "google/electra-large-discriminator": load_en_electra_large_model(),
|
| }
|
|
|
| def display_top_predictions(nbest_preds, top_k=10):
|
|
|
| if not isinstance(nbest_preds, list):
|
| nbest_preds = nbest_preds['id-01']
|
|
|
| sorted_preds = sorted(nbest_preds, key=lambda x: x['probability'], reverse=True)[:top_k]
|
| st.markdown("### Top Predictions")
|
| for i, pred in enumerate(sorted_preds, 1):
|
| st.markdown(f"**{i}. {pred['text']}** - Probability: {pred['probability']*100:.2f}%")
|
|
|
| def main():
|
|
|
| st.sidebar.title("π Welcome to Retro Reader")
|
| st.sidebar.write("""
|
| MRC-RetroReader is a machine reading comprehension (MRC) model designed for reading comprehension tasks. The model leverages advanced neural network architectures to provide high accuracy in understanding and responding to textual queries.
|
| """)
|
| image_url = "img.jpg"
|
| st.sidebar.image(image_url, use_column_width=True)
|
| st.sidebar.title("Contributors")
|
| st.sidebar.write("""
|
| - Phan Van Hoang
|
| - Pham Long Khanh
|
| """)
|
|
|
| st.title("Retrospective Reader Demo")
|
| st.markdown("## Model nameπ¨")
|
| option = st.selectbox(
|
| label="Choose the model used in retro reader",
|
| options=(
|
| "[1] google/electra-base-discriminator",
|
| "[2] google/electra-large-discriminator"
|
| ),
|
| index=1,
|
| )
|
| lang_code, model_name = option.split(" ")
|
| retro_reader = RETRO_READER_HOST[model_name]
|
|
|
| lang_prefix = "EN"
|
| height = 200
|
| return_submodule_outputs = True
|
|
|
| with st.form(key="my_form"):
|
| st.markdown("## Type your query β")
|
| query = st.text_input(
|
| label="",
|
| value=getattr(C, f"{lang_prefix}_EXAMPLE_QUERY"),
|
| max_chars=None,
|
| help=getattr(C, f"{lang_prefix}_QUERY_HELP_TEXT"),
|
| )
|
| st.markdown("## Type your query π¬")
|
| context = st.text_area(
|
| label="",
|
| value=getattr(C, f"{lang_prefix}_EXAMPLE_CONTEXTS"),
|
| height=height,
|
| max_chars=None,
|
| help=getattr(C, f"{lang_prefix}_CONTEXT_HELP_TEXT"),
|
| )
|
| submit_button = st.form_submit_button(label="Submit")
|
|
|
| if submit_button:
|
| with st.spinner("π Please wait.."):
|
| outputs = retro_reader(query=query, context=context, return_submodule_outputs=return_submodule_outputs)
|
| answer, score = outputs[0]["id-01"], outputs[1]
|
| if not answer:
|
| answer = "No answer"
|
| st.markdown("## π Results")
|
| st.write(answer)
|
| if return_submodule_outputs:
|
| score_ext, nbest_preds, score_diff = outputs[2:]
|
| display_top_predictions(nbest_preds)
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|