| import random |
| from datetime import datetime |
|
|
| import streamlit as st |
| from openai import OpenAI |
| from pymongo.mongo_client import MongoClient |
| from pymongo.server_api import ServerApi |
|
|
|
|
| |
| |
| |
| st.set_page_config( |
| page_title="Bot", |
| page_icon="🤖", |
| initial_sidebar_state="collapsed", |
| layout="wide", |
| menu_items={ |
| "Report a bug": "mailto:yk408@cam.ac.uk", |
| "About": "Bot", |
| }, |
| ) |
|
|
| st.markdown( |
| """ |
| <style> |
| div[role="radiogroup"] > :first-child { |
| display: none !important; |
| } |
| </style> |
| """, |
| unsafe_allow_html=True, |
| ) |
|
|
|
|
| |
| |
| |
| MAX_MESSAGES_DEFAULT = 50 |
| SUBMIT_AFTER_USER_TURNS = 5 |
|
|
| GREETINGS = { |
| "hi", "hello", "hey", "heya", "hiya", "yo", "howdy", "sup", |
| "good morning", "good afternoon", "good evening", "good day", |
| "what's up", "whats up", "how do you do", "greetings", |
| "salutations", "hi there", "hello there", "hey there", |
| "how's it going", "hows it going", "how are you", "how are you doing" |
| } |
|
|
| CONDITIONS = { |
| "1": { |
| "label": "true control", |
| "base_url": None, |
| "model": None, |
| "api_secret": None, |
| }, |
| "2": { |
| "label": "base", |
| "base_url": "https://openrouter.ai/api/v1", |
| "model": "meta-llama/llama-3.3-70b-instruct", |
| "api_secret": "OPENROUTER_API_KEY", |
| }, |
| "3": { |
| "label": "bridging", |
| "base_url": "https://tinker.thinkingmachines.dev/services/tinker-prod/oai/api/v1", |
| "model": "tinker://808e4f02-e847-54ae-bc75-f14ee885ce5a:train:0/sampler_weights/final_sampler", |
| "api_secret": "TINKER_API_KEY", |
| }, |
| "4": { |
| "label": "gpt", |
| "base_url": "https://openrouter.ai/api/v1", |
| "model": "openai/gpt-5.4", |
| "api_secret": "OPENROUTER_API_KEY", |
| }, |
| } |
|
|
|
|
| |
| |
| |
| def init_session_state() -> None: |
| if "initialized" in st.session_state: |
| return |
|
|
| user_id = str(random.randint(100000, 999999)) |
|
|
| st.session_state.initialized = True |
| st.session_state.inserted = 0 |
| st.session_state.max_messages = MAX_MESSAGES_DEFAULT |
| st.session_state.messages = [] |
| st.session_state.client = None |
| st.session_state.setup = False |
| st.session_state.convo_start_time = None |
|
|
| st.session_state.user_data = { |
| "BASE_URL": "", |
| "MODEL_PATH": "", |
| "url_id": True, |
| "user_id": user_id, |
| "start_time": datetime.now(), |
| "random_pid": None, |
| "condition": None, |
| } |
|
|
|
|
| def reset_conversation_state() -> None: |
| st.session_state.messages = [] |
| st.session_state.max_messages = MAX_MESSAGES_DEFAULT |
| st.session_state.convo_start_time = None |
| st.session_state.setup = False |
| st.session_state.client = None |
|
|
|
|
| |
| |
| |
| def ensure_query_params() -> None: |
| if "p" not in st.query_params or st.query_params["p"] not in CONDITIONS: |
| st.query_params["p"] = st.radio( |
| "Select a condition for the conversation", |
| ["", "1", "2", "3", "4"], |
| help="1 = true control, 2 = base, 3 = bridging, 4 = gpt", |
| ) |
|
|
| if "id" not in st.query_params or not st.query_params["id"]: |
| st.session_state.user_data["url_id"] = False |
| st.query_params["id"] = st.session_state.user_data["user_id"] |
|
|
|
|
| def setup_conversation() -> None: |
| condition = st.query_params.get("p", "") |
| if condition not in CONDITIONS: |
| return |
|
|
| config = CONDITIONS[condition] |
|
|
| st.session_state.user_data["random_pid"] = st.query_params["id"] |
| st.session_state.user_data["condition"] = condition |
| st.session_state.user_data["BASE_URL"] = config["base_url"] or "" |
| st.session_state.user_data["MODEL_PATH"] = config["model"] or "" |
| st.session_state.convo_start_time = datetime.now() |
|
|
| if condition != "1": |
| st.session_state.client = OpenAI( |
| base_url=config["base_url"], |
| api_key=st.secrets[config["api_secret"]], |
| ) |
| else: |
| st.session_state.client = None |
|
|
| st.session_state.setup = True |
|
|
|
|
| |
| |
| |
| def is_greeting_only(text: str) -> bool: |
| clean = text.strip().lower().strip("!,.?") |
| return clean in GREETINGS |
|
|
|
|
| def user_turn_count() -> int: |
| return sum(1 for m in st.session_state.messages if m["role"] == "user") |
|
|
|
|
| def can_submit() -> bool: |
| return ( |
| user_turn_count() >= SUBMIT_AFTER_USER_TURNS |
| or len(st.session_state.messages) >= st.session_state.max_messages |
| ) |
|
|
|
|
| def render_sidebar() -> None: |
| with st.sidebar: |
| st.markdown("# **Let's talk!**") |
| st.markdown("# **Step 1. Type in the chat box to start a conversation**") |
|
|
| st.success( |
| "Ask, request, or talk to the chatbot about something you consider " |
| "**politically polarizing** or something that people from different " |
| "US political parties might disagree about.", |
| icon="🎯", |
| ) |
|
|
| st.markdown( |
| "🚫 Please avoid greetings and start the conversation with a question " |
| "or a statement about a politically polarizing topic.\n" |
| "**Note: the chatbot's knowledge only goes up to late August 2025.**" |
| ) |
|
|
| st.markdown( |
| "# **Step 2. Use the *Submit Interaction* button to get your chatbot word**\n\n" |
| "⚠️ You must respond **at least 5 times** before you will see a *Submit Interaction* button. " |
| "You can continue before submitting, but **you must Submit Interaction and enter your chatbot word " |
| "to proceed with the survey**.\n" |
| "❗ Do not share any personal information (e.g., name or address). " |
| "Do not use AI tools to write your responses. " |
| "If you encounter any technical issues, please let us know. " |
| "It might sometimes take 30 seconds or more to generate a response, so please be patient." |
| ) |
|
|
|
|
| def render_messages() -> None: |
| for message in st.session_state.messages: |
| if message["role"] != "system": |
| with st.chat_message(message["role"]): |
| st.markdown(message["content"]) |
|
|
|
|
| def save_conversation() -> None: |
| payload = dict(st.session_state.user_data) |
| payload["messages"] = st.session_state.messages |
| payload["convo_start_time"] = st.session_state.convo_start_time |
| payload["convo_end_time"] = datetime.now() |
| payload["inserted"] = st.session_state.inserted + 1 |
|
|
| with MongoClient(st.secrets["mongo"], server_api=ServerApi("1")) as mongo_client: |
| db = mongo_client.bridge |
| collection = db.app2 |
| collection.insert_one(payload) |
|
|
|
|
| def get_assistant_response() -> str: |
| completion = st.session_state.client.chat.completions.create( |
| model=st.session_state.user_data["MODEL_PATH"], |
| messages=st.session_state.messages, |
| max_tokens=512, |
| stream=False, |
| ) |
| return completion.choices[0].message.content or "" |
|
|
|
|
| |
| |
| |
| init_session_state() |
| ensure_query_params() |
|
|
| if not st.session_state.setup and st.query_params.get("p") in CONDITIONS: |
| setup_conversation() |
|
|
| render_sidebar() |
|
|
| if st.session_state.setup and not st.session_state.messages and st.session_state.inserted == 0: |
| st.success( |
| "Ask, request, or talk to the chatbot about something you consider " |
| "**politically polarizing** or something that people from different US political parties might disagree about.", |
| icon="🎯", |
| ) |
|
|
| render_messages() |
|
|
|
|
| |
| |
| |
| if len(st.session_state.messages) >= st.session_state.max_messages: |
| st.info("You have reached the limit of messages for this conversation. Please end and submit the conversation.") |
|
|
| elif st.session_state.inserted > 0: |
| st.markdown("## Copy your WORD!") |
| st.markdown("**Your chatbot WORD is:**") |
| st.markdown("## TOMATOES") |
| st.markdown("**Please copy the WORD and enter it into the survey field below.**") |
|
|
| elif prompt := st.chat_input("Type to ask a question or respond..."): |
| if not st.session_state.setup: |
| st.error("Please select a condition first.") |
| st.stop() |
|
|
| if not st.session_state.messages and is_greeting_only(prompt): |
| st.error( |
| "Please avoid greetings and start the conversation with a question or a statement about a politically polarizing topic.", |
| icon="🚫", |
| ) |
| st.stop() |
|
|
| st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
| with st.chat_message("user"): |
| st.markdown(prompt) |
|
|
| condition = st.session_state.user_data["condition"] |
|
|
| if condition == "1": |
| response = ( |
| "Thank you for your question. You have been randomly assigned to a condition " |
| "without a chatbot. **Please submit your interaction anyway** to get your chatbot word " |
| "and proceed with the survey. Do not worry, this will not influence your compensation." |
| ) |
| st.session_state.messages.append({"role": "assistant", "content": response}) |
| st.session_state.max_messages = len(st.session_state.messages) |
| st.rerun() |
|
|
| with st.chat_message("assistant"): |
| try: |
| with st.spinner("Typing..."): |
| response = get_assistant_response() |
|
|
| st.markdown(response) |
| st.session_state.messages.append({"role": "assistant", "content": response}) |
|
|
| except Exception as e: |
| error_message = ( |
| "An error has occurred or you've reached the maximum conversation length. " |
| "Please submit the conversation." |
| ) |
| st.session_state.messages.append({"role": "assistant", "content": error_message}) |
| st.session_state.max_messages = len(st.session_state.messages) |
| st.error(f"Request failed: {e}") |
|
|
|
|
| |
| |
| |
| if can_submit() and st.session_state.inserted == 0: |
| cols = st.columns((1, 1, 1)) |
| with cols[2]: |
| if st.button("Submit Interaction", use_container_width=True): |
| try: |
| save_conversation() |
| st.session_state.inserted += 1 |
| reset_conversation_state() |
| st.rerun() |
| except Exception as e: |
| st.error(f"Failed to save conversation: {e}") |