|
|
| import os |
| import asyncio |
| |
| if 'OPENAI_API_KEY' not in os.environ: |
| os.environ['OPENAI_API_KEY'] = 'none' |
| os.environ["OPENAI_API_BASE"] = 'none' |
| |
| os.environ["SERP_API_KEY"] = 'none' |
| os.environ["SEMANTIC_SCHOLAR_API_KEY"] = 'none' |
| if os.name == 'nt': |
| asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) |
| |
| import openai |
| import pandas as pd |
| import streamlit as st |
| |
| from PIL import Image |
| from agent import TeLLAgent, make_tools |
| from streamlit_callback_handler import \ |
| StreamlitCallbackHandlerChem |
| import base64 |
| import pandas as pd |
| from dotenv import load_dotenv |
| from langchain_openai import ChatOpenAI , OpenAI |
| import base64 |
| from io import BytesIO |
| from PIL import Image |
| import tempfile |
| |
|
|
| def convert_to_base64(pil_image): |
| buffered = BytesIO() |
| pil_image.save(buffered, format="PNG") |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
| return img_str |
|
|
| def oai_key_isvalid(api_key): |
| """Check if a given OpenAI key is valid""" |
| try: |
| if os.getenv("OPENAI_API_BASE"): |
| llm = ChatOpenAI(openai_api_key = api_key, base_url=os.getenv("OPENAI_API_BASE")) |
| out = llm.invoke("This is a test") |
| else: |
| llm = ChatOpenAI(openai_api_key = api_key) |
| out = llm.invoke("This is a test") |
| return True |
| except: |
| return False |
| |
| load_dotenv() |
| ss = st.session_state |
| ss.prompt = None |
| if 'pending_prompt' not in st.session_state: |
| st.session_state.pending_prompt = None |
| |
| st.markdown( |
| """ |
| <style> |
| [data-testid="stSidebar"][aria-expanded="true"]{ |
| min-width: 500px; |
| max-width: 500px; |
| } |
| """, |
| unsafe_allow_html=True, |
| ) |
|
|
|
|
| def instantiate_agent(model1, model2, file_path = '...', image_path ='...', tools=None): |
| ss.agent = TeLLAgent( tools=tools, |
| model1 = model1, |
| model2 = model2, |
| tools_model='gpt-4o-2024-11-20', |
| temp=0.1, |
| openai_api_key=ss.get('api_key') , file_path = file_path, |
| image_path =image_path) |
| return ss.agent |
|
|
|
|
|
|
| def on_api_key_change(): |
| api_key = ss.get('api_key') or os.getenv('OPENAI_API_KEY') |
| |
| |
| if not oai_key_isvalid(api_key): |
| st.write("Please input a valid OpenAI API key.") |
|
|
| def run_prompt(prompt, file_path = '...', image_path = '...'): |
| if ss.get('domain') =='Drug discovery': |
| agent = instantiate_agent(model1 = ss.get('model1_select'), model2 = ss.get('model2_select'), file_path = file_path, image_path =image_path, tools = 'drug') |
| else: |
| agent = instantiate_agent(model1 = ss.get('model1_select'), model2 = ss.get('model2_select'), file_path = file_path, image_path =image_path) |
| st.chat_message("user").write(prompt) |
| with st.chat_message("assistant") : |
| try: |
| |
| response = agent.run(prompt) |
| if ss.get('file_type') == 'CSV (.csv)': |
| try: |
| fx = pd.DataFrame(list(response)) |
| st.markdown(":red[Prediction finished! ]") |
| st.download_button( "⬇️Download the predicted files as .csv", fx.to_csv(), "predict results.csv", use_container_width=True) |
| except: |
| st.write(response) |
| else: |
| st.write(response) |
| except openai.AuthenticationError: |
| st.write("Please input a valid OpenAI API key") |
| except openai.APIError: |
| |
| print("OpenAI API error, please try again!") |
| |
| pre_prompts = [ |
| 'Generate a donor with PCE = 10% ', |
| ('The history and development of Y6' |
| |
| ), |
| ( |
| 'Predict the LogP of PM6' |
| ), |
| 'Predict the PCE of Y6' |
| ] |
|
|
| |
| with st.sidebar: |
|
|
| st.header("🤖 :blue[TeLLAgent] ") |
| |
| st.text_input( |
| 'Input your OpenAI API key.', |
| placeholder = 'Input your OpenAI API key.', |
| type='password', |
| key='api_key', |
| on_change=on_api_key_change, |
| label_visibility="collapsed" |
| ) |
| st.text_input( |
| 'Input base url (optional).', |
| placeholder = 'Input base url (optional)', |
| key='base_url',type='password', |
| label_visibility="collapsed" |
| ) |
| |
| st.text_input( |
| 'Input global planning model to use', |
| |
| key='model1_select', |
| ) |
| st.text_input( |
| 'Input local execution model to use', |
| |
| key='model2_select', |
| ) |
| st.text_input( |
| 'Input SERP API KEY (optional).', |
| placeholder = 'Input SERP API KEY (optional)', |
| key='serp_api',type='password', |
| label_visibility="collapsed" |
| ) |
| st.text_input( |
| 'Input SEMANTIC SCHOLAR API KEY (optional).', |
| placeholder = 'Input SEMANTIC SCHOLAR API KEY (optional)', |
| key='semantic_scholar_url',type='password', |
| label_visibility="collapsed" |
| ) |
| os.environ['OPENAI_API_KEY'] = ss.get('api_key') |
| os.environ["OPENAI_API_BASE"] = ss.get('base_url') |
| |
| os.environ["SERP_API_KEY"] = ss.get('serp_api') |
| os.environ["SEMANTIC_SCHOLAR_API_KEY"] = ss.get('semantic_scholar_url') |
| |
| |
| st.markdown('# What can I ask?') |
| cols = st.columns(2) |
| with cols[0]: |
| if st.button(r'👑 Generate a donor with PCE = 10% 🧨 '): |
| st.session_state.pending_prompt = pre_prompts[0] |
| |
| if st.button(r'📚 The history and development of Y6 '): |
| st.session_state.pending_prompt = pre_prompts[1] |
|
|
| with cols[1]: |
| if st.button(r"🎄Predict the LogP of PM6 "): |
| st.session_state.pending_prompt = pre_prompts[2] |
| |
| if st.button(r'💎 Predict the PCE of Y6'): |
| st.session_state.pending_prompt = pre_prompts[3] |
|
|
| st.selectbox( |
| 'Select the file type ', |
| ['None', 'CSV (.csv)', 'Figure (.jpg, .png, .jpeg)', 'PDF (.pdf)'], |
| key='file_type', |
| ) |
| uploaded_file = None |
| if ss.get('file_type') == 'Figure (.jpg, .png, .jpeg)': |
| uploaded_file = st.file_uploader("Choose a Figure", type = ["jpg", "jpeg", "png"]) |
| if ss.get('file_type') == 'PDF (.pdf)': |
| uploaded_file = st.file_uploader("Choose a PDF file") |
| if ss.get('file_type') == 'CSV (.csv)': |
| uploaded_file = st.file_uploader("Choose a csv file", type = 'csv') |
| st.selectbox( |
| r'📚 Choose the domain ', |
| ['Organic solar cell', 'Drug discovery'], key='domain', |
| ) |
| |
| if ss.get('domain') == 'Drug discovery': |
| instantiate_agent(model1 = 'gpt-4o-2024-11-20', model2 = 'gpt-4o-2024-11-20' ,tools = 'drug') |
| else: |
| instantiate_agent(model1 = 'gpt-4o-2024-11-20', model2 = 'gpt-4o-2024-11-20' ) |
| tools = ss.agent.agent_executor2.tools |
|
|
| tool_list = pd.Series( {f"✅ {t.name}": t.description for t in tools}).reset_index() |
| tool_list.columns = ['Tool', 'Description'] |
| st.markdown(f"# {len(tool_list)} available tools") |
| st.dataframe( |
| tool_list, |
| width='stretch', |
| hide_index=True, |
| height=200 |
| ) |
|
|
| if st.session_state.pending_prompt is not None: |
| prompt_to_run = st.session_state.pending_prompt |
| st.session_state.pending_prompt = None |
| |
| if not ss.get('model1_select') or not ss.get('model2_select'): |
| st.error("⚠️ Please input both model names in the sidebar first!") |
| else: |
| run_prompt(prompt_to_run) |
|
|
| |
| if prompt := st.chat_input("Say something and/or attach files"): |
| |
| if not ss.get('model1_select') or not ss.get('model2_select'): |
| st.error("⚠️ Please input both model names in the sidebar first!") |
| elif uploaded_file is not None: |
| |
| if ss.get('file_type') == 'CSV (.csv)': |
| with tempfile.NamedTemporaryFile( suffix ='.csv' ,delete=False) as f: |
| f.write(uploaded_file.read()) |
| run_prompt(prompt + str(' ') + str(f.name), file_path = f.name) |
| f.close() |
| |
| if ss.get('file_type') == 'Figure (.jpg, .png, .jpeg)': |
| |
| st.image(uploaded_file, width = 500) |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp: |
| |
| mg_str = base64.b64encode(uploaded_file.getvalue()).decode("utf-8") |
| temp.write(base64.b64decode(mg_str)) |
| |
| run_prompt(prompt+ str(' ') + str(temp.name), image_path = temp.name ) |
| |
| if ss.get('file_type') == 'PDF (.pdf)': |
| with tempfile.NamedTemporaryFile( suffix ='.pdf' ,delete=False) as f: |
| f.write(uploaded_file.read()) |
| run_prompt(prompt, file_path = f.name) |
| f.close() |
| |
| |
| |
| |
| |
| |
| |
| |
| else: |
| run_prompt(prompt) |
|
|
|
|