ehejin commited on
Commit
9455a27
·
1 Parent(s): 839eaee

changed to tinker api and merged movies and groceries

Browse files
Files changed (2) hide show
  1. Dockerfile +1 -1
  2. src/streamlit_app.py +335 -193
Dockerfile CHANGED
@@ -17,4 +17,4 @@ EXPOSE 8501
17
 
18
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
 
20
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
17
 
18
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
 
20
+ ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0", "--", "--mode", "mixed"]
src/streamlit_app.py CHANGED
@@ -1,21 +1,23 @@
1
  """
2
  Streamlit App: AI Product Willingness User Study
3
  =================================================
4
- Run locally:
5
  streamlit run src/streamlit_app.py -- --category groceries
6
  streamlit run src/streamlit_app.py -- --category groceries --debug
7
 
 
 
 
 
8
  On HuggingFace Spaces, set these environment variables in Space Settings → Variables:
9
  HF_TOKEN - HuggingFace token
10
  TINKER_API_KEY - Tinker AI API key
11
- TINKER_MODEL_PATH - Tinker sampler checkpoint path
12
  DATASET_REPO_ID - HuggingFace dataset repo to upload results
13
- CATEGORY - groceries | books | movies | health (default: groceries)
 
14
  DEBUG_MODE - "true" to skip validation (optional)
15
  """
16
 
17
- import asyncio
18
- import concurrent.futures
19
  import csv
20
  import json
21
  import os
@@ -32,33 +34,45 @@ import streamlit as st
32
  from dotenv import load_dotenv
33
  from filelock import FileLock
34
  from huggingface_hub import HfApi
35
- from openai import AsyncOpenAI
36
 
37
  load_dotenv()
38
 
39
  # ---------------------------------------------------------------------------
40
- # CLI args (supported locally; ignored on HF Spaces — use env vars instead)
41
  # ---------------------------------------------------------------------------
42
  import argparse
43
  parser = argparse.ArgumentParser(add_help=False)
44
  parser.add_argument("--category", choices=["books", "groceries", "movies", "health"], default=None)
 
45
  parser.add_argument("--debug", action="store_true", default=False)
46
  cli_args, _ = parser.parse_known_args()
47
 
48
  # ---------------------------------------------------------------------------
49
- # Config (env vars take precedence, then CLI args, then defaults)
50
  # ---------------------------------------------------------------------------
51
- CATEGORY = os.getenv("CATEGORY") or cli_args.category or "groceries"
 
52
  DEBUG_MODE = os.getenv("DEBUG_MODE", "").lower() == "true" or cli_args.debug
53
  DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "your-username/product-study")
54
  HF_TOKEN = os.getenv("HF_TOKEN")
55
 
56
- # TINKER_API_KEY = os.getenv("TINKER_API_KEY")
57
- # TINKER_BASE_URL = "https://tinker.thinkingmachines.dev/services/tinker-prod/oai/api/v1"
58
- # MODEL_NAME = os.getenv("TINKER_MODEL_PATH", "tinker://YOUR_RUN_ID:train:0/sampler_weights/000080")
59
- TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
60
- TOGETHER_BASE_URL = "https://api.together.xyz/v1"
61
- MODEL_NAME = "openai/gpt-oss-20b" # or whichever model you want
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
64
  DATA_DIR = os.path.join(BASE_DIR, "data")
@@ -67,28 +81,28 @@ os.makedirs(DATA_DIR, exist_ok=True)
67
  os.makedirs(ANNOTATIONS_DIR, exist_ok=True)
68
 
69
  CATEGORY_TO_HF = {
70
- "books": "ehejin/amazon_books",
71
  "groceries": "ehejin/amazon_Grocery_and_Gourmet_Food",
72
- "movies": "ehejin/amazon_Movies_and_TV",
73
- "health": "ehejin/amazon_Health_and_Household",
74
  }
75
  CATEGORY_DISPLAY = {
76
- "books": "Books",
77
  "groceries": "Grocery Products",
78
- "movies": "Movies & TV",
79
- "health": "Health & Household Products",
80
  }
 
81
  FAMILIARITY_USED_LABEL = {
82
- "books": "Read it before",
83
- "movies": "Watched it before",
84
  "groceries": "Used it before",
85
- "health": "Used it before",
86
  }
87
 
88
  PRODUCTS_PER_USER = 5
89
  MIN_TURNS = 3
90
  MAX_TURNS = 10
91
- TEST_SUBSET_SIZE = 100 # only use first 100 items from test split
92
 
93
  # Familiarity values that trigger a product swap
94
  SWAP_FAMILIARITY = {"Purchased it before"}
@@ -114,32 +128,50 @@ WILLINGNESS_LABELS = {
114
  }
115
  WILLINGNESS_CHOICES = [f"{v} ({k})" for k, v in WILLINGNESS_LABELS.items()]
116
 
 
117
  # ---------------------------------------------------------------------------
118
- # Dataset loading test split, first 100 items
119
  # ---------------------------------------------------------------------------
120
- LOCAL_DATA_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_test100.json")
121
- # Counter tracks which of the 100 products have been assigned globally
122
- COUNTER_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_counter.txt")
123
- COUNTER_LOCK_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_counter.lock")
124
- RETURN_QUEUE_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_return_queue.json")
125
- # Overflow pool for swap replacements (products beyond the 100, or re-used ones)
126
- OVERFLOW_PATH = os.path.join(DATA_DIR, f"{CATEGORY}_overflow.json")
127
 
 
 
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  @st.cache_resource
130
- def download_and_cache_dataset():
131
- """Download test split (first 100 items) from HuggingFace and cache locally."""
132
- if os.path.exists(LOCAL_DATA_PATH):
133
- print(f"[DATA] Found cached dataset at {LOCAL_DATA_PATH}")
 
 
134
  return
135
- print(f"[DATA] Downloading {CATEGORY_TO_HF[CATEGORY]} (test split) from HuggingFace...")
136
  try:
137
  from datasets import load_dataset
138
  import huggingface_hub
139
  if HF_TOKEN:
140
  huggingface_hub.login(token=HF_TOKEN)
141
 
142
- ds = load_dataset(CATEGORY_TO_HF[CATEGORY], split="test")
143
 
144
  def to_list(val):
145
  if isinstance(val, list): return val
