Spaces:
Running
Running
changed to tinker api and merged movies and groceries
Browse files- Dockerfile +1 -1
- src/streamlit_app.py +335 -193
Dockerfile
CHANGED
|
@@ -17,4 +17,4 @@ EXPOSE 8501
|
|
| 17 |
|
| 18 |
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 19 |
|
| 20 |
-
ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
|
|
|
| 17 |
|
| 18 |
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 19 |
|
| 20 |
+
ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0", "--", "--mode", "mixed"]
|
src/streamlit_app.py
CHANGED
|
@@ -1,21 +1,23 @@
|
|
| 1 |
"""
|
| 2 |
Streamlit App: AI Product Willingness User Study
|
| 3 |
=================================================
|
| 4 |
-
Run locally:
|
| 5 |
streamlit run src/streamlit_app.py -- --category groceries
|
| 6 |
streamlit run src/streamlit_app.py -- --category groceries --debug
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
On HuggingFace Spaces, set these environment variables in Space Settings → Variables:
|
| 9 |
HF_TOKEN - HuggingFace token
|
| 10 |
TINKER_API_KEY - Tinker AI API key
|
| 11 |
-
TINKER_MODEL_PATH - Tinker sampler checkpoint path
|
| 12 |
DATASET_REPO_ID - HuggingFace dataset repo to upload results
|
| 13 |
-
CATEGORY - groceries | books | movies | health (
|
|
|
|
| 14 |
DEBUG_MODE - "true" to skip validation (optional)
|
| 15 |
"""
|
| 16 |
|
| 17 |
-
import asyncio
|
| 18 |
-
import concurrent.futures
|
| 19 |
import csv
|
| 20 |
import json
|
| 21 |
import os
|
|
@@ -32,33 +34,45 @@ import streamlit as st
|
|
| 32 |
from dotenv import load_dotenv
|
| 33 |
from filelock import FileLock
|
| 34 |
from huggingface_hub import HfApi
|
| 35 |
-
from openai import AsyncOpenAI
|
| 36 |
|
| 37 |
load_dotenv()
|
| 38 |
|
| 39 |
# ---------------------------------------------------------------------------
|
| 40 |
-
# CLI args
|
| 41 |
# ---------------------------------------------------------------------------
|
| 42 |
import argparse
|
| 43 |
parser = argparse.ArgumentParser(add_help=False)
|
| 44 |
parser.add_argument("--category", choices=["books", "groceries", "movies", "health"], default=None)
|
|
|
|
| 45 |
parser.add_argument("--debug", action="store_true", default=False)
|
| 46 |
cli_args, _ = parser.parse_known_args()
|
| 47 |
|
| 48 |
# ---------------------------------------------------------------------------
|
| 49 |
-
# Config
|
| 50 |
# ---------------------------------------------------------------------------
|
| 51 |
-
|
|
|
|
| 52 |
DEBUG_MODE = os.getenv("DEBUG_MODE", "").lower() == "true" or cli_args.debug
|
| 53 |
DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "your-username/product-study")
|
| 54 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 64 |
DATA_DIR = os.path.join(BASE_DIR, "data")
|
|
@@ -67,28 +81,28 @@ os.makedirs(DATA_DIR, exist_ok=True)
|
|
| 67 |
os.makedirs(ANNOTATIONS_DIR, exist_ok=True)
|
| 68 |
|
| 69 |
CATEGORY_TO_HF = {
|
| 70 |
-
"books":
|
| 71 |
"groceries": "ehejin/amazon_Grocery_and_Gourmet_Food",
|
| 72 |
-
"movies":
|
| 73 |
-
"health":
|
| 74 |
}
|
| 75 |
CATEGORY_DISPLAY = {
|
| 76 |
-
"books":
|
| 77 |
"groceries": "Grocery Products",
|
| 78 |
-
"movies":
|
| 79 |
-
"health":
|
| 80 |
}
|
|
|
|
| 81 |
FAMILIARITY_USED_LABEL = {
|
| 82 |
-
"books":
|
| 83 |
-
"movies":
|
| 84 |
"groceries": "Used it before",
|
| 85 |
-
"health":
|
| 86 |
}
|
| 87 |
|
| 88 |
PRODUCTS_PER_USER = 5
|
| 89 |
MIN_TURNS = 3
|
| 90 |
MAX_TURNS = 10
|
| 91 |
-
TEST_SUBSET_SIZE = 100 # only use first 100 items from test split
|
| 92 |
|
| 93 |
# Familiarity values that trigger a product swap
|
| 94 |
SWAP_FAMILIARITY = {"Purchased it before"}
|
|
@@ -114,32 +128,50 @@ WILLINGNESS_LABELS = {
|
|
| 114 |
}
|
| 115 |
WILLINGNESS_CHOICES = [f"{v} ({k})" for k, v in WILLINGNESS_LABELS.items()]
|
| 116 |
|
|
|
|
| 117 |
# ---------------------------------------------------------------------------
|
| 118 |
-
#
|
| 119 |
# ---------------------------------------------------------------------------
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
|
|
|
|
|
|
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
@st.cache_resource
|
| 130 |
-
def download_and_cache_dataset():
|
| 131 |
-
"""Download test split
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
| 134 |
return
|
| 135 |
-
print(f"[DATA] Downloading {CATEGORY_TO_HF[
|
| 136 |
try:
|
| 137 |
from datasets import load_dataset
|
| 138 |
import huggingface_hub
|
| 139 |
if HF_TOKEN:
|
| 140 |
huggingface_hub.login(token=HF_TOKEN)
|
| 141 |
|
| 142 |
-
ds = load_dataset(CATEGORY_TO_HF[
|
| 143 |
|
| 144 |
def to_list(val):
|
| 145 |
if isinstance(val, list): return val
|
|
@@ -155,183 +187,244 @@ def download_and_cache_dataset():
|
|
| 155 |
"description": to_list(meta.get("description", []) if isinstance(meta, dict) else []),
|
| 156 |
"features": to_list(meta.get("features", []) if isinstance(meta, dict) else []),
|
| 157 |
"price": meta.get("price", "N/A") if isinstance(meta, dict) else "N/A",
|
| 158 |
-
"category":
|
| 159 |
}
|
| 160 |
all_items.append(item)
|
| 161 |
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
overflow = all_items[TEST_SUBSET_SIZE:]
|
| 165 |
|
| 166 |
-
with open(
|
| 167 |
json.dump(primary, f, indent=2)
|
| 168 |
-
with open(
|
| 169 |
json.dump(overflow, f, indent=2)
|
| 170 |
|
| 171 |
-
print(f"[DATA]
|
| 172 |
except Exception as e:
|
| 173 |
-
print(f"[DATA] ERROR downloading
|
| 174 |
raise
|
| 175 |
|
| 176 |
|
| 177 |
@st.cache_resource
|
| 178 |
-
def load_primary_dataset():
|
| 179 |
-
with open(
|
| 180 |
return json.load(f)
|
| 181 |
|
| 182 |
|
| 183 |
@st.cache_resource
|
| 184 |
-
def load_overflow_dataset():
|
| 185 |
-
|
|
|
|
| 186 |
return []
|
| 187 |
-
with open(
|
| 188 |
return json.load(f)
|
| 189 |
|
| 190 |
|
| 191 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
"""
|
| 193 |
-
Atomically assign
|
| 194 |
-
Drains the return queue first
|
| 195 |
-
|
| 196 |
-
|
| 197 |
"""
|
| 198 |
-
items = load_primary_dataset()
|
| 199 |
total = len(items)
|
| 200 |
-
lock = FileLock(
|
| 201 |
-
with lock:
|
| 202 |
-
# Load return queue
|
| 203 |
-
return_queue = []
|
| 204 |
-
if os.path.exists(RETURN_QUEUE_PATH):
|
| 205 |
-
with open(RETURN_QUEUE_PATH, "r") as f:
|
| 206 |
-
try:
|
| 207 |
-
return_queue = json.load(f)
|
| 208 |
-
except Exception:
|
| 209 |
-
return_queue = []
|
| 210 |
-
|
| 211 |
-
# Load counter
|
| 212 |
-
counter = 0
|
| 213 |
-
if os.path.exists(COUNTER_PATH):
|
| 214 |
-
with open(COUNTER_PATH, "r") as f:
|
| 215 |
-
counter = int(f.read().strip() or "0")
|
| 216 |
|
|
|
|
|
|
|
|
|
|
| 217 |
assigned = []
|
|
|
|
| 218 |
for _ in range(n):
|
| 219 |
if return_queue:
|
| 220 |
-
# Prioritise returned products so they still get reviewed
|
| 221 |
assigned.append(return_queue.pop(0))
|
| 222 |
-
elif counter < total:
|
| 223 |
-
assigned.append(items[counter])
|
| 224 |
-
counter += 1
|
| 225 |
else:
|
| 226 |
-
#
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
# Persist state
|
| 233 |
-
with open(RETURN_QUEUE_PATH, "w") as f:
|
| 234 |
-
json.dump(return_queue, f)
|
| 235 |
-
with open(COUNTER_PATH, "w") as f:
|
| 236 |
-
f.write(str(counter))
|
| 237 |
|
| 238 |
return assigned
|
| 239 |
|
| 240 |
|
| 241 |
-
def
|
| 242 |
"""
|
| 243 |
-
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
"""
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
with lock:
|
| 248 |
-
queue =
|
| 249 |
-
if os.path.exists(RETURN_QUEUE_PATH):
|
| 250 |
-
with open(RETURN_QUEUE_PATH, "r") as f:
|
| 251 |
-
try:
|
| 252 |
-
queue = json.load(f)
|
| 253 |
-
except Exception:
|
| 254 |
-
queue = []
|
| 255 |
-
# Avoid duplicates
|
| 256 |
if not any(p["id"] == product["id"] for p in queue):
|
| 257 |
queue.append(product)
|
| 258 |
-
|
| 259 |
-
json.dump(queue, f)
|
| 260 |
|
| 261 |
|
| 262 |
-
def get_swap_product(exclude_ids: set) -> dict | None:
|
| 263 |
"""
|
| 264 |
-
Get
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
"""
|
| 269 |
-
items
|
| 270 |
-
overflow = load_overflow_dataset()
|
| 271 |
-
total
|
| 272 |
|
| 273 |
-
lock = FileLock(
|
| 274 |
with lock:
|
| 275 |
-
counter =
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
while counter < total:
|
| 282 |
-
candidate = items[counter]
|
| 283 |
counter += 1
|
|
|
|
| 284 |
if candidate["id"] not in exclude_ids:
|
| 285 |
-
|
| 286 |
-
with open(COUNTER_PATH, "w") as f:
|
| 287 |
-
f.write(str(counter))
|
| 288 |
return candidate
|
| 289 |
|
| 290 |
-
# 2.
|
| 291 |
for p in items:
|
| 292 |
if p["id"] not in exclude_ids:
|
| 293 |
return p
|
| 294 |
|
| 295 |
-
# 3.
|
| 296 |
for p in overflow:
|
| 297 |
if p["id"] not in exclude_ids:
|
| 298 |
return p
|
| 299 |
|
| 300 |
-
return None
|
| 301 |
|
| 302 |
|
| 303 |
# ---------------------------------------------------------------------------
|
| 304 |
-
# AI client
|
| 305 |
# ---------------------------------------------------------------------------
|
| 306 |
@st.cache_resource
|
| 307 |
-
def
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
-
def call_model(messages: list) -> str:
|
| 315 |
-
async def _call():
|
| 316 |
-
try:
|
| 317 |
-
client = get_model_client()
|
| 318 |
-
response = await client.chat.completions.create(
|
| 319 |
-
model=MODEL_NAME,
|
| 320 |
-
messages=messages,
|
| 321 |
-
max_tokens=1000,
|
| 322 |
-
temperature=0.7,
|
| 323 |
-
top_p=0.9,
|
| 324 |
-
)
|
| 325 |
-
content = response.choices[0].message.content.strip()
|
| 326 |
-
content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
|
| 327 |
-
return content
|
| 328 |
-
except Exception as e:
|
| 329 |
-
print(f"[MODEL] Error: {e}")
|
| 330 |
-
return f"[Model error: {e}]"
|
| 331 |
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
|
| 336 |
|
| 337 |
# ---------------------------------------------------------------------------
|
|
@@ -355,10 +448,11 @@ def get_hf_api():
|
|
| 355 |
|
| 356 |
def save_and_upload(state: dict):
|
| 357 |
hf_api = get_hf_api()
|
| 358 |
-
worker_id = state.get("
|
| 359 |
submission_id = state.get("submission_id", str(uuid.uuid4()))
|
| 360 |
safe_worker = "".join(c if c.isalnum() else "_" for c in str(worker_id))
|
| 361 |
-
|
|
|
|
| 362 |
folder = os.path.join(ANNOTATIONS_DIR, safe_worker)
|
| 363 |
os.makedirs(folder, exist_ok=True)
|
| 364 |
file_path = os.path.join(folder, filename)
|
|
@@ -383,7 +477,8 @@ def upload_csv_rows(state: dict, hf_api, safe_worker: str, submission_id: str):
|
|
| 383 |
demographics = state.get("demographics", {})
|
| 384 |
products = state.get("products", [])
|
| 385 |
header = [
|
| 386 |
-
"submission_id", "
|
|
|
|
| 387 |
"age", "gender", "geographic_region", "education_level", "race",
|
| 388 |
"us_citizen", "marital_status", "religion", "religious_attendance",
|
| 389 |
"political_affiliation", "income", "political_views", "household_size", "employment_status",
|
|
@@ -400,10 +495,14 @@ def upload_csv_rows(state: dict, hf_api, safe_worker: str, submission_id: str):
|
|
| 400 |
post = prod.get("post_willingness", "")
|
| 401 |
delta = (post - pre) if isinstance(pre, int) and isinstance(post, int) else ""
|
| 402 |
row = [
|
| 403 |
-
submission_id,
|
|
|
|
|
|
|
|
|
|
| 404 |
state.get("meta", {}).get("submission_time", ""),
|
| 405 |
state.get("meta", {}).get("duration_seconds", ""),
|
| 406 |
-
|
|
|
|
| 407 |
demographics.get("age", ""), demographics.get("gender", ""),
|
| 408 |
demographics.get("geographic_region", ""), demographics.get("education_level", ""),
|
| 409 |
demographics.get("race", ""), demographics.get("us_citizen", ""),
|
|
@@ -491,8 +590,9 @@ def parse_willingness(choice_str: str) -> int:
|
|
| 491 |
return 4
|
| 492 |
|
| 493 |
|
| 494 |
-
def get_familiarity_choices():
|
| 495 |
-
|
|
|
|
| 496 |
return [
|
| 497 |
"Never heard of it",
|
| 498 |
"Heard of it, but not used/purchased",
|
|
@@ -502,7 +602,6 @@ def get_familiarity_choices():
|
|
| 502 |
|
| 503 |
|
| 504 |
def needs_swap(familiarity_val: str, pre_will_val: str) -> bool:
|
| 505 |
-
"""Return True if this product should be swapped out."""
|
| 506 |
if familiarity_val in SWAP_FAMILIARITY:
|
| 507 |
return True
|
| 508 |
if pre_will_val == WILLINGNESS_CHOICES[-1]: # "Definitely would buy (7)"
|
|
@@ -510,6 +609,26 @@ def needs_swap(familiarity_val: str, pre_will_val: str) -> bool:
|
|
| 510 |
return False
|
| 511 |
|
| 512 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
# ---------------------------------------------------------------------------
|
| 514 |
# State initialisation
|
| 515 |
# ---------------------------------------------------------------------------
|
|
@@ -520,6 +639,7 @@ def make_product_slot(p: dict, was_swapped: bool = False) -> dict:
|
|
| 520 |
"description": p.get("description", []),
|
| 521 |
"features": p.get("features", []),
|
| 522 |
"price": p.get("price", "N/A"),
|
|
|
|
| 523 |
"familiarity": None,
|
| 524 |
"pre_willingness": None,
|
| 525 |
"post_willingness": None,
|
|
@@ -536,7 +656,7 @@ def make_product_slot(p: dict, was_swapped: bool = False) -> dict:
|
|
| 536 |
|
| 537 |
|
| 538 |
def init_state():
|
| 539 |
-
|
| 540 |
assigned = assign_products(PRODUCTS_PER_USER)
|
| 541 |
|
| 542 |
try:
|
|
@@ -547,12 +667,12 @@ def init_state():
|
|
| 547 |
return {
|
| 548 |
"submission_id": str(uuid.uuid4()),
|
| 549 |
"user_id": str(uuid.uuid4()),
|
| 550 |
-
"
|
| 551 |
-
"
|
| 552 |
-
"
|
| 553 |
-
"turk_submit_to": params.get("turkSubmitTo", ""),
|
| 554 |
"start_time": time.time(),
|
| 555 |
-
"
|
|
|
|
| 556 |
"demographics": {},
|
| 557 |
"products": [make_product_slot(p) for p in assigned],
|
| 558 |
"current_product_index": 0,
|
|
@@ -586,6 +706,14 @@ def inject_css():
|
|
| 586 |
}
|
| 587 |
.pc-title { font-size: 1.05rem; font-weight: 700; color: #1a1a2e; line-height: 1.35; flex: 1; }
|
| 588 |
.pc-price { font-size: 1.2rem; font-weight: 800; color: #16a34a; white-space: nowrap; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 589 |
.pc-section { margin-top: 0.5rem; }
|
| 590 |
.pc-section-title {
|
| 591 |
font-weight: 600; font-size: 0.85rem; color: #475569;
|
|
@@ -616,15 +744,20 @@ def render_product_card_html(product: dict, compact: bool = False) -> str:
|
|
| 616 |
price = product.get("price", "N/A")
|
| 617 |
description = product.get("description", [])
|
| 618 |
features = product.get("features", [])
|
|
|
|
| 619 |
price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price
|
| 620 |
|
| 621 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 622 |
desc_html = ""
|
| 623 |
if description:
|
| 624 |
desc_text = " ".join(d for d in description if d)
|
| 625 |
desc_html = f'<div class="pc-section"><div class="pc-section-title">📋 Description</div><div class="pc-desc">{desc_text}</div></div>'
|
| 626 |
|
| 627 |
-
# Features: bullet points
|
| 628 |
feat_html = ""
|
| 629 |
if features:
|
| 630 |
items_html = "".join(f"<li>{feat}</li>" for feat in features if feat)
|
|
@@ -633,6 +766,7 @@ def render_product_card_html(product: dict, compact: bool = False) -> str:
|
|
| 633 |
max_h = "max-height:240px;overflow-y:auto;" if compact else ""
|
| 634 |
return f"""
|
| 635 |
<div class="product-card" style="{max_h}">
|
|
|
|
| 636 |
<div class="pc-header">
|
| 637 |
<div class="pc-title">{title}</div>
|
| 638 |
<div class="pc-price">{price_str}</div>
|
|
@@ -668,8 +802,11 @@ def render_chat_history(turns: list):
|
|
| 668 |
# ---------------------------------------------------------------------------
|
| 669 |
def screen_welcome(s):
|
| 670 |
st.markdown("# 🛒 Product Evaluation Study")
|
|
|
|
| 671 |
st.markdown(
|
| 672 |
-
f"Welcome! In this study you will evaluate **{PRODUCTS_PER_USER} {
|
|
|
|
|
|
|
| 673 |
"For each product you will:\n"
|
| 674 |
"1. Rate how familiar you are with the product\n"
|
| 675 |
"2. Rate how willing you are to buy it\n"
|
|
@@ -755,14 +892,19 @@ def screen_demographics(s):
|
|
| 755 |
def screen_product_intro(s):
|
| 756 |
idx = s["current_product_index"]
|
| 757 |
product = s["products"][idx]
|
|
|
|
|
|
|
| 758 |
render_progress(idx + 1)
|
| 759 |
st.markdown("## Product Evaluation")
|
| 760 |
st.markdown("Please read the product information carefully, then answer the two questions below.")
|
| 761 |
st.markdown(render_product_card_html(product), unsafe_allow_html=True)
|
| 762 |
|
|
|
|
|
|
|
|
|
|
| 763 |
familiarity_val = st.radio(
|
| 764 |
"How familiar are you with this product?",
|
| 765 |
-
|
| 766 |
index=None,
|
| 767 |
key=f"familiarity_{idx}_{product['id']}",
|
| 768 |
)
|
|
@@ -782,21 +924,20 @@ def screen_product_intro(s):
|
|
| 782 |
st.error("⚠️ Please rate your willingness to buy.")
|
| 783 |
return
|
| 784 |
|
| 785 |
-
familiarity_val = familiarity_val or
|
| 786 |
pre_will_val = pre_will_val or WILLINGNESS_CHOICES[3]
|
| 787 |
|
| 788 |
# Check if we need to swap this product
|
| 789 |
if needs_swap(familiarity_val, pre_will_val) and not DEBUG_MODE:
|
| 790 |
current_ids = {p["id"] for p in s["products"]}
|
| 791 |
-
replacement = get_swap_product(exclude_ids=current_ids)
|
| 792 |
if replacement:
|
| 793 |
-
# Return the rejected product to the queue so it gets reviewed by someone else
|
| 794 |
return_product_to_queue(s["products"][idx])
|
| 795 |
s["products"][idx] = make_product_slot(replacement, was_swapped=True)
|
| 796 |
st.info("We've swapped this product for a better match. Please review the new product below.")
|
| 797 |
st.rerun()
|
| 798 |
return
|
| 799 |
-
#
|
| 800 |
|
| 801 |
pre_val = parse_willingness(pre_will_val)
|
| 802 |
s["products"][idx]["familiarity"] = familiarity_val
|
|
@@ -977,7 +1118,8 @@ def screen_reflection(s):
|
|
| 977 |
"submission_time": end_time,
|
| 978 |
"duration_seconds": round(end_time - s.get("start_time", end_time), 1),
|
| 979 |
"model": MODEL_NAME,
|
| 980 |
-
"
|
|
|
|
| 981 |
}
|
| 982 |
with st.spinner("Saving your responses…"):
|
| 983 |
save_and_upload(s)
|
|
@@ -998,9 +1140,11 @@ def screen_done(s):
|
|
| 998 |
post = p.get("post_willingness", "?")
|
| 999 |
delta = p.get("willingness_delta", 0)
|
| 1000 |
arrow = "➡️" if delta == 0 else ("⬆️" if delta > 0 else "⬇️")
|
|
|
|
| 1001 |
rows.append({
|
| 1002 |
"#": i + 1,
|
| 1003 |
-
|
|
|
|
| 1004 |
"Before": WILLINGNESS_LABELS.get(pre, str(pre)),
|
| 1005 |
"After": WILLINGNESS_LABELS.get(post, str(post)),
|
| 1006 |
"Change": f"{arrow} {delta:+d}" if isinstance(delta, int) else "–",
|
|
@@ -1008,22 +1152,20 @@ def screen_done(s):
|
|
| 1008 |
import pandas as pd
|
| 1009 |
st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
|
| 1010 |
|
| 1011 |
-
|
| 1012 |
-
|
| 1013 |
-
|
| 1014 |
-
|
| 1015 |
-
|
| 1016 |
-
|
| 1017 |
-
|
| 1018 |
-
|
| 1019 |
-
|
| 1020 |
-
|
| 1021 |
-
|
| 1022 |
-
|
| 1023 |
-
|
| 1024 |
-
|
| 1025 |
-
</form>
|
| 1026 |
-
""", unsafe_allow_html=True)
|
| 1027 |
|
| 1028 |
|
| 1029 |
# ---------------------------------------------------------------------------
|
|
|
|
| 1 |
"""
|
| 2 |
Streamlit App: AI Product Willingness User Study
|
| 3 |
=================================================
|
| 4 |
+
Run locally (single category):
|
| 5 |
streamlit run src/streamlit_app.py -- --category groceries
|
| 6 |
streamlit run src/streamlit_app.py -- --category groceries --debug
|
| 7 |
|
| 8 |
+
Run locally (mixed mode — movies + groceries):
|
| 9 |
+
streamlit run src/streamlit_app.py -- --mode mixed
|
| 10 |
+
streamlit run src/streamlit_app.py -- --mode mixed --debug
|
| 11 |
+
|
| 12 |
On HuggingFace Spaces, set these environment variables in Space Settings → Variables:
|
| 13 |
HF_TOKEN - HuggingFace token
|
| 14 |
TINKER_API_KEY - Tinker AI API key
|
|
|
|
| 15 |
DATASET_REPO_ID - HuggingFace dataset repo to upload results
|
| 16 |
+
CATEGORY - groceries | books | movies | health (single-category mode)
|
| 17 |
+
MODE - mixed (overrides CATEGORY; runs movies + groceries together)
|
| 18 |
DEBUG_MODE - "true" to skip validation (optional)
|
| 19 |
"""
|
| 20 |
|
|
|
|
|
|
|
| 21 |
import csv
|
| 22 |
import json
|
| 23 |
import os
|
|
|
|
| 34 |
from dotenv import load_dotenv
|
| 35 |
from filelock import FileLock
|
| 36 |
from huggingface_hub import HfApi
|
|
|
|
| 37 |
|
| 38 |
load_dotenv()
|
| 39 |
|
| 40 |
# ---------------------------------------------------------------------------
|
| 41 |
+
# CLI args
|
| 42 |
# ---------------------------------------------------------------------------
|
| 43 |
import argparse
|
| 44 |
parser = argparse.ArgumentParser(add_help=False)
|
| 45 |
parser.add_argument("--category", choices=["books", "groceries", "movies", "health"], default=None)
|
| 46 |
+
parser.add_argument("--mode", choices=["mixed"], default=None)
|
| 47 |
parser.add_argument("--debug", action="store_true", default=False)
|
| 48 |
cli_args, _ = parser.parse_known_args()
|
| 49 |
|
| 50 |
# ---------------------------------------------------------------------------
|
| 51 |
+
# Config
|
| 52 |
# ---------------------------------------------------------------------------
|
| 53 |
+
MODE = os.getenv("MODE") or cli_args.mode # "mixed" or None
|
| 54 |
+
CATEGORY = os.getenv("CATEGORY") or cli_args.category or "groceries" # used only in single-category mode
|
| 55 |
DEBUG_MODE = os.getenv("DEBUG_MODE", "").lower() == "true" or cli_args.debug
|
| 56 |
DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "your-username/product-study")
|
| 57 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 58 |
|
| 59 |
+
TINKER_API_KEY = os.getenv("TINKER_API_KEY")
|
| 60 |
+
MODEL_NAME = "openai/gpt-oss-20b"
|
| 61 |
+
|
| 62 |
+
# ---------------------------------------------------------------------------
|
| 63 |
+
# Mixed-mode constants
|
| 64 |
+
# ---------------------------------------------------------------------------
|
| 65 |
+
# In mixed mode these two categories are always used together
|
| 66 |
+
MIXED_CATEGORIES = ["movies", "groceries"]
|
| 67 |
+
# Each category contributes this many items to the shared pool of 100
|
| 68 |
+
MIXED_SUBSET_SIZE = 50 # 50 movies + 50 groceries = 100 total
|
| 69 |
+
SINGLE_SUBSET_SIZE = 100 # legacy single-category mode
|
| 70 |
+
|
| 71 |
+
# ---------------------------------------------------------------------------
|
| 72 |
+
# Prolific config
|
| 73 |
+
# ---------------------------------------------------------------------------
|
| 74 |
+
PROLIFIC_COMPLETION_URL = "https://app.prolific.com/submissions/complete?cc=CYC7ALM1"
|
| 75 |
+
PROLIFIC_COMPLETION_CODE = "CYC7ALM1"
|
| 76 |
|
| 77 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 78 |
DATA_DIR = os.path.join(BASE_DIR, "data")
|
|
|
|
| 81 |
os.makedirs(ANNOTATIONS_DIR, exist_ok=True)
|
| 82 |
|
| 83 |
CATEGORY_TO_HF = {
|
| 84 |
+
"books": "ehejin/amazon_books",
|
| 85 |
"groceries": "ehejin/amazon_Grocery_and_Gourmet_Food",
|
| 86 |
+
"movies": "ehejin/amazon_Movies_and_TV",
|
| 87 |
+
"health": "ehejin/amazon_Health_and_Household",
|
| 88 |
}
|
| 89 |
CATEGORY_DISPLAY = {
|
| 90 |
+
"books": "Books",
|
| 91 |
"groceries": "Grocery Products",
|
| 92 |
+
"movies": "Movies & TV",
|
| 93 |
+
"health": "Health & Household Products",
|
| 94 |
}
|
| 95 |
+
# Per-product familiarity label (depends on the individual product's category)
|
| 96 |
FAMILIARITY_USED_LABEL = {
|
| 97 |
+
"books": "Read it before",
|
| 98 |
+
"movies": "Watched it before",
|
| 99 |
"groceries": "Used it before",
|
| 100 |
+
"health": "Used it before",
|
| 101 |
}
|
| 102 |
|
| 103 |
PRODUCTS_PER_USER = 5
|
| 104 |
MIN_TURNS = 3
|
| 105 |
MAX_TURNS = 10
|
|
|
|
| 106 |
|
| 107 |
# Familiarity values that trigger a product swap
|
| 108 |
SWAP_FAMILIARITY = {"Purchased it before"}
|
|
|
|
| 128 |
}
|
| 129 |
WILLINGNESS_CHOICES = [f"{v} ({k})" for k, v in WILLINGNESS_LABELS.items()]
|
| 130 |
|
| 131 |
+
|
| 132 |
# ---------------------------------------------------------------------------
|
| 133 |
+
# Helpers: per-category file paths
|
| 134 |
# ---------------------------------------------------------------------------
|
| 135 |
+
def _data_path(category: str, suffix: str) -> str:
|
| 136 |
+
subset = MIXED_SUBSET_SIZE if MODE == "mixed" else SINGLE_SUBSET_SIZE
|
| 137 |
+
return os.path.join(DATA_DIR, f"{category}_test{subset}_{suffix}")
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def local_data_path(category: str) -> str:
|
| 141 |
+
return _data_path(category, "primary.json")
|
| 142 |
|
| 143 |
+
def overflow_path(category: str) -> str:
|
| 144 |
+
return _data_path(category, "overflow.json")
|
| 145 |
|
| 146 |
+
def counter_path(category: str) -> str:
|
| 147 |
+
return _data_path(category, "counter.txt")
|
| 148 |
+
|
| 149 |
+
def counter_lock_path(category: str) -> str:
|
| 150 |
+
return _data_path(category, "counter.lock")
|
| 151 |
+
|
| 152 |
+
def return_queue_path(category: str) -> str:
|
| 153 |
+
return _data_path(category, "return_queue.json")
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# ---------------------------------------------------------------------------
|
| 157 |
+
# Dataset loading
|
| 158 |
+
# ---------------------------------------------------------------------------
|
| 159 |
@st.cache_resource
|
| 160 |
+
def download_and_cache_dataset(category: str, subset_size: int):
|
| 161 |
+
"""Download test split from HuggingFace and cache locally."""
|
| 162 |
+
primary_path = local_data_path(category)
|
| 163 |
+
over_path = overflow_path(category)
|
| 164 |
+
if os.path.exists(primary_path):
|
| 165 |
+
print(f"[DATA] Found cached dataset for {category} at {primary_path}")
|
| 166 |
return
|
| 167 |
+
print(f"[DATA] Downloading {CATEGORY_TO_HF[category]} (test split, first {subset_size}) from HuggingFace...")
|
| 168 |
try:
|
| 169 |
from datasets import load_dataset
|
| 170 |
import huggingface_hub
|
| 171 |
if HF_TOKEN:
|
| 172 |
huggingface_hub.login(token=HF_TOKEN)
|
| 173 |
|
| 174 |
+
ds = load_dataset(CATEGORY_TO_HF[category], split="test")
|
| 175 |
|
| 176 |
def to_list(val):
|
| 177 |
if isinstance(val, list): return val
|
|
|
|
| 187 |
"description": to_list(meta.get("description", []) if isinstance(meta, dict) else []),
|
| 188 |
"features": to_list(meta.get("features", []) if isinstance(meta, dict) else []),
|
| 189 |
"price": meta.get("price", "N/A") if isinstance(meta, dict) else "N/A",
|
| 190 |
+
"category": category,
|
| 191 |
}
|
| 192 |
all_items.append(item)
|
| 193 |
|
| 194 |
+
primary = all_items[:subset_size]
|
| 195 |
+
overflow = all_items[subset_size:]
|
|
|
|
| 196 |
|
| 197 |
+
with open(primary_path, "w") as f:
|
| 198 |
json.dump(primary, f, indent=2)
|
| 199 |
+
with open(over_path, "w") as f:
|
| 200 |
json.dump(overflow, f, indent=2)
|
| 201 |
|
| 202 |
+
print(f"[DATA] {category}: cached {len(primary)} primary + {len(overflow)} overflow items.")
|
| 203 |
except Exception as e:
|
| 204 |
+
print(f"[DATA] ERROR downloading {category}: {e}")
|
| 205 |
raise
|
| 206 |
|
| 207 |
|
| 208 |
@st.cache_resource
|
| 209 |
+
def load_primary_dataset(category: str):
|
| 210 |
+
with open(local_data_path(category), "r") as f:
|
| 211 |
return json.load(f)
|
| 212 |
|
| 213 |
|
| 214 |
@st.cache_resource
|
| 215 |
+
def load_overflow_dataset(category: str):
|
| 216 |
+
path = overflow_path(category)
|
| 217 |
+
if not os.path.exists(path):
|
| 218 |
return []
|
| 219 |
+
with open(path, "r") as f:
|
| 220 |
return json.load(f)
|
| 221 |
|
| 222 |
|
| 223 |
+
def _ensure_datasets():
|
| 224 |
+
"""Download/cache all needed category datasets."""
|
| 225 |
+
if MODE == "mixed":
|
| 226 |
+
for cat in MIXED_CATEGORIES:
|
| 227 |
+
download_and_cache_dataset(cat, MIXED_SUBSET_SIZE)
|
| 228 |
+
else:
|
| 229 |
+
download_and_cache_dataset(CATEGORY, SINGLE_SUBSET_SIZE)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
# ---------------------------------------------------------------------------
|
| 233 |
+
# Per-category counter helpers
|
| 234 |
+
# ---------------------------------------------------------------------------
|
| 235 |
+
def _read_counter(category: str) -> int:
|
| 236 |
+
path = counter_path(category)
|
| 237 |
+
if not os.path.exists(path):
|
| 238 |
+
return 0
|
| 239 |
+
with open(path, "r") as f:
|
| 240 |
+
return int(f.read().strip() or "0")
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def _write_counter(category: str, value: int):
|
| 244 |
+
with open(counter_path(category), "w") as f:
|
| 245 |
+
f.write(str(value))
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def _read_return_queue(category: str) -> list:
|
| 249 |
+
path = return_queue_path(category)
|
| 250 |
+
if not os.path.exists(path):
|
| 251 |
+
return []
|
| 252 |
+
with open(path, "r") as f:
|
| 253 |
+
try:
|
| 254 |
+
return json.load(f)
|
| 255 |
+
except Exception:
|
| 256 |
+
return []
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def _write_return_queue(category: str, queue: list):
|
| 260 |
+
with open(return_queue_path(category), "w") as f:
|
| 261 |
+
json.dump(queue, f)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
# ---------------------------------------------------------------------------
|
| 265 |
+
# Product assignment
|
| 266 |
+
# ---------------------------------------------------------------------------
|
| 267 |
+
def _assign_from_category(category: str, n: int) -> list:
|
| 268 |
"""
|
| 269 |
+
Atomically assign n products from a single category pool.
|
| 270 |
+
- Drains the return queue first.
|
| 271 |
+
- Pulls sequentially from the primary pool.
|
| 272 |
+
- Wraps around (modulo pool size) when exhausted so user 21+ still get valid items.
|
| 273 |
"""
|
| 274 |
+
items = load_primary_dataset(category)
|
| 275 |
total = len(items)
|
| 276 |
+
lock = FileLock(counter_lock_path(category))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
|
| 278 |
+
with lock:
|
| 279 |
+
return_queue = _read_return_queue(category)
|
| 280 |
+
counter = _read_counter(category)
|
| 281 |
assigned = []
|
| 282 |
+
|
| 283 |
for _ in range(n):
|
| 284 |
if return_queue:
|
|
|
|
| 285 |
assigned.append(return_queue.pop(0))
|
|
|
|
|
|
|
|
|
|
| 286 |
else:
|
| 287 |
+
# Wrap-around: counter mod total so we cycle through items
|
| 288 |
+
assigned.append(items[counter % total])
|
| 289 |
+
counter += 1
|
| 290 |
+
|
| 291 |
+
_write_return_queue(category, return_queue)
|
| 292 |
+
_write_counter(category, counter)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
return assigned
|
| 295 |
|
| 296 |
|
| 297 |
+
def assign_mixed_products(n: int = PRODUCTS_PER_USER) -> list:
|
| 298 |
"""
|
| 299 |
+
Assign n products split across movies and groceries.
|
| 300 |
+
Alternates the majority category each call so coverage stays balanced.
|
| 301 |
+
|
| 302 |
+
User 1: 3 movies + 2 groceries
|
| 303 |
+
User 2: 2 movies + 3 groceries
|
| 304 |
+
User 3: 3 movies + 2 groceries ... etc.
|
| 305 |
+
|
| 306 |
+
The split is decided by reading the movies counter parity (even → movies gets 3).
|
| 307 |
"""
|
| 308 |
+
movies_counter = _read_counter("movies")
|
| 309 |
+
# Even call-count → movies gets the larger share
|
| 310 |
+
if (movies_counter // 1) % 2 == 0:
|
| 311 |
+
n_movies, n_groceries = 3, 2
|
| 312 |
+
else:
|
| 313 |
+
n_movies, n_groceries = 2, 3
|
| 314 |
+
|
| 315 |
+
# Clamp in case n != 5
|
| 316 |
+
if n_movies + n_groceries != n:
|
| 317 |
+
n_movies = n // 2
|
| 318 |
+
n_groceries = n - n_movies
|
| 319 |
+
|
| 320 |
+
movie_items = _assign_from_category("movies", n_movies)
|
| 321 |
+
grocery_items = _assign_from_category("groceries", n_groceries)
|
| 322 |
+
|
| 323 |
+
combined = movie_items + grocery_items
|
| 324 |
+
random.shuffle(combined) # mix so user doesn't see all movies then all groceries
|
| 325 |
+
return combined
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def assign_products(n: int = PRODUCTS_PER_USER) -> list:
|
| 329 |
+
"""Dispatcher: mixed mode or single-category mode."""
|
| 330 |
+
if MODE == "mixed":
|
| 331 |
+
return assign_mixed_products(n)
|
| 332 |
+
# Single-category (legacy behaviour)
|
| 333 |
+
return _assign_from_category(CATEGORY, n)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def return_product_to_queue(product: dict):
|
| 337 |
+
"""Put a rejected/swapped product back so it gets reassigned."""
|
| 338 |
+
cat = product.get("category", CATEGORY)
|
| 339 |
+
lock = FileLock(counter_lock_path(cat))
|
| 340 |
with lock:
|
| 341 |
+
queue = _read_return_queue(cat)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
if not any(p["id"] == product["id"] for p in queue):
|
| 343 |
queue.append(product)
|
| 344 |
+
_write_return_queue(cat, queue)
|
|
|
|
| 345 |
|
| 346 |
|
| 347 |
+
def get_swap_product(exclude_ids: set, category: str) -> dict | None:
|
| 348 |
"""
|
| 349 |
+
Get a replacement product for the given category.
|
| 350 |
+
1. Next unassigned primary product (advances counter).
|
| 351 |
+
2. Wrap-around: any primary product not held by this user.
|
| 352 |
+
3. Overflow pool.
|
| 353 |
"""
|
| 354 |
+
items = load_primary_dataset(category)
|
| 355 |
+
overflow = load_overflow_dataset(category)
|
| 356 |
+
total = len(items)
|
| 357 |
|
| 358 |
+
lock = FileLock(counter_lock_path(category))
|
| 359 |
with lock:
|
| 360 |
+
counter = _read_counter(category)
|
| 361 |
+
|
| 362 |
+
# 1. Unassigned (with wrap-around awareness)
|
| 363 |
+
attempts = 0
|
| 364 |
+
while attempts < total:
|
| 365 |
+
candidate = items[counter % total]
|
|
|
|
|
|
|
| 366 |
counter += 1
|
| 367 |
+
attempts += 1
|
| 368 |
if candidate["id"] not in exclude_ids:
|
| 369 |
+
_write_counter(category, counter)
|
|
|
|
|
|
|
| 370 |
return candidate
|
| 371 |
|
| 372 |
+
# 2. Any primary product not held by this user
|
| 373 |
for p in items:
|
| 374 |
if p["id"] not in exclude_ids:
|
| 375 |
return p
|
| 376 |
|
| 377 |
+
# 3. Overflow
|
| 378 |
for p in overflow:
|
| 379 |
if p["id"] not in exclude_ids:
|
| 380 |
return p
|
| 381 |
|
| 382 |
+
return None
|
| 383 |
|
| 384 |
|
| 385 |
# ---------------------------------------------------------------------------
|
| 386 |
+
# AI client (Tinker)
|
| 387 |
# ---------------------------------------------------------------------------
|
| 388 |
@st.cache_resource
|
| 389 |
+
def get_tinker_clients():
|
| 390 |
+
"""Initialise and cache Tinker sampling client, renderer, and tokenizer."""
|
| 391 |
+
import tinker
|
| 392 |
+
from tinker import types as tinker_types
|
| 393 |
+
from tinker_cookbook import renderers
|
| 394 |
+
from tinker_cookbook.tokenizer_utils import get_tokenizer
|
| 395 |
+
from tinker_cookbook.model_info import get_recommended_renderer_name
|
| 396 |
+
|
| 397 |
+
service_client = tinker.ServiceClient()
|
| 398 |
+
sampling_client = service_client.create_sampling_client(base_model=MODEL_NAME)
|
| 399 |
+
tokenizer = get_tokenizer(MODEL_NAME)
|
| 400 |
+
renderer_name = get_recommended_renderer_name(MODEL_NAME)
|
| 401 |
+
renderer = renderers.get_renderer(renderer_name, tokenizer)
|
| 402 |
+
return sampling_client, renderer, tinker_types
|
| 403 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
|
| 405 |
+
def call_model(messages: list) -> str:
|
| 406 |
+
try:
|
| 407 |
+
from tinker_cookbook import renderers as tinker_renderers
|
| 408 |
+
sampling_client, renderer, tinker_types = get_tinker_clients()
|
| 409 |
+
|
| 410 |
+
prompt = renderer.build_generation_prompt(messages)
|
| 411 |
+
params = tinker_types.SamplingParams(
|
| 412 |
+
max_tokens=1000,
|
| 413 |
+
temperature=0.7,
|
| 414 |
+
stop=renderer.get_stop_sequences(),
|
| 415 |
+
)
|
| 416 |
+
result = sampling_client.sample(
|
| 417 |
+
prompt=prompt,
|
| 418 |
+
sampling_params=params,
|
| 419 |
+
num_samples=1,
|
| 420 |
+
).result()
|
| 421 |
+
parsed_message, _ = renderer.parse_response(result.sequences[0].tokens)
|
| 422 |
+
content = tinker_renderers.format_content_as_string(parsed_message["content"])
|
| 423 |
+
content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
|
| 424 |
+
return content
|
| 425 |
+
except Exception as e:
|
| 426 |
+
print(f"[MODEL] Tinker error: {e}")
|
| 427 |
+
return f"[Model error: {e}]"
|
| 428 |
|
| 429 |
|
| 430 |
# ---------------------------------------------------------------------------
|
|
|
|
| 448 |
|
| 449 |
def save_and_upload(state: dict):
|
| 450 |
hf_api = get_hf_api()
|
| 451 |
+
worker_id = state.get("prolific_pid") or state.get("user_id", "anonymous")
|
| 452 |
submission_id = state.get("submission_id", str(uuid.uuid4()))
|
| 453 |
safe_worker = "".join(c if c.isalnum() else "_" for c in str(worker_id))
|
| 454 |
+
mode_tag = state.get("mode", "single")
|
| 455 |
+
filename = f"{submission_id}_{mode_tag}.json"
|
| 456 |
folder = os.path.join(ANNOTATIONS_DIR, safe_worker)
|
| 457 |
os.makedirs(folder, exist_ok=True)
|
| 458 |
file_path = os.path.join(folder, filename)
|
|
|
|
| 477 |
demographics = state.get("demographics", {})
|
| 478 |
products = state.get("products", [])
|
| 479 |
header = [
|
| 480 |
+
"submission_id", "prolific_pid", "study_id", "session_id",
|
| 481 |
+
"submission_time", "duration_seconds", "mode", "category",
|
| 482 |
"age", "gender", "geographic_region", "education_level", "race",
|
| 483 |
"us_citizen", "marital_status", "religion", "religious_attendance",
|
| 484 |
"political_affiliation", "income", "political_views", "household_size", "employment_status",
|
|
|
|
| 495 |
post = prod.get("post_willingness", "")
|
| 496 |
delta = (post - pre) if isinstance(pre, int) and isinstance(post, int) else ""
|
| 497 |
row = [
|
| 498 |
+
submission_id,
|
| 499 |
+
state.get("prolific_pid", ""),
|
| 500 |
+
state.get("study_id", ""),
|
| 501 |
+
state.get("session_id", ""),
|
| 502 |
state.get("meta", {}).get("submission_time", ""),
|
| 503 |
state.get("meta", {}).get("duration_seconds", ""),
|
| 504 |
+
state.get("mode", "single"),
|
| 505 |
+
prod.get("category", ""), # per-product category
|
| 506 |
demographics.get("age", ""), demographics.get("gender", ""),
|
| 507 |
demographics.get("geographic_region", ""), demographics.get("education_level", ""),
|
| 508 |
demographics.get("race", ""), demographics.get("us_citizen", ""),
|
|
|
|
| 590 |
return 4
|
| 591 |
|
| 592 |
|
| 593 |
+
def get_familiarity_choices(category: str) -> list:
|
| 594 |
+
"""Return familiarity options with the correct 'used' label for this product's category."""
|
| 595 |
+
used_label = FAMILIARITY_USED_LABEL.get(category, "Used it before")
|
| 596 |
return [
|
| 597 |
"Never heard of it",
|
| 598 |
"Heard of it, but not used/purchased",
|
|
|
|
| 602 |
|
| 603 |
|
| 604 |
def needs_swap(familiarity_val: str, pre_will_val: str) -> bool:
|
|
|
|
| 605 |
if familiarity_val in SWAP_FAMILIARITY:
|
| 606 |
return True
|
| 607 |
if pre_will_val == WILLINGNESS_CHOICES[-1]: # "Definitely would buy (7)"
|
|
|
|
| 609 |
return False
|
| 610 |
|
| 611 |
|
| 612 |
+
# ---------------------------------------------------------------------------
|
| 613 |
+
# Welcome screen helpers
|
| 614 |
+
# ---------------------------------------------------------------------------
|
| 615 |
+
def study_display_name() -> str:
|
| 616 |
+
"""Human-readable name for what the user will evaluate."""
|
| 617 |
+
if MODE == "mixed":
|
| 618 |
+
return "Movies & TV and Grocery Products"
|
| 619 |
+
return CATEGORY_DISPLAY.get(CATEGORY, CATEGORY)
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
def study_category_breakdown() -> str:
|
| 623 |
+
"""Extra sentence shown on welcome screen describing the mix."""
|
| 624 |
+
if MODE == "mixed":
|
| 625 |
+
return (
|
| 626 |
+
"You will evaluate a mix of **Movies & TV** and **Grocery Products** "
|
| 627 |
+
"(roughly 2–3 of each)."
|
| 628 |
+
)
|
| 629 |
+
return ""
|
| 630 |
+
|
| 631 |
+
|
| 632 |
# ---------------------------------------------------------------------------
|
| 633 |
# State initialisation
|
| 634 |
# ---------------------------------------------------------------------------
|
|
|
|
| 639 |
"description": p.get("description", []),
|
| 640 |
"features": p.get("features", []),
|
| 641 |
"price": p.get("price", "N/A"),
|
| 642 |
+
"category": p.get("category", CATEGORY), # ← per-product category
|
| 643 |
"familiarity": None,
|
| 644 |
"pre_willingness": None,
|
| 645 |
"post_willingness": None,
|
|
|
|
| 656 |
|
| 657 |
|
| 658 |
def init_state():
|
| 659 |
+
_ensure_datasets()
|
| 660 |
assigned = assign_products(PRODUCTS_PER_USER)
|
| 661 |
|
| 662 |
try:
|
|
|
|
| 667 |
return {
|
| 668 |
"submission_id": str(uuid.uuid4()),
|
| 669 |
"user_id": str(uuid.uuid4()),
|
| 670 |
+
"prolific_pid": params.get("PROLIFIC_PID", ""),
|
| 671 |
+
"study_id": params.get("STUDY_ID", ""),
|
| 672 |
+
"session_id": params.get("SESSION_ID", ""),
|
|
|
|
| 673 |
"start_time": time.time(),
|
| 674 |
+
"mode": MODE or "single",
|
| 675 |
+
"category": CATEGORY if MODE != "mixed" else "mixed",
|
| 676 |
"demographics": {},
|
| 677 |
"products": [make_product_slot(p) for p in assigned],
|
| 678 |
"current_product_index": 0,
|
|
|
|
| 706 |
}
|
| 707 |
.pc-title { font-size: 1.05rem; font-weight: 700; color: #1a1a2e; line-height: 1.35; flex: 1; }
|
| 708 |
.pc-price { font-size: 1.2rem; font-weight: 800; color: #16a34a; white-space: nowrap; }
|
| 709 |
+
.pc-category-badge {
|
| 710 |
+
display: inline-block;
|
| 711 |
+
font-size: 0.75rem; font-weight: 600;
|
| 712 |
+
padding: 0.15rem 0.55rem;
|
| 713 |
+
border-radius: 99px;
|
| 714 |
+
margin-bottom: 0.4rem;
|
| 715 |
+
background: #dbeafe; color: #1e40af;
|
| 716 |
+
}
|
| 717 |
.pc-section { margin-top: 0.5rem; }
|
| 718 |
.pc-section-title {
|
| 719 |
font-weight: 600; font-size: 0.85rem; color: #475569;
|
|
|
|
| 744 |
price = product.get("price", "N/A")
|
| 745 |
description = product.get("description", [])
|
| 746 |
features = product.get("features", [])
|
| 747 |
+
category = product.get("category", "")
|
| 748 |
price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price
|
| 749 |
|
| 750 |
+
# Category badge — only shown in mixed mode
|
| 751 |
+
badge_html = ""
|
| 752 |
+
if MODE == "mixed" and category:
|
| 753 |
+
badge_label = CATEGORY_DISPLAY.get(category, category)
|
| 754 |
+
badge_html = f'<div class="pc-category-badge">📂 {badge_label}</div>'
|
| 755 |
+
|
| 756 |
desc_html = ""
|
| 757 |
if description:
|
| 758 |
desc_text = " ".join(d for d in description if d)
|
| 759 |
desc_html = f'<div class="pc-section"><div class="pc-section-title">📋 Description</div><div class="pc-desc">{desc_text}</div></div>'
|
| 760 |
|
|
|
|
| 761 |
feat_html = ""
|
| 762 |
if features:
|
| 763 |
items_html = "".join(f"<li>{feat}</li>" for feat in features if feat)
|
|
|
|
| 766 |
max_h = "max-height:240px;overflow-y:auto;" if compact else ""
|
| 767 |
return f"""
|
| 768 |
<div class="product-card" style="{max_h}">
|
| 769 |
+
{badge_html}
|
| 770 |
<div class="pc-header">
|
| 771 |
<div class="pc-title">{title}</div>
|
| 772 |
<div class="pc-price">{price_str}</div>
|
|
|
|
| 802 |
# ---------------------------------------------------------------------------
|
| 803 |
def screen_welcome(s):
|
| 804 |
st.markdown("# 🛒 Product Evaluation Study")
|
| 805 |
+
breakdown = study_category_breakdown()
|
| 806 |
st.markdown(
|
| 807 |
+
f"Welcome! In this study you will evaluate **{PRODUCTS_PER_USER} {study_display_name()}** products.\n\n"
|
| 808 |
+
+ (f"{breakdown}\n\n" if breakdown else "")
|
| 809 |
+
+
|
| 810 |
"For each product you will:\n"
|
| 811 |
"1. Rate how familiar you are with the product\n"
|
| 812 |
"2. Rate how willing you are to buy it\n"
|
|
|
|
| 892 |
def screen_product_intro(s):
|
| 893 |
idx = s["current_product_index"]
|
| 894 |
product = s["products"][idx]
|
| 895 |
+
product_category = product.get("category", CATEGORY)
|
| 896 |
+
|
| 897 |
render_progress(idx + 1)
|
| 898 |
st.markdown("## Product Evaluation")
|
| 899 |
st.markdown("Please read the product information carefully, then answer the two questions below.")
|
| 900 |
st.markdown(render_product_card_html(product), unsafe_allow_html=True)
|
| 901 |
|
| 902 |
+
# Use per-product familiarity choices based on the product's own category
|
| 903 |
+
familiarity_choices = get_familiarity_choices(product_category)
|
| 904 |
+
|
| 905 |
familiarity_val = st.radio(
|
| 906 |
"How familiar are you with this product?",
|
| 907 |
+
familiarity_choices,
|
| 908 |
index=None,
|
| 909 |
key=f"familiarity_{idx}_{product['id']}",
|
| 910 |
)
|
|
|
|
| 924 |
st.error("⚠️ Please rate your willingness to buy.")
|
| 925 |
return
|
| 926 |
|
| 927 |
+
familiarity_val = familiarity_val or familiarity_choices[0]
|
| 928 |
pre_will_val = pre_will_val or WILLINGNESS_CHOICES[3]
|
| 929 |
|
| 930 |
# Check if we need to swap this product
|
| 931 |
if needs_swap(familiarity_val, pre_will_val) and not DEBUG_MODE:
|
| 932 |
current_ids = {p["id"] for p in s["products"]}
|
| 933 |
+
replacement = get_swap_product(exclude_ids=current_ids, category=product_category)
|
| 934 |
if replacement:
|
|
|
|
| 935 |
return_product_to_queue(s["products"][idx])
|
| 936 |
s["products"][idx] = make_product_slot(replacement, was_swapped=True)
|
| 937 |
st.info("We've swapped this product for a better match. Please review the new product below.")
|
| 938 |
st.rerun()
|
| 939 |
return
|
| 940 |
+
# No replacement found — proceed with this product anyway
|
| 941 |
|
| 942 |
pre_val = parse_willingness(pre_will_val)
|
| 943 |
s["products"][idx]["familiarity"] = familiarity_val
|
|
|
|
| 1118 |
"submission_time": end_time,
|
| 1119 |
"duration_seconds": round(end_time - s.get("start_time", end_time), 1),
|
| 1120 |
"model": MODEL_NAME,
|
| 1121 |
+
"mode": MODE or "single",
|
| 1122 |
+
"category": CATEGORY if MODE != "mixed" else "mixed",
|
| 1123 |
}
|
| 1124 |
with st.spinner("Saving your responses…"):
|
| 1125 |
save_and_upload(s)
|
|
|
|
| 1140 |
post = p.get("post_willingness", "?")
|
| 1141 |
delta = p.get("willingness_delta", 0)
|
| 1142 |
arrow = "➡️" if delta == 0 else ("⬆️" if delta > 0 else "⬇️")
|
| 1143 |
+
cat_label = CATEGORY_DISPLAY.get(p.get("category", ""), "") if MODE == "mixed" else ""
|
| 1144 |
rows.append({
|
| 1145 |
"#": i + 1,
|
| 1146 |
+
**({"Category": cat_label} if MODE == "mixed" else {}),
|
| 1147 |
+
"Product": p.get("title", "")[:55] + ("…" if len(p.get("title", "")) > 55 else ""),
|
| 1148 |
"Before": WILLINGNESS_LABELS.get(pre, str(pre)),
|
| 1149 |
"After": WILLINGNESS_LABELS.get(post, str(post)),
|
| 1150 |
"Change": f"{arrow} {delta:+d}" if isinstance(delta, int) else "–",
|
|
|
|
| 1152 |
import pandas as pd
|
| 1153 |
st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
|
| 1154 |
|
| 1155 |
+
st.markdown("---")
|
| 1156 |
+
st.success(
|
| 1157 |
+
f"**Your completion code:** `{PROLIFIC_COMPLETION_CODE}`\n\n"
|
| 1158 |
+
"You can either click the button below to return to Prolific automatically, "
|
| 1159 |
+
"or copy the code above and paste it on the Prolific website."
|
| 1160 |
+
)
|
| 1161 |
+
st.markdown(
|
| 1162 |
+
f"""<a href="{PROLIFIC_COMPLETION_URL}" target="_self">
|
| 1163 |
+
<button style="background:#2563eb;color:white;border:none;padding:12px 28px;
|
| 1164 |
+
font-size:1rem;border-radius:6px;cursor:pointer;margin-top:8px;">
|
| 1165 |
+
✅ Return to Prolific
|
| 1166 |
+
</button></a>""",
|
| 1167 |
+
unsafe_allow_html=True,
|
| 1168 |
+
)
|
|
|
|
|
|
|
| 1169 |
|
| 1170 |
|
| 1171 |
# ---------------------------------------------------------------------------
|