| import os |
| import streamlit as st |
| import instructor |
| from atomic_agents.lib.components.agent_memory import AgentMemory |
| from atomic_agents.lib.components.system_prompt_generator import SystemPromptGenerator |
| from atomic_agents.agents.base_agent import BaseAgent, BaseAgentConfig, BaseAgentInputSchema, BaseAgentOutputSchema |
| from dotenv import load_dotenv |
| import asyncio |
| import logging |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| load_dotenv() |
|
|
| |
| st.title("Math Reasoning Chatbot") |
| st.write("Select a provider and chat with the bot to solve math problems!") |
|
|
| |
| def setup_client(provider): |
| if provider == "openai": |
| from openai import AsyncOpenAI |
| api_key = os.getenv("OPENAI_API_KEY") |
| if not api_key: |
| st.warning("OpenAI unavailable: OPENAI_API_KEY not set. Using Ollama.") |
| return setup_client("ollama") |
| client = instructor.from_openai(AsyncOpenAI(api_key=api_key)) |
| model = "gpt-4o-mini" |
| display_model = "OpenAI (gpt-4o-mini)" |
| elif provider == "ollama": |
| from openai import AsyncOpenAI as OllamaClient |
| try: |
| client = instructor.from_openai( |
| OllamaClient(base_url="http://localhost:11434/v1", api_key="ollama"), mode=instructor.Mode.JSON |
| ) |
| model = "llama3" |
| display_model = "Ollama (llama3)" |
| logger.info("Ollama client initialized successfully") |
| except Exception as e: |
| logger.error(f"Failed to initialize Ollama client: {e}") |
| st.error(f"Ollama connection failed: {e}") |
| return None, None, None |
| else: |
| st.error(f"Unsupported provider: {provider}") |
| return None, None, None |
| return client, model, display_model |
|
|
| |
| system_prompt_generator = SystemPromptGenerator( |
| background=["You are a math genius."], |
| steps=["Think logically step by step and solve a math problem."], |
| output_instructions=[ |
| "Summarise your lengthy thinking processes into experienced problems and solutions with thinking order numbers. Do not speak of all the processes.", |
| "Answer in plain English plus formulas.", |
| "Always respond using the proper JSON schema.", |
| "Always use the available additional information and context to enhance the response.", |
| ], |
| ) |
|
|
| |
| providers_list = ["ollama", "openai"] |
| selected_provider = st.selectbox("Choose a provider:", providers_list, key="provider_select") |
|
|
| |
| client, model, display_model = setup_client(selected_provider) |
| if client is None: |
| st.stop() |
|
|
| |
| st.session_state.display_model = display_model |
| if "agent" not in st.session_state or st.session_state.get("current_model") != model: |
| if "memory" not in st.session_state: |
| st.session_state.memory = AgentMemory() |
| initial_message = BaseAgentOutputSchema(chat_message="Hello! I'm here to help with math problems. What can I assist you with today?") |
| st.session_state.memory.add_message("assistant", initial_message) |
| st.session_state.conversation = [("assistant", initial_message.chat_message)] |
| st.session_state.agent = BaseAgent(config=BaseAgentConfig( |
| client=client, |
| model=model, |
| system_prompt_generator=system_prompt_generator, |
| memory=st.session_state.memory, |
| system_role="developer", |
| )) |
| st.session_state.current_model = model |
|
|
| |
| st.markdown(f"**Selected Model:** {st.session_state.display_model}") |
|
|
| |
| with st.expander("View System Prompt"): |
| system_prompt = system_prompt_generator.generate_prompt() |
| st.text(system_prompt) |
|
|
| |
| for role, message in st.session_state.conversation: |
| with st.chat_message(role): |
| st.markdown(message) |
|
|
| |
| user_input = st.chat_input(placeholder="e.g., x^4 + a^4 = 0 find cf") |
|
|
| |
| if user_input: |
| |
| st.session_state.conversation.append(("user", user_input)) |
| input_schema = BaseAgentInputSchema(chat_message=user_input) |
| st.session_state.memory.add_message("user", input_schema) |
|
|
| |
| with st.chat_message("user"): |
| st.markdown(user_input) |
|
|
| |
| with st.chat_message("assistant"): |
| response_container = st.empty() |
| async def stream_response(): |
| current_response = "" |
| try: |
| async for partial_response in st.session_state.agent.run_async(input_schema): |
| if hasattr(partial_response, "chat_message") and partial_response.chat_message: |
| if partial_response.chat_message != current_response: |
| current_response = partial_response.chat_message |
| response_container.markdown(current_response) |
| except Exception as e: |
| logger.error(f"Error streaming response: {e}") |
| response_container.error(f"Error: {e}") |
|
|
| |
| st.session_state.conversation.append(("assistant", current_response)) |
| st.session_state.memory.add_message("assistant", BaseAgentOutputSchema(chat_message=current_response)) |
|
|
| |
| asyncio.run(stream_response()) |