@@ -155,183 +187,244 @@ def download_and_cache_dataset():
155
  "description": to_list(meta.get("description", []) if isinstance(meta, dict) else []),
156
  "features": to_list(meta.get("features", []) if isinstance(meta, dict) else []),
157
  "price": meta.get("price", "N/A") if isinstance(meta, dict) else "N/A",
158
- "category": CATEGORY,
159
  }
160
  all_items.append(item)
161
 
162
- # First 100 are the primary pool; the rest are the overflow/swap pool
163
- primary = all_items[:TEST_SUBSET_SIZE]
164
- overflow = all_items[TEST_SUBSET_SIZE:]
165
 
166
- with open(LOCAL_DATA_PATH, "w") as f:
167
  json.dump(primary, f, indent=2)
168
- with open(OVERFLOW_PATH, "w") as f:
169
  json.dump(overflow, f, indent=2)
170
 
171
- print(f"[DATA] Cached {len(primary)} primary + {len(overflow)} overflow items.")
172
  except Exception as e:
173
- print(f"[DATA] ERROR downloading dataset: {e}")
174
  raise
175
 
176
 
177
  @st.cache_resource
178
- def load_primary_dataset():
179
- with open(LOCAL_DATA_PATH, "r") as f:
180
  return json.load(f)
181
 
182
 
183
  @st.cache_resource
184
- def load_overflow_dataset():
185
- if not os.path.exists(OVERFLOW_PATH):
 
186
  return []
187
- with open(OVERFLOW_PATH, "r") as f:
188
  return json.load(f)
189
 
190
 
191
- def assign_products(n=PRODUCTS_PER_USER):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  """
193
- Atomically assign the next n products.
194
- Drains the return queue first (rejected products waiting for reassignment),
195
- then pulls from the primary pool sequentially.
196
- Falls back to overflow only if primary 100 is fully exhausted.
197
  """
198
- items = load_primary_dataset()
199
  total = len(items)
200
- lock = FileLock(COUNTER_LOCK_PATH)
201
- with lock:
202
- # Load return queue
203
- return_queue = []
204
- if os.path.exists(RETURN_QUEUE_PATH):
205
- with open(RETURN_QUEUE_PATH, "r") as f:
206
- try:
207
- return_queue = json.load(f)
208
- except Exception:
209
- return_queue = []
210
-
211
- # Load counter
212
- counter = 0
213
- if os.path.exists(COUNTER_PATH):
214
- with open(COUNTER_PATH, "r") as f:
215
- counter = int(f.read().strip() or "0")
216
 
 
 
 
217
  assigned = []
 
218
  for _ in range(n):
219
  if return_queue:
220
- # Prioritise returned products so they still get reviewed
221
  assigned.append(return_queue.pop(0))
222
- elif counter < total:
223
- assigned.append(items[counter])
224
- counter += 1
225
  else:
226
- # Primary pool exhausted fall back to overflow
227
- overflow = load_overflow_dataset()
228
- if overflow:
229
- assigned.append(overflow[0])
230
- # If truly nothing left, skip (shouldn't happen with 20 users / 100 products)
231
-
232
- # Persist state
233
- with open(RETURN_QUEUE_PATH, "w") as f:
234
- json.dump(return_queue, f)
235
- with open(COUNTER_PATH, "w") as f:
236
- f.write(str(counter))
237
 
238
  return assigned
239
 
240
 
241
- def return_product_to_queue(product: dict):
242
  """
243
- Put a rejected/swapped product back into the queue so it gets
244
- reassigned to the next available user slot.
 
 
 
 
 
 
245
  """
246
- lock = FileLock(COUNTER_LOCK_PATH)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  with lock:
248
- queue = []
249
- if os.path.exists(RETURN_QUEUE_PATH):
250
- with open(RETURN_QUEUE_PATH, "r") as f:
251
- try:
252
- queue = json.load(f)
253
- except Exception:
254
- queue = []
255
- # Avoid duplicates
256
  if not any(p["id"] == product["id"] for p in queue):
257
  queue.append(product)
258
- with open(RETURN_QUEUE_PATH, "w") as f:
259
- json.dump(queue, f)
260
 
261
 
262
- def get_swap_product(exclude_ids: set) -> dict | None:
263
  """
264
- Get the next unassigned product from the primary pool (advances the counter
265
- so the picked product is consumed and won't be assigned again),
266
- fall back to any primary product not held by this user (overlap ok),
267
- then overflow.
268
  """
269
- items = load_primary_dataset()
270
- overflow = load_overflow_dataset()
271
- total = len(items)
272
 
273
- lock = FileLock(COUNTER_LOCK_PATH)
274
  with lock:
275
- counter = 0
276
- if os.path.exists(COUNTER_PATH):
277
- with open(COUNTER_PATH, "r") as f:
278
- counter = int(f.read().strip() or "0")
279
-
280
- # 1. Try unassigned primary products — advance counter when we pick one
281
- while counter < total:
282
- candidate = items[counter]
283
  counter += 1
 
284
  if candidate["id"] not in exclude_ids:
285
- # Persist advanced counter so this product isn't assigned again
286
- with open(COUNTER_PATH, "w") as f:
287
- f.write(str(counter))
288
  return candidate
289
 
290
- # 2. Primary pool exhausted — any primary product not held by this user (overlap ok)
291
  for p in items:
292
  if p["id"] not in exclude_ids:
293
  return p
294
 
295
- # 3. Last resort: overflow
296
  for p in overflow:
297
  if p["id"] not in exclude_ids:
298
  return p
299
 
300
- return None # extremely unlikely
301
 
302
 
303
  # ---------------------------------------------------------------------------
304
- # AI client
305
  # ---------------------------------------------------------------------------
306
  @st.cache_resource
307
- def get_model_client():
308
- return AsyncOpenAI(
309
- base_url=TOGETHER_BASE_URL,
310
- api_key=TOGETHER_API_KEY,
311
- timeout=60.0,
312
- )
 
 
 
 
 
 
 
 
313
 
314
- def call_model(messages: list) -> str:
315
- async def _call():
316
- try:
317
- client = get_model_client()
318
- response = await client.chat.completions.create(
319
- model=MODEL_NAME,
320
- messages=messages,
321
- max_tokens=1000,
322
- temperature=0.7,
323
- top_p=0.9,
324
- )
325
- content = response.choices[0].message.content.strip()
326
- content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
327
- return content
328
- except Exception as e:
329
- print(f"[MODEL] Error: {e}")
330
- return f"[Model error: {e}]"
331
 
