| import re |
| import os |
| from dotenv import load_dotenv |
|
|
| import gradio as gr |
| import pandas as pd |
| from pandas import DataFrame as PandasDataFrame |
|
|
| from llm import MessageChatCompletion |
| from customization import css, js |
| from examples import example_1, example_2, example_3, example_4 |
| from prompt_template import system_message_template, user_message_template |
|
|
| load_dotenv() |
|
|
| API_KEY = os.getenv("API_KEY") |
|
|
|
|
| df = pd.read_csv('subsectors.csv') |
| logs_columns = ['Abstract', 'Model', 'Results'] |
| logs_df = PandasDataFrame(columns=logs_columns) |
|
|
|
|
| def download_logs(): |
| global logs_df |
| |
| if os.name == 'nt': |
| desktop = os.path.join(os.path.join(os.environ['USERPROFILE']), 'Desktop') |
| else: |
| desktop = os.path.join(os.path.join(os.path.expanduser('~')), 'Desktop') |
|
|
| |
| file_path = os.path.join(desktop, 'classification_logs.csv') |
|
|
| |
| logs_df.to_csv(file_path) |
|
|
|
|
| def build_context(row): |
| subsector_name = row['Subsector'] |
| context = f"Subsector name: {subsector_name}. " |
| context += f"{subsector_name} Definition: {row['Definition']}. " |
| context += f"{subsector_name} keywords: {row['Keywords']}. " |
| context += f"{subsector_name} Does include: {row['Does include']}. " |
| context += f"{subsector_name} Does not include: {row['Does not include']}.\n" |
|
|
| return context |
|
|
|
|
| def click_button(model, api_key, abstract): |
| labels = df['Subsector'].tolist() |
| prompt_context = [build_context(row) for _, row in df.iterrows()] |
| language_model = MessageChatCompletion(model=model, api_key=api_key) |
| system_message = system_message_template.format(prompt_context=prompt_context) |
| user_message = user_message_template.format(labels=labels, abstract=abstract) |
| language_model.new_system_message(content=system_message) |
| language_model.new_user_message(content=user_message) |
| language_model.send_message() |
|
|
| response_reasoning = language_model.get_last_message() |
|
|
| dict_pattern = r'\{.*?\}' |
| match = re.search(dict_pattern, response_reasoning, re.DOTALL) |
|
|
| if match and language_model.error is False: |
| match_score_dict = eval(match.group(0)) |
| else: |
| match_score_dict = {} |
|
|
| |
| new_log_entry = pd.DataFrame({'Abstract': [abstract], 'Model': [model], 'Results': [str(match_score_dict)]}) |
| global logs_df |
| logs_df = pd.concat([logs_df, new_log_entry], ignore_index=True) |
|
|
| return match_score_dict, response_reasoning, logs_df |
|
|
|
|
| def on_select(evt: gr.SelectData): |
| selected = df.iloc[[evt.index[0]]].iloc[0] |
| name, definition, keywords, does_include, does_not_include = selected['Subsector'], selected['Definition'], selected['Keywords'], selected['Does include'], selected['Does not include'] |
| name_accordion = gr.Accordion(label=name) |
| return name_accordion, definition, keywords, does_include, does_not_include |
|
|
|
|
| |
| with gr.Blocks(css=css, js=js) as demo: |
| state_lotto = gr.State() |
| selected_x_labels = gr.State() |
| with gr.Tab("Patent Discovery"): |
| with gr.Row(): |
| with gr.Column(scale=5): |
| dropdown_model = gr.Dropdown( |
| label="Model", |
| choices=["gpt-5-mini", "gpt-5", "gpt-4o-mini", "gpt-4o"], |
| value="gpt-5-mini", |
| multiselect=False, |
| interactive=True |
| ) |
| with gr.Column(scale=5): |
| api_key = gr.Textbox( |
| label="API Key", |
| interactive=True, |
| lines=1, |
| max_lines=1, |
| type="password", |
| value=API_KEY |
| ) |
| with gr.Row(equal_height=True): |
| abstract_description = gr.Textbox( |
| label="Abstract description", |
| lines=5, |
| max_lines=10000, |
| interactive=True, |
| placeholder="Input a patent abstract" |
| ) |
| with gr.Row(): |
| with gr.Accordion(label="Example Abstracts", open=False): |
| gr.Examples( |
| examples=[example_1, example_2, example_3, example_4], |
| inputs=abstract_description, |
| fn=click_button, |
| label="", |
| |
| ) |
| with gr.Row(): |
| btn_get_result = gr.Button("Classify") |
| with gr.Row(elem_classes=['all_results']): |
| with gr.Column(scale=4): |
| label_result = gr.Label(num_top_classes=None) |
| with gr.Column(scale=6): |
| reasoning = gr.Markdown(label="Reasoning", elem_classes=['reasoning_results']) |
|
|
| with gr.Tab("Subsector definitions"): |
| with gr.Row(): |
| with gr.Column(scale=4): |
| df_subsectors = gr.DataFrame(df[['Subsector']], interactive=False, height=800) |
| with gr.Column(scale=6): |
| with gr.Accordion(label='Artificial Intelligence, Big Data and Analytics') as subsector_name: |
| s1_definition = gr.Textbox(label="Definition", lines=5, max_lines=100, value="Virtual reality (VR) is an artificial, computer-generated simulation or recreation of a real life environment or situation. Augmented reality (AR) is a technology that layers computer-generated enhancements atop an existing reality in order to make it more meaningful through the ability to interact with it. ") |
| s1_keywords = gr.Textbox(label="Keywords", lines=5, max_lines=100, |
| value="Mixed Reality, 360 video, frame rate, metaverse, virtual world, cross reality, Artificial intelligence, computer vision") |
| does_include = gr.Textbox(label="Does include", lines=4) |
| does_not_include = gr.Textbox(label="Does not include", lines=3) |
|
|
| with gr.Tab("Logs"): |
| output_dataframe = gr.Dataframe( |
| value=logs_df, |
| type="pandas", |
| height=500, |
| headers=['Abstract', 'Model', 'Results'], |
| interactive=False, |
| column_widths=["45%", "10%", "45%"], |
| ) |
| btn_export = gr.Button( |
| value="Export to CSV", |
| size="sm", |
| ) |
|
|
| btn_get_result.click( |
| fn=click_button, |
| inputs=[dropdown_model, api_key, abstract_description], |
| outputs=[label_result, reasoning, output_dataframe]) |
|
|
| btn_export.click( |
| fn=download_logs, |
| ) |
|
|
| df_subsectors.select( |
| fn=on_select, |
| outputs=[subsector_name, s1_definition, s1_keywords, does_include, does_not_include] |
| ) |
|
|
| if __name__ == "__main__": |
| |
| demo.launch() |
|
|