| import gradio as gr
|
| import io
|
| import os
|
| import yaml
|
| import pyarrow
|
| import tokenizers
|
| from retro_reader import RetroReader
|
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
|
| def from_library():
|
| from retro_reader import constants as C
|
| return C, RetroReader
|
|
|
| C, RetroReader = from_library()
|
|
|
|
|
| def load_model(config_path):
|
| return RetroReader.load(config_file=config_path)
|
|
|
|
|
| model_electra_base = load_model("configs/inference_en_electra_base.yaml")
|
| model_electra_large = load_model("configs/inference_en_electra_large.yaml")
|
| model_roberta = load_model("configs/inference_en_roberta.yaml")
|
| model_distilbert = load_model("configs/inference_en_distilbert.yaml")
|
|
|
| def retro_reader_demo(query, context, model_choice):
|
|
|
| if model_choice == "Electra Base":
|
| model = model_electra_base
|
| elif model_choice == "Electra Large":
|
| model = model_electra_large
|
| elif model_choice == "Roberta":
|
| model = model_roberta
|
| elif model_choice == "DistilBERT":
|
| model = model_distilbert
|
| else:
|
| return "Invalid model choice"
|
|
|
|
|
| outputs = model(query=query, context=context, return_submodule_outputs=True)
|
|
|
|
|
| answer = outputs[0]["id-01"] if outputs[0]["id-01"] else "No answer found"
|
|
|
| return answer
|
|
|
|
|
| iface = gr.Interface(
|
| fn=retro_reader_demo,
|
| inputs=[
|
| gr.Textbox(label="Query", placeholder="Type your query here..."),
|
| gr.Textbox(label="Context", placeholder="Provide the context here...", lines=10),
|
| gr.Radio(choices=["Electra Base", "Electra Large", "Roberta", "DistilBERT"], label="Model Choice")
|
| ],
|
| outputs=gr.Textbox(label="Answer"),
|
| title="Retrospective Reader Demo",
|
| description="This interface uses the RetroReader model to perform reading comprehension tasks."
|
| )
|
|
|
| if __name__ == "__main__":
|
| iface.launch(share=True)
|
|
|