332
- with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
333
- future = pool.submit(asyncio.run, _call())
334
- return future.result()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
 
337
  # ---------------------------------------------------------------------------
@@ -355,10 +448,11 @@ def get_hf_api():
355
 
356
  def save_and_upload(state: dict):
357
  hf_api = get_hf_api()
358
- worker_id = state.get("worker_id") or state.get("user_id", "anonymous")
359
  submission_id = state.get("submission_id", str(uuid.uuid4()))
360
  safe_worker = "".join(c if c.isalnum() else "_" for c in str(worker_id))
361
- filename = f"{submission_id}_{CATEGORY}.json"
 
362
  folder = os.path.join(ANNOTATIONS_DIR, safe_worker)
363
  os.makedirs(folder, exist_ok=True)
364
  file_path = os.path.join(folder, filename)
@@ -383,7 +477,8 @@ def upload_csv_rows(state: dict, hf_api, safe_worker: str, submission_id: str):
383
  demographics = state.get("demographics", {})
384
  products = state.get("products", [])
385
  header = [
386
- "submission_id", "worker_id", "submission_time", "duration_seconds", "category",
 
387
  "age", "gender", "geographic_region", "education_level", "race",
388
  "us_citizen", "marital_status", "religion", "religious_attendance",
389
  "political_affiliation", "income", "political_views", "household_size", "employment_status",
@@ -400,10 +495,14 @@ def upload_csv_rows(state: dict, hf_api, safe_worker: str, submission_id: str):
400
  post = prod.get("post_willingness", "")
401
  delta = (post - pre) if isinstance(pre, int) and isinstance(post, int) else ""
402
  row = [
403
- submission_id, state.get("worker_id", ""),
 
 
 
404
  state.get("meta", {}).get("submission_time", ""),
405
  state.get("meta", {}).get("duration_seconds", ""),
406
- CATEGORY,
 
407
  demographics.get("age", ""), demographics.get("gender", ""),
408
  demographics.get("geographic_region", ""), demographics.get("education_level", ""),
409
  demographics.get("race", ""), demographics.get("us_citizen", ""),
@@ -491,8 +590,9 @@ def parse_willingness(choice_str: str) -> int:
491
  return 4
492
 
493
 
494
- def get_familiarity_choices():
495
- used_label = FAMILIARITY_USED_LABEL.get(CATEGORY, "Used it before")
 
496
  return [
497
  "Never heard of it",
498
  "Heard of it, but not used/purchased",
@@ -502,7 +602,6 @@ def get_familiarity_choices():
502
 
503
 
504
  def needs_swap(familiarity_val: str, pre_will_val: str) -> bool:
505
- """Return True if this product should be swapped out."""
506
  if familiarity_val in SWAP_FAMILIARITY:
507
  return True
508
  if pre_will_val == WILLINGNESS_CHOICES[-1]: # "Definitely would buy (7)"
@@ -510,6 +609,26 @@ def needs_swap(familiarity_val: str, pre_will_val: str) -> bool:
510
  return False
511
 
512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
  # ---------------------------------------------------------------------------
514
  # State initialisation
515
  # ---------------------------------------------------------------------------
@@ -520,6 +639,7 @@ def make_product_slot(p: dict, was_swapped: bool = False) -> dict:
520
  "description": p.get("description", []),
521
  "features": p.get("features", []),
522
  "price": p.get("price", "N/A"),
 
523
  "familiarity": None,
524
  "pre_willingness": None,
525
  "post_willingness": None,
@@ -536,7 +656,7 @@ def make_product_slot(p: dict, was_swapped: bool = False) -> dict:
536
 
537
 
538
  def init_state():
539
- download_and_cache_dataset()
540
  assigned = assign_products(PRODUCTS_PER_USER)
541
 
542
  try:
@@ -547,12 +667,12 @@ def init_state():
547
  return {
548
  "submission_id": str(uuid.uuid4()),
549
  "user_id": str(uuid.uuid4()),
550
- "worker_id": params.get("workerId", ""),
551
- "assignment_id": params.get("assignmentId", ""),
552
- "hit_id": params.get("hitId", ""),
553
- "turk_submit_to": params.get("turkSubmitTo", ""),
554
  "start_time": time.time(),
555
- "category": CATEGORY,
 
556
  "demographics": {},
557
  "products": [make_product_slot(p) for p in assigned],
558
  "current_product_index": 0,
@@ -586,6 +706,14 @@ def inject_css():
586
  }
587
  .pc-title { font-size: 1.05rem; font-weight: 700; color: #1a1a2e; line-height: 1.35; flex: 1; }
588
  .pc-price { font-size: 1.2rem; font-weight: 800; color: #16a34a; white-space: nowrap; }
 
 
 
 
 
 
 
 
589
  .pc-section { margin-top: 0.5rem; }
590
  .pc-section-title {
591
  font-weight: 600; font-size: 0.85rem; color: #475569;
@@ -616,15 +744,20 @@ def render_product_card_html(product: dict, compact: bool = False) -> str:
616
  price = product.get("price", "N/A")
617
  description = product.get("description", [])
618
  features = product.get("features", [])
 
619
  price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price
620
 
621
- # Description: joined with spaces as prose
 
 
 
 
 
622
  desc_html = ""
623
  if description:
624
  desc_text = " ".join(d for d in description if d)
625
  desc_html = f'<div class="pc-section"><div class="pc-section-title">📋 Description</div><div class="pc-desc">{desc_text}</div></div>'
626
 
627
- # Features: bullet points
628
  feat_html = ""
629
  if features:
630
  items_html = "".join(f"<li>{feat}</li>" for feat in features if feat)
@@ -633,6 +766,7 @@ def render_product_card_html(product: dict, compact: bool = False) -> str:
633
  max_h = "max-height:240px;overflow-y:auto;" if compact else ""
634
  return f"""
635
  <div class="product-card" style="{max_h}">
 
636
  <div class="pc-header">
637
  <div class="pc-title">{title}</div>
638
  <div class="pc-price">{price_str}</div>
@@ -668,8 +802,11 @@ def render_chat_history(turns: list):
668
  # ---------------------------------------------------------------------------
669
  def screen_welcome(s):
670
  st.markdown("# 🛒 Product Evaluation Study")
 
671
  st.markdown(
672
- f"Welcome! In this study you will evaluate **{PRODUCTS_PER_USER} {CATEGORY_DISPLAY[CATEGORY]}** products.\n\n"
 
 
673
  "For each product you will:\n"
674
  "1. Rate how familiar you are with the product\n"
675
  "2. Rate how willing you are to buy it\n"
@@ -755,14 +892,19 @@ def screen_demographics(s):
755
  def screen_product_intro(s):
756
  idx = s["current_product_index"]
757
  product = s["products"][idx]
 
 
758
  render_progress(idx + 1)
759
  st.markdown("## Product Evaluation")
760
  st.markdown("Please read the product information carefully, then answer the two questions below.")
761
  st.markdown(render_product_card_html(product), unsafe_allow_html=True)
762
 
 
 
 
763
  familiarity_val = st.radio(
764
  "How familiar are you with this product?",
765
- get_familiarity_choices(),
766
  index=None,
767
  key=f"familiarity_{idx}_{product['id']}",
768
  )
@@ -782,21 +924,20 @@ def screen_product_intro(s):
782
  st.error("⚠️ Please rate your willingness to buy.")
783
  return
784
 
785
- familiarity_val = familiarity_val or get_familiarity_choices()[0]
786
  pre_will_val = pre_will_val or WILLINGNESS_CHOICES[3]
787
 
788
  # Check if we need to swap this product
789
  if needs_swap(familiarity_val, pre_will_val) and not DEBUG_MODE:
790
  current_ids = {p["id"] for p in s["products"]}
791
- replacement = get_swap_product(exclude_ids=current_ids)
792
  if replacement:
793
- # Return the rejected product to the queue so it gets reviewed by someone else
794
  return_product_to_queue(s["products"][idx])
795
  s["products"][idx] = make_product_slot(replacement, was_swapped=True)
796
  st.info("We've swapped this product for a better match. Please review the new product below.")
797
  st.rerun()
798
  return
799
- # If no replacement found, proceed anyway
800
 
801
  pre_val = parse_willingness(pre_will_val)
802
  s["products"][idx]["familiarity"] = familiarity_val
@@ -977,7 +1118,8 @@ def screen_reflection(s):
977
  "submission_time": end_time,
978
  "duration_seconds": round(end_time - s.get("start_time", end_time), 1),
979
  "model": MODEL_NAME,
980
- "category": CATEGORY,
 
981
  }
982
  with st.spinner("Saving your responses…"):
983
  save_and_upload(s)
@@ -998,9 +1140,11 @@ def screen_done(s):
998
  post = p.get("post_willingness", "?")
999
  delta = p.get("willingness_delta", 0)
1000
  arrow = "➡️" if delta == 0 else ("⬆️" if delta > 0 else "⬇️")
 
1001
  rows.append({
1002
  "#": i + 1,
1003
- "Product": p.get("title", "")[:60] + ("…" if len(p.get("title", "")) > 60 else ""),
 
1004
  "Before": WILLINGNESS_LABELS.get(pre, str(pre)),
1005
  "After": WILLINGNESS_LABELS.get(post, str(post)),
1006
  "Change": f"{arrow} {delta:+d}" if isinstance(delta, int) else "–",
@@ -1008,22 +1152,20 @@ def screen_done(s):
1008
  import pandas as pd
1009
  st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
1010
 
1011
- assignment_id = s.get("assignment_id", "")
1012
- turk_submit_to = s.get("turk_submit_to", "")
1013
- if assignment_id and turk_submit_to:
1014
- submit_url = f"{turk_submit_to}/mturk/externalSubmit"
1015
- submission_id = s.get("submission_id", "")
1016
- st.markdown(f"""
1017
- <form id="mturk-submit-form" method="POST" action="{submit_url}">
1018
- <input type="hidden" name="assignmentId" value="{assignment_id}" />
1019
- <input type="hidden" name="submission_id" value="{submission_id}" />
1020
- <button type="submit" style="
1021
- background:#2563eb; color:white; border:none; padding:12px 28px;
1022
- font-size:1rem; border-radius:6px; cursor:pointer; margin-top:12px;">
1023
- ✅ Submit to MTurk
1024
- </button>
1025
- </form>
1026
- """, unsafe_allow_html=True)
1027
 
1028
 
1029
  # ---------------------------------------------------------------------------
 
1
  """
2
  Streamlit App: AI Product Willingness User Study
3
  =================================================
4
+ Run locally (single category):
5
  streamlit run src/streamlit_app.py -- --category groceries
6
  streamlit run src/streamlit_app.py -- --category groceries --debug
7
 
8
+ Run locally (mixed mode — movies + groceries):
9
+ streamlit run src/streamlit_app.py -- --mode mixed
10
+ streamlit run src/streamlit_app.py -- --mode mixed --debug
11
+
12
  On HuggingFace Spaces, set these environment variables in Space Settings → Variables:
13
  HF_TOKEN - HuggingFace token
14
  TINKER_API_KEY - Tinker AI API key
 
15
  DATASET_REPO_ID - HuggingFace dataset repo to upload results
16
+ CATEGORY - groceries | books | movies | health (single-category mode)
17
+ MODE - mixed (overrides CATEGORY; runs movies + groceries together)
18
  DEBUG_MODE - "true" to skip validation (optional)
19
  """
20
 
 
 
21
  import csv
22
  import json
23
  import os
 
34
  from dotenv import load_dotenv
35
  from filelock import FileLock
36
  from huggingface_hub import HfApi
 
37
 
38
  load_dotenv()
39
 
40
  # ---------------------------------------------------------------------------
41
+ # CLI args
42
  # ---------------------------------------------------------------------------
43
  import argparse
44
  parser = argparse.ArgumentParser(add_help=False)
45
  parser.add_argument("--category", choices=["books", "groceries", "movies", "health"], default=None)
46
+ parser.add_argument("--mode", choices=["mixed"], default=None)
47
  parser.add_argument("--debug", action="store_true", default=False)
48
  cli_args, _ = parser.parse_known_args()
49
 
50
  # ---------------------------------------------------------------------------
51
+ # Config
52
  # ---------------------------------------------------------------------------
53
+ MODE = os.getenv("MODE") or cli_args.mode # "mixed" or None
54
+ CATEGORY = os.getenv("CATEGORY") or cli_args.category or "groceries" # used only in single-category mode
55
  DEBUG_MODE = os.getenv("DEBUG_MODE", "").lower() == "true" or cli_args.debug
56
  DATASET_REPO_ID = os.getenv("DATASET_REPO_ID", "your-username/product-study")
57
  HF_TOKEN = os.getenv("HF_TOKEN")
58
 
59
+ TINKER_API_KEY = os.getenv("TINKER_API_KEY")
60
+ MODEL_NAME = "openai/gpt-oss-20b"
61
+
62
+ # ---------------------------------------------------------------------------
63
+ # Mixed-mode constants
64
+ # ---------------------------------------------------------------------------
65
+ # In mixed mode these two categories are always used together
66
+ MIXED_CATEGORIES = ["movies", "groceries"]
67
+ # Each category contributes this many items to the shared pool of 100
68
+ MIXED_SUBSET_SIZE = 50 # 50 movies + 50 groceries = 100 total
69
+ SINGLE_SUBSET_SIZE = 100 # legacy single-category mode
70
+
71
+ # ---------------------------------------------------------------------------
72
+ # Prolific config
73
+ # ---------------------------------------------------------------------------
74
+ PROLIFIC_COMPLETION_URL = "https://app.prolific.com/submissions/complete?cc=CYC7ALM1"
75
+ PROLIFIC_COMPLETION_CODE = "CYC7ALM1"
76
 
77
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
78
  DATA_DIR = os.path.join(BASE_DIR, "data")
 
81
  os.makedirs(ANNOTATIONS_DIR, exist_ok=True)
82
 
83
  CATEGORY_TO_HF = {
84
+ "books": "ehejin/amazon_books",
85
  "groceries": "ehejin/amazon_Grocery_and_Gourmet_Food",
86
+ "movies": "ehejin/amazon_Movies_and_TV",
87
+ "health": "ehejin/amazon_Health_and_Household",
88
  }
89
  CATEGORY_DISPLAY = {
90
+ "books": "Books",
91
  "groceries": "Grocery Products",
92
+ "movies": "Movies & TV",
93
+ "health": "Health & Household Products",
94
  }
95
+ # Per-product familiarity label (depends on the individual product's category)
96
  FAMILIARITY_USED_LABEL = {
97
+ "books": "Read it before",
98
+ "movies": "Watched it before",
99
  "groceries": "Used it before",
100
+ "health": "Used it before",
101
  }
102
 
103
  PRODUCTS_PER_USER = 5
104
  MIN_TURNS = 3
105
  MAX_TURNS = 10
 
106
 
107
  # Familiarity values that trigger a product swap
108
  SWAP_FAMILIARITY = {"Purchased it before"}
 
128
  }
129
  WILLINGNESS_CHOICES = [f"{v} ({k})" for k, v in WILLINGNESS_LABELS.items()]
130
 
131
+
132
  # ---------------------------------------------------------------------------
133
+ # Helpers: per-category file paths
134
  # ---------------------------------------------------------------------------
135
+ def _data_path(category: str, suffix: str) -> str:
136
+ subset = MIXED_SUBSET_SIZE if MODE == "mixed" else SINGLE_SUBSET_SIZE
137
+ return os.path.join(DATA_DIR, f"{category}_test{subset}_{suffix}")
138
+
139
+
140
+ def local_data_path(category: str) -> str:
141
+ return _data_path(category, "primary.json")
142
 
143
+ def overflow_path(category: str) -> str:
144
+ return _data_path(category, "overflow.json")
145
 
146
+ def counter_path(category: str) -> str:
147
+ return _data_path(category, "counter.txt")
148
+
149
+ def counter_lock_path(category: str) -> str:
150
+ return _data_path(category, "counter.lock")
151
+
152
+ def return_queue_path(category: str) -> str:
153
+ return _data_path(category, "return_queue.json")
154
+
155
+
156
+ # ---------------------------------------------------------------------------
157
+ # Dataset loading
158
+ # ---------------------------------------------------------------------------
159
  @st.cache_resource
160
+ def download_and_cache_dataset(category: str, subset_size: int):
161
+ """Download test split from HuggingFace and cache locally."""
162
+ primary_path = local_data_path(category)
163
+ over_path = overflow_path(category)
164
+ if os.path.exists(primary_path):
165
+ print(f"[DATA] Found cached dataset for {category} at {primary_path}")
166
  return
167
+ print(f"[DATA] Downloading {CATEGORY_TO_HF[category]} (test split, first {subset_size}) from HuggingFace...")
168
  try:
169
  from datasets import load_dataset
170
  import huggingface_hub
171
  if HF_TOKEN:
172
  huggingface_hub.login(token=HF_TOKEN)
173
 
174
+ ds = load_dataset(CATEGORY_TO_HF[category], split="test")
175
 
176
  def to_list(val):
177
  if isinstance(val, list): return val
 
187
  "description": to_list(meta.get("description", []) if isinstance(meta, dict) else []),
188
  "features": to_list(meta.get("features", []) if isinstance(meta, dict) else []),
189
  "price": meta.get("price", "N/A") if isinstance(meta, dict) else "N/A",
190
+ "category": category,
191
  }
192
  all_items.append(item)
193
 
194
+ primary = all_items[:subset_size]
195
+ overflow = all_items[subset_size:]
 
196
 
197
+ with open(primary_path, "w") as f:
198
  json.dump(primary, f, indent=2)
199
+ with open(over_path, "w") as f:
200
  json.dump(overflow, f, indent=2)
201
 
202
+ print(f"[DATA] {category}: cached {len(primary)} primary + {len(overflow)} overflow items.")
203
  except Exception as e:
204
+ print(f"[DATA] ERROR downloading {category}: {e}")
205
  raise
206
 
207
 
208
  @st.cache_resource
209
+ def load_primary_dataset(category: str):
210
+ with open(local_data_path(category), "r") as f:
211
  return json.load(f)
212
 
213
 
214
  @st.cache_resource
215
+ def load_overflow_dataset(category: str):
216
+ path = overflow_path(category)
217
+ if not os.path.exists(path):
218
  return []
219
+ with open(path, "r") as f:
220
  return json.load(f)
221
 
222
 
223
+ def _ensure_datasets():
224
+ """Download/cache all needed category datasets."""
225
+ if MODE == "mixed":
226
+ for cat in MIXED_CATEGORIES:
227
+ download_and_cache_dataset(cat, MIXED_SUBSET_SIZE)
228
+ else:
229
+ download_and_cache_dataset(CATEGORY, SINGLE_SUBSET_SIZE)
230
+
231
+
232
+ # ---------------------------------------------------------------------------
233
+ # Per-category counter helpers
234
+ # ---------------------------------------------------------------------------
235
+ def _read_counter(category: str) -> int:
236
+ path = counter_path(category)
237
+ if not os.path.exists(path):
238
+ return 0
239
+ with open(path, "r") as f:
240
+ return int(f.read().strip() or "0")
241
+
242
+
243
+ def _write_counter(category: str, value: int):
244
+ with open(counter_path(category), "w") as f:
245
+ f.write(str(value))
246
+
247
+
248
+ def _read_return_queue(category: str) -> list:
249
+ path = return_queue_path(category)
250
+ if not os.path.exists(path):
251
+ return []
252
+ with open(path, "r") as f:
253
+ try:
254
+ return json.load(f)
255
+ except Exception:
256
+ return []
257
+
258
+
259
+ def _write_return_queue(category: str, queue: list):
260
+ with open(return_queue_path(category), "w") as f:
261
+ json.dump(queue, f)
262
+
263
+
264
+ # ---------------------------------------------------------------------------
265
+ # Product assignment
266
+ # ---------------------------------------------------------------------------
267
+ def _assign_from_category(category: str, n: int) -> list:
268
  """
269
+ Atomically assign n products from a single category pool.
270
+ - Drains the return queue first.
271
+ - Pulls sequentially from the primary pool.
272
+ - Wraps around (modulo pool size) when exhausted so user 21+ still get valid items.
273
  """
274
+ items = load_primary_dataset(category)
275
  total = len(items)
276
+ lock = FileLock(counter_lock_path(category))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
+ with lock:
279
+ return_queue = _read_return_queue(category)
280
+ counter = _read_counter(category)
281
  assigned = []
282
+
283
  for _ in range(n):
284
  if return_queue:
 
285
  assigned.append(return_queue.pop(0))
 
 
 
286
  else:
287
+ # Wrap-around: counter mod total so we cycle through items
288
+ assigned.append(items[counter % total])
289
+ counter += 1
290
+
291
+ _write_return_queue(category, return_queue)
292
+ _write_counter(category, counter)
 
 
 
 
 
293
 
294
  return assigned
295
 
296
 
297
+ def assign_mixed_products(n: int = PRODUCTS_PER_USER) -> list:
298
  """
299
+ Assign n products split across movies and groceries.
300
+ Alternates the majority category each call so coverage stays balanced.
301
+
302
+ User 1: 3 movies + 2 groceries
303
+ User 2: 2 movies + 3 groceries
304
+ User 3: 3 movies + 2 groceries ... etc.
305
+
306
+ The split is decided by reading the movies counter parity (even → movies gets 3).
307
  """
308
+ movies_counter = _read_counter("movies")
309
+ # Even call-count → movies gets the larger share
310
+ if (movies_counter // 1) % 2 == 0:
311
+ n_movies, n_groceries = 3, 2
312
+ else:
313
+ n_movies, n_groceries = 2, 3
314
+
315
+ # Clamp in case n != 5
316
+ if n_movies + n_groceries != n:
317
+ n_movies = n // 2
318
+ n_groceries = n - n_movies
319
+
320
+ movie_items = _assign_from_category("movies", n_movies)
321
+ grocery_items = _assign_from_category("groceries", n_groceries)
322
+
323
+ combined = movie_items + grocery_items
324
+ random.shuffle(combined) # mix so user doesn't see all movies then all groceries
325
+ return combined
326
+
327
+
328
+ def assign_products(n: int = PRODUCTS_PER_USER) -> list:
329
+ """Dispatcher: mixed mode or single-category mode."""
330
+ if MODE == "mixed":
331
+ return assign_mixed_products(n)
332
+ # Single-category (legacy behaviour)
333
+ return _assign_from_category(CATEGORY, n)
334
+
335
+
336
+ def return_product_to_queue(product: dict):
337
+ """Put a rejected/swapped product back so it gets reassigned."""
338
+ cat = product.get("category", CATEGORY)
339
+ lock = FileLock(counter_lock_path(cat))
340
  with lock:
341
+ queue = _read_return_queue(cat)
 
 
 
 
 
 
 
342
  if not any(p["id"] == product["id"] for p in queue):
343
  queue.append(product)
344
+ _write_return_queue(cat, queue)
 
345
 
346
 
347
+ def get_swap_product(exclude_ids: set, category: str) -> dict | None:
348
  """
349
+ Get a replacement product for the given category.
350
+ 1. Next unassigned primary product (advances counter).
351
+ 2. Wrap-around: any primary product not held by this user.
352
+ 3. Overflow pool.
353
  """
354
+ items = load_primary_dataset(category)
355
+ overflow = load_overflow_dataset(category)
356
+ total = len(items)
357
 
358
+ lock = FileLock(counter_lock_path(category))
359
  with lock:
360
+ counter = _read_counter(category)
361
+
362
+ # 1. Unassigned (with wrap-around awareness)
363
+ attempts = 0
364
+ while attempts < total:
365
+ candidate = items[counter % total]
 
 
366
  counter += 1
367
+ attempts += 1
368
  if candidate["id"] not in exclude_ids:
369
+ _write_counter(category, counter)
 
 
370
  return candidate
371
 
372
+ # 2. Any primary product not held by this user
373
  for p in items:
374
  if p["id"] not in exclude_ids:
375
  return p
376
 
377
+ # 3. Overflow
378
  for p in overflow:
379
  if p["id"] not in exclude_ids:
380
  return p
381
 
382
+ return None
383
 
384
 
385
  # ---------------------------------------------------------------------------
386
+ # AI client (Tinker)
387
  # ---------------------------------------------------------------------------
388
  @st.cache_resource
389
+ def get_tinker_clients():
390
+ """Initialise and cache Tinker sampling client, renderer, and tokenizer."""
391
+ import tinker
392
+ from tinker import types as tinker_types
393
+ from tinker_cookbook import renderers
394
+ from tinker_cookbook.tokenizer_utils import get_tokenizer
395
+ from tinker_cookbook.model_info import get_recommended_renderer_name
396
+
397
+ service_client = tinker.ServiceClient()
398
+ sampling_client = service_client.create_sampling_client(base_model=MODEL_NAME)
399
+ tokenizer = get_tokenizer(MODEL_NAME)
400
+ renderer_name = get_recommended_renderer_name(MODEL_NAME)
401
+ renderer = renderers.get_renderer(renderer_name, tokenizer)
402
+ return sampling_client, renderer, tinker_types
403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
 
405
+ def call_model(messages: list) -> str:
406
+ try:
407
+ from tinker_cookbook import renderers as tinker_renderers
408
+ sampling_client, renderer, tinker_types = get_tinker_clients()
409
+
410
+ prompt = renderer.build_generation_prompt(messages)
411
+ params = tinker_types.SamplingParams(
412
+ max_tokens=1000,
413
+ temperature=0.7,
414
+ stop=renderer.get_stop_sequences(),
415
+ )
416
+ result = sampling_client.sample(
417
+ prompt=prompt,
418
+ sampling_params=params,
419
+ num_samples=1,
420
+ ).result()
421
+ parsed_message, _ = renderer.parse_response(result.sequences[0].tokens)
422
+ content = tinker_renderers.format_content_as_string(parsed_message["content"])
423
+ content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
424
+ return content
425
+ except Exception as e:
426
+ print(f"[MODEL] Tinker error: {e}")
427
+ return f"[Model error: {e}]"
428
 
429
 
430
  # ---------------------------------------------------------------------------
 
448
 
449
  def save_and_upload(state: dict):
450
  hf_api = get_hf_api()
451
+ worker_id = state.get("prolific_pid") or state.get("user_id", "anonymous")
452
  submission_id = state.get("submission_id", str(uuid.uuid4()))
453
  safe_worker = "".join(c if c.isalnum() else "_" for c in str(worker_id))
454
+ mode_tag = state.get("mode", "single")
455
+ filename = f"{submission_id}_{mode_tag}.json"
456
  folder = os.path.join(ANNOTATIONS_DIR, safe_worker)
457
  os.makedirs(folder, exist_ok=True)
458
  file_path = os.path.join(folder, filename)
 
477
  demographics = state.get("demographics", {})
478
  products = state.get("products", [])
479
  header = [
480
+ "submission_id", "prolific_pid", "study_id", "session_id",
481
+ "submission_time", "duration_seconds", "mode", "category",
482
  "age", "gender", "geographic_region", "education_level", "race",
483
  "us_citizen", "marital_status", "religion", "religious_attendance",
484
  "political_affiliation", "income", "political_views", "household_size", "employment_status",
 
495
  post = prod.get("post_willingness", "")
496
  delta = (post - pre) if isinstance(pre, int) and isinstance(post, int) else ""
497
  row = [
498
+ submission_id,
499
+ state.get("prolific_pid", ""),
500
+ state.get("study_id", ""),
501
+ state.get("session_id", ""),
502
  state.get("meta", {}).get("submission_time", ""),
503
  state.get("meta", {}).get("duration_seconds", ""),
504
+ state.get("mode", "single"),
505
+ prod.get("category", ""), # per-product category
506
  demographics.get("age", ""), demographics.get("gender", ""),
507
  demographics.get("geographic_region", ""), demographics.get("education_level", ""),
508
  demographics.get("race", ""), demographics.get("us_citizen", ""),
 
590
  return 4
591
 
592
 
593
+ def get_familiarity_choices(category: str) -> list:
594
+ """Return familiarity options with the correct 'used' label for this product's category."""
595
+ used_label = FAMILIARITY_USED_LABEL.get(category, "Used it before")
596
  return [
597
  "Never heard of it",
598
  "Heard of it, but not used/purchased",
 
602
 
603
 
604
  def needs_swap(familiarity_val: str, pre_will_val: str) -> bool:
 
605
  if familiarity_val in SWAP_FAMILIARITY:
606
  return True
607
  if pre_will_val == WILLINGNESS_CHOICES[-1]: # "Definitely would buy (7)"
 
609
  return False
610
 
611
 
612
+ # ---------------------------------------------------------------------------
613
+ # Welcome screen helpers
614
+ # ---------------------------------------------------------------------------
615
+ def study_display_name() -> str:
616
+ """Human-readable name for what the user will evaluate."""
617
+ if MODE == "mixed":
618
+ return "Movies & TV and Grocery Products"
619
+ return CATEGORY_DISPLAY.get(CATEGORY, CATEGORY)
620
+
621
+
622
+ def study_category_breakdown() -> str:
623
+ """Extra sentence shown on welcome screen describing the mix."""
624
+ if MODE == "mixed":
625
+ return (
626
+ "You will evaluate a mix of **Movies & TV** and **Grocery Products** "
627
+ "(roughly 2–3 of each)."
628
+ )
629
+ return ""
630
+
631
+
632
  # ---------------------------------------------------------------------------
633
  # State initialisation
634
  # ---------------------------------------------------------------------------
 
639
  "description": p.get("description", []),
640
  "features": p.get("features", []),
641
  "price": p.get("price", "N/A"),
642
+ "category": p.get("category", CATEGORY), # ← per-product category
643
  "familiarity": None,
644
  "pre_willingness": None,
645
  "post_willingness": None,
 
656
 
657
 
658
  def init_state():
659
+ _ensure_datasets()
660
  assigned = assign_products(PRODUCTS_PER_USER)
661
 
662
  try:
 
667
  return {
668
  "submission_id": str(uuid.uuid4()),
669
  "user_id": str(uuid.uuid4()),
670
+ "prolific_pid": params.get("PROLIFIC_PID", ""),
671
+ "study_id": params.get("STUDY_ID", ""),
672
+ "session_id": params.get("SESSION_ID", ""),
 
673
  "start_time": time.time(),
674
+ "mode": MODE or "single",
675
+ "category": CATEGORY if MODE != "mixed" else "mixed",
676
  "demographics": {},
677
  "products": [make_product_slot(p) for p in assigned],
678
  "current_product_index": 0,
 
706
  }
707
  .pc-title { font-size: 1.05rem; font-weight: 700; color: #1a1a2e; line-height: 1.35; flex: 1; }
708
  .pc-price { font-size: 1.2rem; font-weight: 800; color: #16a34a; white-space: nowrap; }
709
+ .pc-category-badge {
710
+ display: inline-block;
711
+ font-size: 0.75rem; font-weight: 600;
712
+ padding: 0.15rem 0.55rem;
713
+ border-radius: 99px;
714
+ margin-bottom: 0.4rem;
715
+ background: #dbeafe; color: #1e40af;
716
+ }
717
  .pc-section { margin-top: 0.5rem; }
718
  .pc-section-title {
719
  font-weight: 600; font-size: 0.85rem; color: #475569;
 
744
  price = product.get("price", "N/A")
745
  description = product.get("description", [])
746
  features = product.get("features", [])
747
+ category = product.get("category", "")
748
  price_str = f"${price}" if price and price != "N/A" and not str(price).startswith("$") else price
749
 
750
+ # Category badge only shown in mixed mode
751
+ badge_html = ""
752
+ if MODE == "mixed" and category:
753
+ badge_label = CATEGORY_DISPLAY.get(category, category)
754
+ badge_html = f'<div class="pc-category-badge">📂 {badge_label}</div>'
755
+
756
  desc_html = ""
757
  if description:
758
  desc_text = " ".join(d for d in description if d)
759
  desc_html = f'<div class="pc-section"><div class="pc-section-title">📋 Description</div><div class="pc-desc">{desc_text}</div></div>'
760
 
 
761
  feat_html = ""
762
  if features:
763
  items_html = "".join(f"<li>{feat}</li>" for feat in features if feat)
 
766
  max_h = "max-height:240px;overflow-y:auto;" if compact else ""
767
  return f"""
768
  <div class="product-card" style="{max_h}">
769
+ {badge_html}
770
  <div class="pc-header">
771
  <div class="pc-title">{title}</div>
772
  <div class="pc-price">{price_str}</div>
 
802
  # ---------------------------------------------------------------------------
803
  def screen_welcome(s):
804
  st.markdown("# 🛒 Product Evaluation Study")
805
+ breakdown = study_category_breakdown()
806
  st.markdown(
807
+ f"Welcome! In this study you will evaluate **{PRODUCTS_PER_USER} {study_display_name()}** products.\n\n"
808
+ + (f"{breakdown}\n\n" if breakdown else "")
809
+ +
810
  "For each product you will:\n"
811
  "1. Rate how familiar you are with the product\n"
812
  "2. Rate how willing you are to buy it\n"
 
892
  def screen_product_intro(s):
893
  idx = s["current_product_index"]
894
  product = s["products"][idx]
895
+ product_category = product.get("category", CATEGORY)
896
+
897
  render_progress(idx + 1)
898
  st.markdown("## Product Evaluation")
899
  st.markdown("Please read the product information carefully, then answer the two questions below.")
900
  st.markdown(render_product_card_html(product), unsafe_allow_html=True)
901
 
902
+ # Use per-product familiarity choices based on the product's own category
903
+ familiarity_choices = get_familiarity_choices(product_category)
904
+
905
  familiarity_val = st.radio(
906
  "How familiar are you with this product?",
907
+ familiarity_choices,
908
  index=None,
909
  key=f"familiarity_{idx}_{product['id']}",
910
  )
 
924
  st.error("⚠️ Please rate your willingness to buy.")
925
  return
926
 
927
+ familiarity_val = familiarity_val or familiarity_choices[0]
928
  pre_will_val = pre_will_val or WILLINGNESS_CHOICES[3]
929
 
930
  # Check if we need to swap this product
931
  if needs_swap(familiarity_val, pre_will_val) and not DEBUG_MODE:
932
  current_ids = {p["id"] for p in s["products"]}
933
+ replacement = get_swap_product(exclude_ids=current_ids, category=product_category)
934
  if replacement:
 
935
  return_product_to_queue(s["products"][idx])
936
  s["products"][idx] = make_product_slot(replacement, was_swapped=True)
937
  st.info("We've swapped this product for a better match. Please review the new product below.")
938
  st.rerun()
939
  return
940
+ # No replacement found proceed with this product anyway
941
 
942
  pre_val = parse_willingness(pre_will_val)
943
  s["products"][idx]["familiarity"] = familiarity_val
 
1118
  "submission_time": end_time,
1119
  "duration_seconds": round(end_time - s.get("start_time", end_time), 1),
1120
  "model": MODEL_NAME,
1121
+ "mode": MODE or "single",
1122
+ "category": CATEGORY if MODE != "mixed" else "mixed",
1123
  }
1124
  with st.spinner("Saving your responses…"):
1125
  save_and_upload(s)
 
1140
  post = p.get("post_willingness", "?")
1141
  delta = p.get("willingness_delta", 0)
1142
  arrow = "➡️" if delta == 0 else ("⬆️" if delta > 0 else "⬇️")
1143
+ cat_label = CATEGORY_DISPLAY.get(p.get("category", ""), "") if MODE == "mixed" else ""
1144
  rows.append({
1145
  "#": i + 1,
1146
+ **({"Category": cat_label} if MODE == "mixed" else {}),
1147
+ "Product": p.get("title", "")[:55] + ("…" if len(p.get("title", "")) > 55 else ""),
1148
  "Before": WILLINGNESS_LABELS.get(pre, str(pre)),
1149
  "After": WILLINGNESS_LABELS.get(post, str(post)),
1150
  "Change": f"{arrow} {delta:+d}" if isinstance(delta, int) else "–",
 
1152
  import pandas as pd
1153
  st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
1154
 
1155
+ st.markdown("---")
1156
+ st.success(
1157
+ f"**Your completion code:** `{PROLIFIC_COMPLETION_CODE}`\n\n"
1158
+ "You can either click the button below to return to Prolific automatically, "
1159
+ "or copy the code above and paste it on the Prolific website."
1160
+ )
1161
+ st.markdown(
1162
+ f"""<a href="{PROLIFIC_COMPLETION_URL}" target="_self">
1163
+ <button style="background:#2563eb;color:white;border:none;padding:12px 28px;
1164
+ font-size:1rem;border-radius:6px;cursor:pointer;margin-top:8px;">
1165
+ Return to Prolific
1166
+ </button></a>""",
1167
+ unsafe_allow_html=True,
1168
+ )
 
 
1169
 
1170
 
1171
  # ---------------------------------------------------------------------------