""" Streamlit App: AI Product Willingness User Study ================================================= Run locally (single category): streamlit run src/streamlit_app.py -- --category groceries streamlit run src/streamlit_app.py -- --category groceries --debug Run locally (mixed mode — movies + groceries): streamlit run src/streamlit_app.py -- --mode mixed streamlit run src/streamlit_app.py -- --mode mixed --debug On HuggingFace Spaces, set these environment variables in Space Settings → Variables: HF_TOKEN - HuggingFace token TINKER_API_KEY - Tinker AI API key DATASET_REPO_ID - HuggingFace dataset repo to upload results CATEGORY - groceries | books | movies | health (single-category mode) MODE - mixed (overrides CATEGORY; runs movies + groceries together) DEBUG_MODE - "true" to skip validation (optional) """ import csv import json import os import random import re import sys import tempfile import time import uuid from datetime import datetime from pathlib import Path import streamlit as st from dotenv import load_dotenv from filelock import FileLock from huggingface_hub import HfApi load_dotenv() # --------------------------------------------------------------------------- # CLI args # --------------------------------------------------------------------------- import argparse parser = argparse.ArgumentParser(add_help=False) parser.add_argument("--category", choices=["books", "groceries", "movies", "health"], default=None) parser.add_argument("--mode", choices=["mixed"], default=None) parser.add_argument("--debug", action="store_true", default=False) cli_args, _ = parser.parse_known_args() # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- MODE = os.getenv("MODE") or cli_args.mode # "mixed" or None CATEGORY = os.getenv("CATEGORY") or cli_args.category or "groceries" # used only in single-category mode DEBUG_MODE = os.getenv("DEBUG_MODE", "").lower() == "true" or cli_args.debug DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "your-username/product-study") HF_TOKEN = os.getenv("HF_TOKEN") TINKER_API_KEY = os.getenv("TINKER_API_KEY") MODEL_NAME = "openai/gpt-oss-20b" # --------------------------------------------------------------------------- # Mixed-mode constants # --------------------------------------------------------------------------- # In mixed mode these two categories are always used together MIXED_CATEGORIES = ["movies", "groceries"] # Each category contributes this many items to the shared pool of 100 MIXED_SUBSET_SIZE = 50 # 50 movies + 50 groceries = 100 total SINGLE_SUBSET_SIZE = 100 # legacy single-category mode # --------------------------------------------------------------------------- # Prolific config # --------------------------------------------------------------------------- PROLIFIC_COMPLETION_URL = "https://app.prolific.com/submissions/complete?cc=CYC7ALM1" PROLIFIC_COMPLETION_CODE = "CYC7ALM1" BASE_DIR = os.path.dirname(os.path.abspath(__file__)) DATA_DIR = os.path.join(BASE_DIR, "data") ANNOTATIONS_DIR = os.path.join(BASE_DIR, "annotations") os.makedirs(DATA_DIR, exist_ok=True) os.makedirs(ANNOTATIONS_DIR, exist_ok=True) CATEGORY_TO_HF = { "books": "ehejin/amazon_books", "groceries": "ehejin/amazon_Grocery_and_Gourmet_Food", "movies": "ehejin/amazon_Movies_and_TV", "health": "ehejin/amazon_Health_and_Household", } CATEGORY_DISPLAY = { "books": "Books", "groceries": "Grocery Products", "movies": "Movies & TV", "health": "Health & Household Products", } # Per-product familiarity label (depends on the individual product's category) FAMILIARITY_USED_LABEL = { "books": "Read it before", "movies": "Watched it before", "groceries": "Used it before", "health": "Used it before", } PRODUCTS_PER_USER = 5 MIN_TURNS = 3 MAX_TURNS = 10 # Familiarity values that trigger a product swap SWAP_FAMILIARITY = {"Purchased it before"} DEBUG_DEMOGRAPHICS = { "age": "30", "gender": "Female", "geographic_region": "West", "education_level": "College graduate/some postgrad", "race": "White", "us_citizen": "Yes", "marital_status": "Single", "religion": "Agnostic", "religious_attendance": "Never", "political_affiliation": "Independent", "income": "$50,000-$75,000", "political_views": "Moderate", "household_size": "2", "employment_status": "Full-time employment", } WILLINGNESS_LABELS = { 1: "Definitely would not buy", 2: "Probably would not buy", 3: "Slightly unlikely to buy", 4: "Neutral", 5: "Slightly likely to buy", 6: "Probably would buy", 7: "Definitely would buy", } WILLINGNESS_CHOICES = [f"{v} ({k})" for k, v in WILLINGNESS_LABELS.items()] # --------------------------------------------------------------------------- # Helpers: per-category file paths # --------------------------------------------------------------------------- def _data_path(category: str, suffix: str) -> str: subset = MIXED_SUBSET_SIZE if MODE == "mixed" else SINGLE_SUBSET_SIZE return os.path.join(DATA_DIR, f"{category}_test{subset}_{suffix}") def local_data_path(category: str) -> str: return _data_path(category, "primary.json") def overflow_path(category: str) -> str: return _data_path(category, "overflow.json") def counter_path(category: str) -> str: return _data_path(category, "counter.txt") def counter_lock_path(category: str) -> str: return _data_path(category, "counter.lock") def return_queue_path(category: str) -> str: return _data_path(category, "return_queue.json") # --------------------------------------------------------------------------- # Dataset loading # --------------------------------------------------------------------------- @st.cache_resource def download_and_cache_dataset(category: str, subset_size: int): """Download test split from HuggingFace and cache locally.""" primary_path = local_data_path(category) over_path = overflow_path(category) if os.path.exists(primary_path): print(f"[DATA] Found cached dataset for {category} at {primary_path}") return print(f"[DATA] Downloading {CATEGORY_TO_HF[category]} (test split, first {subset_size}) from HuggingFace...") try: from datasets import load_dataset import huggingface_hub if HF_TOKEN: huggingface_hub.login(token=HF_TOKEN) ds = load_dataset(CATEGORY_TO_HF[category], split="test") def to_list(val): if isinstance(val, list): return val if isinstance(val, str): return [val] if val else [] return [] all_items = [] for row in ds: meta = row.get("metadata", {}) item = { "id": str(uuid.uuid4()), "title": meta.get("title", "") if isinstance(meta, dict) else "", "description": to_list(meta.get("description", []) if isinstance(meta, dict) else []), "features": to_list(meta.get("features", []) if isinstance(meta, dict) else []), "price": meta.get("price", "N/A") if isinstance(meta, dict) else "N/A", "category": category, } all_items.append(item) primary = all_items[:subset_size] overflow = all_items[subset_size:] with open(primary_path, "w") as f: json.dump(primary, f, indent=2) with open(over_path, "w") as f: json.dump(overflow, f, indent=2) print(f"[DATA] {category}: cached {len(primary)} primary + {len(overflow)} overflow items.") except Exception as e: print(f"[DATA] ERROR downloading {category}: {e}") raise @st.cache_resource def load_primary_dataset(category: str): with open(local_data_path(category), "r") as f: return json.load(f) @st.cache_resource def load_overflow_dataset(category: str): path = overflow_path(category) if not os.path.exists(path): return [] with open(path, "r") as f: return json.load(f) def _ensure_datasets(): """Download/cache all needed category datasets.""" if MODE == "mixed": for cat in MIXED_CATEGORIES: download_and_cache_dataset(cat, MIXED_SUBSET_SIZE) else: download_and_cache_dataset(CATEGORY, SINGLE_SUBSET_SIZE) # --------------------------------------------------------------------------- # Per-category counter helpers # --------------------------------------------------------------------------- def _read_counter(category: str) -> int: path = counter_path(category) if not os.path.exists(path): return 0 with open(path, "r") as f: return int(f.read().strip() or "0") def _write_counter(category: str, value: int): with open(counter_path(category), "w") as f: f.write(str(value)) def _read_return_queue(category: str) -> list: path = return_queue_path(category) if not os.path.exists(path): return [] with open(path, "r") as f: try: return json.load(f) except Exception: return [] def _write_return_queue(category: str, queue: list): with open(return_queue_path(category), "w") as f: json.dump(queue, f) # --------------------------------------------------------------------------- # Product assignment # --------------------------------------------------------------------------- def _assign_from_category(category: str, n: int) -> list: """ Atomically assign n products from a single category pool. - Drains the return queue first. - Pulls sequentially from the primary pool. - Wraps around (modulo pool size) when exhausted so user 21+ still get valid items. """ items = load_primary_dataset(category) total = len(items) lock = FileLock(counter_lock_path(category)) with lock: return_queue = _read_return_queue(category) counter = _read_counter(category) assigned = [] for _ in range(n): if return_queue: assigned.append(return_queue.pop(0)) else: # Wrap-around: counter mod total so we cycle through items assigned.append(items[counter % total]) counter += 1 _write_return_queue(category, return_queue) _write_counter(category, counter) return assigned def assign_mixed_products(n: int = PRODUCTS_PER_USER) -> list: """ Assign n products split across movies and groceries. Alternates the majority category each call so coverage stays balanced. User 1: 3 movies + 2 groceries User 2: 2 movies + 3 groceries User 3: 3 movies + 2 groceries ... etc. The split is decided by reading the movies counter parity (even → movies gets 3). """ movies_counter = _read_counter("movies") # Even call-count → movies gets the larger share if (movies_counter // 1) % 2 == 0: n_movies, n_groceries = 3, 2 else: n_movies, n_groceries = 2, 3 # Clamp in case n != 5 if n_movies + n_groceries != n: n_movies = n // 2 n_groceries = n - n_movies movie_items = _assign_from_category("movies", n_movies) grocery_items = _assign_from_category("groceries", n_groceries) combined = movie_items + grocery_items random.shuffle(combined) # mix so user doesn't see all movies then all groceries return combined def assign_products(n: int = PRODUCTS_PER_USER) -> list: """Dispatcher: mixed mode or single-category mode.""" if MODE == "mixed": return assign_mixed_products(n) # Single-category (legacy behaviour) return _assign_from_category(CATEGORY, n) def return_product_to_queue(product: dict): """Put a rejected/swapped product back so it gets reassigned.""" cat = product.get("category", CATEGORY) lock = FileLock(counter_lock_path(cat)) with lock: queue = _read_return_queue(cat) if not any(p["id"] == product["id"] for p in queue): queue.append(product) _write_return_queue(cat, queue) def get_swap_product(exclude_ids: set, category: str) -> dict | None: """ Get a replacement product for the given category. 1. Next unassigned primary product (advances counter). 2. Wrap-around: any primary product not held by this user. 3. Overflow pool. """ items = load_primary_dataset(category) overflow = load_overflow_dataset(category) total = len(items) lock = FileLock(counter_lock_path(category)) with lock: counter = _read_counter(category) # 1. Unassigned (with wrap-around awareness) attempts = 0 while attempts < total: candidate = items[counter % total] counter += 1 attempts += 1 if candidate["id"] not in exclude_ids: _write_counter(category, counter) return candidate # 2. Any primary product not held by this user for p in items: if p["id"] not in exclude_ids: return p # 3. Overflow for p in overflow: if p["id"] not in exclude_ids: return p return None # --------------------------------------------------------------------------- # AI client (Tinker) # --------------------------------------------------------------------------- @st.cache_resource def get_tinker_clients(): """Initialise and cache Tinker sampling client, renderer, and tokenizer.""" import tinker from tinker import types as tinker_types from tinker_cookbook import renderers from tinker_cookbook.tokenizer_utils import get_tokenizer from tinker_cookbook.model_info import get_recommended_renderer_name service_client = tinker.ServiceClient() sampling_client = service_client.create_sampling_client(base_model=MODEL_NAME) tokenizer = get_tokenizer(MODEL_NAME) renderer_name = get_recommended_renderer_name(MODEL_NAME) renderer = renderers.get_renderer(renderer_name, tokenizer) return sampling_client, renderer, tinker_types def call_model(messages: list) -> str: try: from tinker_cookbook import renderers as tinker_renderers sampling_client, renderer, tinker_types = get_tinker_clients() prompt = renderer.build_generation_prompt(messages) params = tinker_types.SamplingParams( max_tokens=1000, temperature=0.7, stop=renderer.get_stop_sequences(), ) result = sampling_client.sample( prompt=prompt, sampling_params=params, num_samples=1, ).result() parsed_message, _ = renderer.parse_response(result.sequences[0].tokens) content = tinker_renderers.format_content_as_string(parsed_message["content"]) content = re.sub(r".*?", "", content, flags=re.DOTALL).strip() return content except Exception as e: print(f"[MODEL] Tinker error: {e}") return f"[Model error: {e}]" # --------------------------------------------------------------------------- # HuggingFace upload # --------------------------------------------------------------------------- @st.cache_resource def get_hf_api(): api = HfApi(token=HF_TOKEN) if HF_TOKEN else HfApi() if HF_TOKEN: try: api.repo_info(repo_id=DATASET_REPO_ID, repo_type="dataset") print(f"[HF] Repo {DATASET_REPO_ID} exists.") except Exception as e: if "404" in str(e) or "not found" in str(e).lower(): api.create_repo(repo_id=DATASET_REPO_ID, repo_type="dataset", private=True) print(f"[HF] Created repo {DATASET_REPO_ID}.") else: print(f"[HF] WARNING: {e}") return api def save_and_upload(state: dict): hf_api = get_hf_api() worker_id = state.get("prolific_pid") or state.get("user_id", "anonymous") submission_id = state.get("submission_id", str(uuid.uuid4())) safe_worker = "".join(c if c.isalnum() else "_" for c in str(worker_id)) mode_tag = state.get("mode", "single") filename = f"{submission_id}_{mode_tag}.json" folder = os.path.join(ANNOTATIONS_DIR, safe_worker) os.makedirs(folder, exist_ok=True) file_path = os.path.join(folder, filename) with open(file_path, "w") as f: json.dump(state, f, indent=2) print(f"[SAVE] Wrote {file_path}") if HF_TOKEN: try: hf_api.upload_file( path_or_fileobj=file_path, path_in_repo=f"{safe_worker}/{filename}", repo_id=DATASET_REPO_ID, repo_type="dataset", ) print("[HF] Uploaded JSON.") except Exception as e: print(f"[HF] JSON upload error: {e}") upload_csv_rows(state, hf_api, safe_worker, submission_id) def upload_csv_rows(state: dict, hf_api, safe_worker: str, submission_id: str): demographics = state.get("demographics", {}) products = state.get("products", []) header = [ "submission_id", "prolific_pid", "study_id", "session_id", "submission_time", "duration_seconds", "mode", "category", "age", "gender", "geographic_region", "education_level", "race", "us_citizen", "marital_status", "religion", "religious_attendance", "political_affiliation", "income", "political_views", "household_size", "employment_status", "product_index", "product_id", "title", "price", "familiarity", "pre_willingness", "pre_willingness_label", "post_willingness", "post_willingness_label", "willingness_delta", "num_turns", "conversation_json", "standout_moment", "thinking_change", "was_swapped", ] rows = [] for i, prod in enumerate(products): conv = prod.get("conversation", {}) refl = prod.get("reflection", {}) pre = prod.get("pre_willingness", "") post = prod.get("post_willingness", "") delta = (post - pre) if isinstance(pre, int) and isinstance(post, int) else "" row = [ submission_id, state.get("prolific_pid", ""), state.get("study_id", ""), state.get("session_id", ""), state.get("meta", {}).get("submission_time", ""), state.get("meta", {}).get("duration_seconds", ""), state.get("mode", "single"), prod.get("category", ""), # per-product category demographics.get("age", ""), demographics.get("gender", ""), demographics.get("geographic_region", ""), demographics.get("education_level", ""), demographics.get("race", ""), demographics.get("us_citizen", ""), demographics.get("marital_status", ""), demographics.get("religion", ""), demographics.get("religious_attendance", ""), demographics.get("political_affiliation", ""), demographics.get("income", ""), demographics.get("political_views", ""), demographics.get("household_size", ""), demographics.get("employment_status", ""), i + 1, prod.get("id", ""), prod.get("title", ""), prod.get("price", ""), prod.get("familiarity", ""), pre, WILLINGNESS_LABELS.get(pre, "") if isinstance(pre, int) else "", post, WILLINGNESS_LABELS.get(post, "") if isinstance(post, int) else "", delta, conv.get("num_turns", 0), json.dumps(conv.get("turns", [])), refl.get("standout_moment", ""), refl.get("thinking_change", ""), prod.get("was_swapped", False), ] rows.append(row) timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") unique_id = uuid.uuid4().hex[:8] csv_filename = f"csv_submissions/{timestamp_str}_{safe_worker}_{unique_id}.csv" with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False, newline="", encoding="utf-8") as tmp: tmp_path = tmp.name writer = csv.writer(tmp) writer.writerow(header) writer.writerows(rows) if HF_TOKEN: try: hf_api.upload_file( path_or_fileobj=tmp_path, path_in_repo=csv_filename, repo_id=DATASET_REPO_ID, repo_type="dataset", ) print("[HF] Uploaded CSV rows.") except Exception as e: print(f"[HF] CSV upload error: {e}") os.unlink(tmp_path) # --------------------------------------------------------------------------- # Prompt builders # --------------------------------------------------------------------------- def build_sales_system_prompt(product: dict) -> str: title = product.get("title", "this product") price = product.get("price", "N/A") description = product.get("description", []) features = product.get("features", []) desc_text = " ".join(description) if description else "No description available." feat_text = " ".join(features) if features else "" price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price return f"""You are a warm, knowledgeable sales agent for an Amazon product. Your goal is to convince the user to buy this product. Product name: {title} Product description and features: {desc_text} {feat_text} Price: {price_str} You need to convince the user to buy it. First message rules: - In ONE paragraph: briefly highlight the product's best quality, explain why it's worth buying, and hit them with the strongest benefit - End with an engaging question that draws out their interest or hesitation Follow-up message rules: - In ONE paragraph: acknowledge what they said, address any concerns directly with a concrete benefit or reassurance, end with a question - Use their words against hesitation: if they say it's expensive, talk value; if they doubt quality, cite a feature - Vary your tactics: sometimes appeal to emotion (convenience, joy), sometimes to reason (value, quality) - Use "imagine if..." scenarios to make benefits concrete General style: - Be warm, confident, and conversational — like a helpful friend who knows the product well, not a pushy salesperson - End your messages with an engaging question - Never fabricate statistics, details, or reviews you don't have - Never make up a price different from the one given """ def build_opening_user_message(product: dict) -> str: return f'Tell me about this product and why I should buy it: "{product.get("title", "this product")}"' def parse_willingness(choice_str: str) -> int: try: return int(choice_str.split("(")[1].rstrip(")")) except Exception: return 4 def get_familiarity_choices(category: str) -> list: """Return familiarity options with the correct 'used' label for this product's category.""" used_label = FAMILIARITY_USED_LABEL.get(category, "Used it before") return [ "Never heard of it", "Heard of it, but not used/purchased", used_label, "Purchased it before", ] def needs_swap(familiarity_val: str, pre_will_val: str) -> bool: if familiarity_val in SWAP_FAMILIARITY: return True if pre_will_val == WILLINGNESS_CHOICES[-1]: # "Definitely would buy (7)" return True return False # --------------------------------------------------------------------------- # Welcome screen helpers # --------------------------------------------------------------------------- def study_display_name() -> str: """Human-readable name for what the user will evaluate.""" if MODE == "mixed": return "Movies & TV and Grocery Products" return CATEGORY_DISPLAY.get(CATEGORY, CATEGORY) def study_category_breakdown() -> str: """Extra sentence shown on welcome screen describing the mix.""" if MODE == "mixed": return ( "You will evaluate a mix of **Movies & TV** and **Grocery Products** " "(roughly 2–3 of each)." ) return "" # --------------------------------------------------------------------------- # State initialisation # --------------------------------------------------------------------------- def make_product_slot(p: dict, was_swapped: bool = False) -> dict: return { "id": p.get("id", str(uuid.uuid4())), "title": p.get("title", ""), "description": p.get("description", []), "features": p.get("features", []), "price": p.get("price", "N/A"), "category": p.get("category", CATEGORY), # ← per-product category "familiarity": None, "pre_willingness": None, "post_willingness": None, "willingness_delta": None, "was_swapped": was_swapped, "conversation": { "system_prompt": "", "opening_user_message": "", "turns": [], "num_turns": 0, }, "reflection": {}, } def init_state(): _ensure_datasets() assigned = assign_products(PRODUCTS_PER_USER) try: params = st.query_params except Exception: params = {} return { "submission_id": str(uuid.uuid4()), "user_id": str(uuid.uuid4()), "prolific_pid": params.get("PROLIFIC_PID", ""), "study_id": params.get("STUDY_ID", ""), "session_id": params.get("SESSION_ID", ""), "start_time": time.time(), "mode": MODE or "single", "category": CATEGORY if MODE != "mixed" else "mixed", "demographics": {}, "products": [make_product_slot(p) for p in assigned], "current_product_index": 0, "screen": "welcome", "meta": {}, } # --------------------------------------------------------------------------- # CSS # --------------------------------------------------------------------------- def inject_css(): st.markdown(""" """, unsafe_allow_html=True) # --------------------------------------------------------------------------- # UI helpers # --------------------------------------------------------------------------- def render_product_card_html(product: dict, compact: bool = False) -> str: title = product.get("title", "Unknown Product") price = product.get("price", "N/A") description = product.get("description", []) features = product.get("features", []) category = product.get("category", "") price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price # Category badge — only shown in mixed mode badge_html = "" if MODE == "mixed" and category: badge_label = CATEGORY_DISPLAY.get(category, category) badge_html = f'
📂 {badge_label}
' desc_html = "" if description: desc_text = " ".join(d for d in description if d) desc_html = f'
📋 Description
{desc_text}
' feat_html = "" if features: items_html = "".join(f"
  • {feat}
  • " for feat in features if feat) feat_html = f'
    ✨ Features
    ' max_h = "max-height:240px;overflow-y:auto;" if compact else "" return f"""
    {badge_html}
    {title}
    {price_str}
    {desc_html} {feat_html}
    """ def render_progress(current: int, total: int = PRODUCTS_PER_USER): pct = int((current / total) * 100) st.markdown(f"""
    Product {current} of {total}
    """, unsafe_allow_html=True) def render_chat_history(turns: list): html = '
    ' for turn in turns: role = turn.get("role", "") content = turn.get("content", "") if role == "assistant": html += f'
    🤖 AI Sales Agent
    {content}
    ' elif role == "user": html += f'
    You
    {content}
    ' html += "
    " st.markdown(html, unsafe_allow_html=True) # --------------------------------------------------------------------------- # Screen renderers # --------------------------------------------------------------------------- def screen_welcome(s): st.markdown("# 🛒 Product Evaluation Study") breakdown = study_category_breakdown() st.markdown( f"Welcome! In this study you will evaluate **{PRODUCTS_PER_USER} {study_display_name()}** products.\n\n" + (f"{breakdown}\n\n" if breakdown else "") + "For each product you will:\n" "1. Rate how familiar you are with the product\n" "2. Rate how willing you are to buy it\n" "3. Chat with an AI about the product (**at least 3 exchanges**)\n" "4. Rate your willingness to buy it again\n" "5. Answer two brief reflection questions\n\n" "After all 5 products, you're done! The study takes about **20–30 minutes**. " "Thank you for participating!" ) if st.button("Begin →", type="primary", use_container_width=True): if DEBUG_MODE: s["demographics"] = DEBUG_DEMOGRAPHICS.copy() s["screen"] = "product_intro" else: s["screen"] = "demographics" st.rerun() def screen_demographics(s): st.markdown("## Demographics — About You") st.markdown("All fields are required before you can proceed.") age = st.text_input("Age (years)", placeholder="e.g. 34") gender = st.selectbox("Gender", ["", "Female", "Male"]) geographic_region = st.selectbox("Geographic region", ["", "West", "South", "Midwest", "Northeast", "Pacific"]) education_level = st.selectbox("Highest education level", [ "", "Less than high school", "High school graduate", "Some college, no degree", "Associate's degree", "College graduate/some postgrad", "Postgraduate", ]) race = st.selectbox("Race / ethnicity", ["", "Asian", "Hispanic", "White", "Black", "Other"]) us_citizen = st.selectbox("Are you a U.S. citizen?", ["", "Yes", "No"]) marital_status = st.selectbox("Marital status", [ "", "Never been married", "Married", "Living with a partner", "Divorced", "Separated", "Widowed", ]) religion = st.selectbox("Religion", [ "", "Protestant", "Roman Catholic", "Mormon", "Orthodox", "Jewish", "Muslim", "Buddhist", "Atheist", "Agnostic", "Nothing in particular", "Other", ]) religious_attendance = st.selectbox("How often do you attend religious services?", [ "", "Never", "Seldom", "A few times a year", "Once or twice a month", "Once a week", "More than once a week", ]) political_affiliation = st.selectbox("Political affiliation", [ "", "Democrat", "Republican", "Independent", "Something else", ]) income = st.selectbox("Household income", [ "", "Less than $30,000", "$30,000-$50,000", "$50,000-$75,000", "$75,000-$100,000", "$100,000 or more", ]) political_views = st.selectbox("Political views", [ "", "Very liberal", "Liberal", "Moderate", "Conservative", "Very conservative", ]) household_size = st.selectbox("Household size", ["", "1", "2", "3", "4", "More than 4"]) employment_status = st.selectbox("Employment status", [ "", "Full-time employment", "Part-time employment", "Self-employed", "Unemployed", "Retired", "Home-maker", "Student", ]) if st.button("Next →", type="primary", use_container_width=True): fields = [age, gender, geographic_region, education_level, race, us_citizen, marital_status, religion, religious_attendance, political_affiliation, income, political_views, household_size, employment_status] if not all([f and (f.strip() if isinstance(f, str) else f) for f in fields]): st.error("⚠️ Please complete all fields.") return if not age.strip().isdigit() or not (1 <= int(age.strip()) <= 120): st.error("⚠️ Please enter a valid age.") return s["demographics"] = { "age": age.strip(), "gender": gender, "geographic_region": geographic_region, "education_level": education_level, "race": race, "us_citizen": us_citizen, "marital_status": marital_status, "religion": religion, "religious_attendance": religious_attendance, "political_affiliation": political_affiliation, "income": income, "political_views": political_views, "household_size": household_size, "employment_status": employment_status, } s["screen"] = "product_intro" st.rerun() def screen_product_intro(s): idx = s["current_product_index"] product = s["products"][idx] product_category = product.get("category", CATEGORY) render_progress(idx + 1) st.markdown("## Product Evaluation") st.markdown("Please read the product information carefully, then answer the two questions below.") st.markdown(render_product_card_html(product), unsafe_allow_html=True) # Use per-product familiarity choices based on the product's own category familiarity_choices = get_familiarity_choices(product_category) familiarity_val = st.radio( "How familiar are you with this product?", familiarity_choices, index=None, key=f"familiarity_{idx}_{product['id']}", ) pre_will_val = st.radio( "How willing would you be to buy this product?", WILLINGNESS_CHOICES, index=None, key=f"pre_will_{idx}_{product['id']}", ) if st.button("Start Chat →", type="primary", use_container_width=True): if not DEBUG_MODE: if not familiarity_val: st.error("⚠️ Please rate your familiarity.") return if not pre_will_val: st.error("⚠️ Please rate your willingness to buy.") return familiarity_val = familiarity_val or familiarity_choices[0] pre_will_val = pre_will_val or WILLINGNESS_CHOICES[3] # Check if we need to swap this product if needs_swap(familiarity_val, pre_will_val) and not DEBUG_MODE: current_ids = {p["id"] for p in s["products"]} replacement = get_swap_product(exclude_ids=current_ids, category=product_category) if replacement: return_product_to_queue(s["products"][idx]) s["products"][idx] = make_product_slot(replacement, was_swapped=True) st.info("We've swapped this product for a better match. Please review the new product below.") st.rerun() return # No replacement found — proceed with this product anyway pre_val = parse_willingness(pre_will_val) s["products"][idx]["familiarity"] = familiarity_val s["products"][idx]["pre_willingness"] = pre_val s["products"][idx]["pre_willingness_label"] = WILLINGNESS_LABELS[pre_val] system_prompt = build_sales_system_prompt(product) opening_user_msg = build_opening_user_message(product) messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": opening_user_msg}, ] with st.spinner("Starting conversation…"): ai_reply = call_model(messages) s["products"][idx]["conversation"]["system_prompt"] = system_prompt s["products"][idx]["conversation"]["opening_user_message"] = opening_user_msg s["products"][idx]["conversation"]["turns"] = [ {"turn_index": 0, "role": "assistant", "content": ai_reply, "timestamp": time.time(), "model": MODEL_NAME} ] s["products"][idx]["conversation"]["num_turns"] = 0 s["screen"] = "chat" st.rerun() def screen_chat(s): idx = s["current_product_index"] product = s["products"][idx] conv = s["products"][idx]["conversation"] render_progress(idx + 1) st.markdown("## Chat with the AI") title = product.get("title", "Product") price = product.get("price", "N/A") price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price with st.expander(f"📦 {title} — {price_str} (click to expand product details)"): st.markdown(render_product_card_html(product, compact=True), unsafe_allow_html=True) num_turns = conv["num_turns"] st.markdown( f"Chat with the AI about whether you'd like to purchase the product. " f"Ask questions, push back, or explore your interest. " f"You need at least **{MIN_TURNS} exchanges** before you can move on." ) display_turns = [t for t in conv["turns"] if t["role"] in ("user", "assistant")] render_chat_history(display_turns) if num_turns >= MAX_TURNS: st.info(f"Maximum turns ({MAX_TURNS}) reached. Please proceed.") else: st.caption(f"Turns: {num_turns} / minimum {MIN_TURNS}") st.caption("💡 If you don't see the latest messages, scroll down while hovering over the conversation.") if num_turns < MAX_TURNS: user_msg = st.text_area( "Your response:", placeholder="Type your response here…", height=100, key=f"chat_input_{idx}_{num_turns}", ) col1, col2 = st.columns([3, 1]) with col2: send_clicked = st.button("Send", type="primary", use_container_width=True) if send_clicked: if not user_msg or not user_msg.strip(): st.error("⚠️ Please type a message.") return if len(user_msg.strip().split()) < 5 and not DEBUG_MODE: st.error(f"⚠️ Please write at least 5 words ({len(user_msg.strip().split())} so far).") return user_msg = user_msg.strip() messages = [ {"role": "system", "content": conv["system_prompt"]}, {"role": "user", "content": conv["opening_user_message"]}, ] for turn in conv["turns"]: messages.append({"role": turn["role"], "content": turn["content"]}) messages.append({"role": "user", "content": user_msg}) with st.spinner("AI is responding…"): ai_reply = call_model(messages) conv["turns"].append({"turn_index": len(conv["turns"]), "role": "user", "content": user_msg, "timestamp": time.time()}) conv["turns"].append({"turn_index": len(conv["turns"]), "role": "assistant", "content": ai_reply, "timestamp": time.time(), "model": MODEL_NAME}) conv["num_turns"] = num_turns + 1 s["products"][idx]["conversation"] = conv st.rerun() can_finish = num_turns >= MIN_TURNS or num_turns >= MAX_TURNS or DEBUG_MODE if can_finish: if st.button("I'm done chatting →", use_container_width=True): s["screen"] = "post_will" st.rerun() else: st.button("I'm done chatting →", disabled=True, use_container_width=True, help=f"Complete at least {MIN_TURNS} exchanges first.") def screen_post_willingness(s): idx = s["current_product_index"] product = s["products"][idx] render_progress(idx + 1) st.markdown("## Your View Now") st.markdown("Now that you've chatted with the AI, rate your willingness to buy again.") st.markdown(render_product_card_html(product), unsafe_allow_html=True) post_will_val = st.radio( "How willing would you be to buy this product now?", WILLINGNESS_CHOICES, index=None, key=f"post_will_{idx}_{product['id']}", ) if st.button("Next →", type="primary", use_container_width=True): if not post_will_val and not DEBUG_MODE: st.error("⚠️ Please rate your willingness to buy.") return post_will_val = post_will_val or WILLINGNESS_CHOICES[3] post_val = parse_willingness(post_will_val) pre_val = s["products"][idx].get("pre_willingness", 4) delta = post_val - pre_val s["products"][idx]["post_willingness"] = post_val s["products"][idx]["post_willingness_label"] = WILLINGNESS_LABELS[post_val] s["products"][idx]["willingness_delta"] = delta s["screen"] = "reflection" st.rerun() def screen_reflection(s): idx = s["current_product_index"] render_progress(idx + 1) st.markdown("## Reflection") standout = st.text_area( "What did the AI say that stood out to you most?", placeholder="Describe a specific argument, question, or moment from the conversation…", height=120, key=f"standout_{idx}", ) thinking_change = st.text_area( "How did your thinking about this product change (or not change) during the chat? Why?", placeholder="Be as specific as you can…", height=120, key=f"thinking_{idx}", ) next_label = "Next Product →" if idx + 1 < PRODUCTS_PER_USER else "Submit Study →" if st.button(next_label, type="primary", use_container_width=True): if not DEBUG_MODE: if not standout or not standout.strip(): st.error("⚠️ Please answer the first reflection question.") return if len(standout.strip().split()) < 10: st.error(f"⚠️ Please write at least 10 words for the first question ({len(standout.strip().split())} so far).") return if not thinking_change or not thinking_change.strip(): st.error("⚠️ Please answer the second reflection question.") return if len(thinking_change.strip().split()) < 10: st.error(f"⚠️ Please write at least 10 words for the second question ({len(thinking_change.strip().split())} so far).") return standout = (standout or "").strip() or "[debug placeholder]" thinking_change = (thinking_change or "").strip() or "[debug placeholder]" s["products"][idx]["reflection"] = { "standout_moment": standout, "thinking_change": thinking_change, } next_idx = idx + 1 s["current_product_index"] = next_idx if next_idx >= PRODUCTS_PER_USER: end_time = time.time() s["meta"] = { "submission_time": end_time, "duration_seconds": round(end_time - s.get("start_time", end_time), 1), "model": MODEL_NAME, "mode": MODE or "single", "category": CATEGORY if MODE != "mixed" else "mixed", } with st.spinner("Saving your responses…"): save_and_upload(s) s["screen"] = "done" else: s["screen"] = "product_intro" st.rerun() def screen_done(s): st.markdown("## ✅ Study Complete!") st.markdown("**Thank you for completing the study!**") st.markdown(f"Here's a summary of how your willingness changed across the {PRODUCTS_PER_USER} products:") rows = [] for i, p in enumerate(s["products"]): pre = p.get("pre_willingness", "?") post = p.get("post_willingness", "?") delta = p.get("willingness_delta", 0) arrow = "➡️" if delta == 0 else ("⬆️" if delta > 0 else "⬇️") cat_label = CATEGORY_DISPLAY.get(p.get("category", ""), "") if MODE == "mixed" else "" rows.append({ "#": i + 1, **({"Category": cat_label} if MODE == "mixed" else {}), "Product": p.get("title", "")[:55] + ("…" if len(p.get("title", "")) > 55 else ""), "Before": WILLINGNESS_LABELS.get(pre, str(pre)), "After": WILLINGNESS_LABELS.get(post, str(post)), "Change": f"{arrow} {delta:+d}" if isinstance(delta, int) else "–", }) import pandas as pd st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True) st.markdown("---") st.success( f"**Your completion code:** `{PROLIFIC_COMPLETION_CODE}`\n\n" "Please copy this code and paste it on the Prolific website to complete your submission." ) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): st.set_page_config(page_title="Product Study", page_icon="🛒", layout="centered") inject_css() if "study_state" not in st.session_state: st.session_state.study_state = init_state() s = st.session_state.study_state screen = s.get("screen", "welcome") if screen == "welcome": screen_welcome(s) elif screen == "demographics": screen_demographics(s) elif screen == "product_intro": screen_product_intro(s) elif screen == "chat": screen_chat(s) elif screen == "post_will": screen_post_willingness(s) elif screen == "reflection": screen_reflection(s) elif screen == "done": screen_done(s) if __name__ == "__main__": main()