j-js commited on
Commit
fa70564
·
verified ·
1 Parent(s): 0e6c031

Upload 11 files

Browse files
Files changed (11) hide show
  1. app.py +53 -84
  2. context_parser.py +52 -23
  3. conversation_logic.py +100 -74
  4. formatting.py +18 -34
  5. generator_engine.py +20 -168
  6. logging_store.py +60 -126
  7. models.py +39 -83
  8. quant_solver.py +116 -171
  9. retrieval_engine.py +81 -116
  10. ui_html.py +25 -56
  11. utils.py +81 -83
app.py CHANGED
@@ -1,27 +1,27 @@
1
  from __future__ import annotations
2
 
3
- import secrets
4
  from typing import Any, Dict
5
 
6
- from fastapi import FastAPI, Header, HTTPException, Request
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from fastapi.responses import HTMLResponse, JSONResponse
9
 
10
- from config import settings
11
- from context_parser import detect_help_mode
12
- from conversation_logic import generate_response
13
- from generator_engine import GenerativeEngine
14
  from logging_store import LoggingStore
15
  from models import ChatRequest, EventLogRequest, SessionFinalizeRequest, SessionStartRequest
16
  from retrieval_engine import RetrievalEngine
17
  from ui_html import HOME_HTML
18
- from utils import clamp01, flatten_history, get_user_text, parse_hidden_context, split_unity_message
19
 
20
- app = FastAPI(title=settings.app_name, version=settings.app_version)
21
  retriever = RetrievalEngine()
22
- generator = GenerativeEngine()
23
- logging_store = LoggingStore()
 
24
 
 
25
  app.add_middleware(
26
  CORSMiddleware,
27
  allow_origins=["*"],
@@ -31,14 +31,9 @@ app.add_middleware(
31
  )
32
 
33
 
34
- def _require_ingest_key(x_ingest_key: str | None) -> None:
35
- if settings.ingest_api_key and not secrets.compare_digest(x_ingest_key or "", settings.ingest_api_key):
36
- raise HTTPException(status_code=401, detail="Invalid ingest key")
37
-
38
-
39
- def _require_research_key(x_research_key: str | None) -> None:
40
- if settings.research_api_key and not secrets.compare_digest(x_research_key or "", settings.research_api_key):
41
- raise HTTPException(status_code=401, detail="Invalid research key")
42
 
43
 
44
  @app.get("/", response_class=HTMLResponse)
@@ -46,20 +41,6 @@ def home() -> str:
46
  return HOME_HTML
47
 
48
 
49
- @app.get("/health")
50
- def health() -> Dict[str, Any]:
51
- return {
52
- "ok": True,
53
- "app": settings.app_name,
54
- "version": settings.app_version,
55
- "retrieval_rows": len(retriever.rows),
56
- "generator_model": settings.generator_model,
57
- "generator_task": settings.generator_task,
58
- "hub_log_push_enabled": settings.push_logs_to_hub,
59
- "hub_log_repo": settings.log_dataset_repo_id,
60
- }
61
-
62
-
63
  @app.post("/chat")
64
  async def chat(request: Request) -> JSONResponse:
65
  raw_body: Any = None
@@ -76,86 +57,74 @@ async def chat(request: Request) -> JSONResponse:
76
 
77
  full_text = get_user_text(req, raw_body)
78
  hidden_context, actual_user_message = split_unity_message(full_text)
79
- game_context = parse_hidden_context(hidden_context)
80
- retrieval_query = actual_user_message or game_context.question_text or full_text
81
- retrieval_chunks = retriever.search(retrieval_query)
 
82
 
83
  tone = clamp01(req_data.get("tone", req.tone), 0.5)
84
  verbosity = clamp01(req_data.get("verbosity", req.verbosity), 0.5)
85
  transparency = clamp01(req_data.get("transparency", req.transparency), 0.5)
86
- help_mode = detect_help_mode(actual_user_message or full_text, req_data.get("help_mode", req.help_mode))
87
- chat_history = flatten_history(req_data.get("chat_history") or req_data.get("history") or req.chat_history or req.history)
88
 
89
- result = generate_response(
90
- user_text=actual_user_message or full_text,
 
 
 
91
  tone=tone,
92
  verbosity=verbosity,
93
  transparency=transparency,
 
94
  help_mode=help_mode,
95
- game_context=game_context,
96
- chat_history=chat_history,
97
- retrieval_chunks=retrieval_chunks,
98
- generator=generator,
99
  )
100
 
101
- return JSONResponse({"reply": result.reply, "meta": result.meta})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
 
104
  @app.post("/log/session/start")
105
- def start_session(
106
- req: SessionStartRequest,
107
- x_ingest_key: str | None = Header(default=None),
108
- ) -> Dict[str, Any]:
109
- _require_ingest_key(x_ingest_key)
110
- record = logging_store.start_session(req)
111
- return {"ok": True, "session_id": record.session_id, "record": record.model_dump()}
112
 
113
 
114
  @app.post("/log/event")
115
- def log_event(
116
- req: EventLogRequest,
117
- x_ingest_key: str | None = Header(default=None),
118
- ) -> Dict[str, Any]:
119
- _require_ingest_key(x_ingest_key)
120
- try:
121
- return logging_store.append_event(req)
122
- except FileNotFoundError as e:
123
- raise HTTPException(status_code=404, detail=str(e)) from e
124
 
125
 
126
  @app.post("/log/session/finalize")
127
- def finalize_session(
128
- req: SessionFinalizeRequest,
129
- x_ingest_key: str | None = Header(default=None),
130
- ) -> Dict[str, Any]:
131
- _require_ingest_key(x_ingest_key)
132
- try:
133
- return logging_store.finalize_session(req)
134
- except FileNotFoundError as e:
135
- raise HTTPException(status_code=404, detail=str(e)) from e
136
 
137
 
138
  @app.get("/research/sessions")
139
- def research_sessions(
140
- x_research_key: str | None = Header(default=None),
141
- ) -> Dict[str, Any]:
142
- _require_research_key(x_research_key)
143
- return {"ok": True, "sessions": logging_store.list_sessions()}
144
 
145
 
146
  @app.get("/research/session/{session_id}")
147
- def research_session(
148
- session_id: str,
149
- x_research_key: str | None = Header(default=None),
150
- ) -> Dict[str, Any]:
151
- _require_research_key(x_research_key)
152
- try:
153
- return {"ok": True, **logging_store.read_session_bundle(session_id)}
154
- except FileNotFoundError as e:
155
- raise HTTPException(status_code=404, detail=str(e)) from e
156
 
157
 
158
  if __name__ == "__main__":
159
  import uvicorn
160
 
161
- uvicorn.run(app, host="0.0.0.0", port=settings.port)
 
 
1
  from __future__ import annotations
2
 
3
+ import os
4
  from typing import Any, Dict
5
 
6
+ from fastapi import FastAPI, Request
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from fastapi.responses import HTMLResponse, JSONResponse
9
 
10
+ from context_parser import detect_intent, extract_game_context_fields, intent_to_help_mode, split_unity_message
11
+ from conversation_logic import ConversationEngine
12
+ from generator_engine import GeneratorEngine
 
13
  from logging_store import LoggingStore
14
  from models import ChatRequest, EventLogRequest, SessionFinalizeRequest, SessionStartRequest
15
  from retrieval_engine import RetrievalEngine
16
  from ui_html import HOME_HTML
17
+ from utils import clamp01, get_user_text
18
 
 
19
  retriever = RetrievalEngine()
20
+ generator = GeneratorEngine()
21
+ engine = ConversationEngine(retriever=retriever, generator=generator)
22
+ store = LoggingStore()
23
 
24
+ app = FastAPI(title="Trading Game AI V2", version="2.2.0")
25
  app.add_middleware(
26
  CORSMiddleware,
27
  allow_origins=["*"],
 
31
  )
32
 
33
 
34
+ @app.get("/health")
35
+ def health() -> Dict[str, Any]:
36
+ return {"ok": True, "app": "Trading Game AI V2", "generator_available": generator.available()}
 
 
 
 
 
37
 
38
 
39
  @app.get("/", response_class=HTMLResponse)
 
41
  return HOME_HTML
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  @app.post("/chat")
45
  async def chat(request: Request) -> JSONResponse:
46
  raw_body: Any = None
 
57
 
58
  full_text = get_user_text(req, raw_body)
59
  hidden_context, actual_user_message = split_unity_message(full_text)
60
+ game_fields = extract_game_context_fields(hidden_context)
61
+
62
+ question_text = (req.question_text or "").strip() or game_fields["question"]
63
+ options_text = game_fields["options"]
64
 
65
  tone = clamp01(req_data.get("tone", req.tone), 0.5)
66
  verbosity = clamp01(req_data.get("verbosity", req.verbosity), 0.5)
67
  transparency = clamp01(req_data.get("transparency", req.transparency), 0.5)
 
 
68
 
69
+ intent = detect_intent(actual_user_message or full_text, req_data.get("help_mode", req.help_mode))
70
+ help_mode = intent_to_help_mode(intent)
71
+
72
+ result = engine.generate_response(
73
+ raw_user_text=actual_user_message or full_text,
74
  tone=tone,
75
  verbosity=verbosity,
76
  transparency=transparency,
77
+ intent=intent,
78
  help_mode=help_mode,
79
+ chat_history=req.chat_history or req.history or [],
80
+ question_text=question_text,
81
+ options_text=options_text,
 
82
  )
83
 
84
+ return JSONResponse(
85
+ {
86
+ "reply": result.reply,
87
+ "meta": {
88
+ "domain": result.domain,
89
+ "solved": result.solved,
90
+ "help_mode": result.help_mode,
91
+ "answer_letter": result.answer_letter,
92
+ "answer_value": result.answer_value,
93
+ "topic": result.topic,
94
+ "used_retrieval": result.used_retrieval,
95
+ "used_generator": result.used_generator,
96
+ },
97
+ }
98
+ )
99
 
100
 
101
  @app.post("/log/session/start")
102
+ def log_session_start(payload: SessionStartRequest) -> Dict[str, Any]:
103
+ return store.start_session(payload.session_id, payload.user_id, payload.condition, payload.metadata)
 
 
 
 
 
104
 
105
 
106
  @app.post("/log/event")
107
+ def log_event(payload: EventLogRequest) -> Dict[str, Any]:
108
+ return store.log_event(payload.session_id, payload.event_type, payload.payload, payload.timestamp)
 
 
 
 
 
 
 
109
 
110
 
111
  @app.post("/log/session/finalize")
112
+ def log_session_finalize(payload: SessionFinalizeRequest) -> Dict[str, Any]:
113
+ return store.finalize_session(payload.session_id, payload.summary)
 
 
 
 
 
 
 
114
 
115
 
116
  @app.get("/research/sessions")
117
+ def research_sessions() -> Dict[str, Any]:
118
+ return {"sessions": store.list_sessions()}
 
 
 
119
 
120
 
121
  @app.get("/research/session/{session_id}")
122
+ def research_session(session_id: str) -> Dict[str, Any]:
123
+ return store.get_session(session_id)
 
 
 
 
 
 
 
124
 
125
 
126
  if __name__ == "__main__":
127
  import uvicorn
128
 
129
+ port = int(os.getenv("PORT", "7860"))
130
+ uvicorn.run(app, host="0.0.0.0", port=port)
context_parser.py CHANGED
@@ -1,36 +1,65 @@
1
  from __future__ import annotations
2
 
3
  import re
 
4
 
5
 
6
- HELP_MODES = {"hint", "guide", "walkthrough", "answer", "explain"}
 
 
 
 
 
 
 
 
 
7
 
8
 
 
 
 
 
9
 
10
- def detect_help_mode(user_text: str, explicit_help_mode: str | None = None) -> str:
11
- if explicit_help_mode and explicit_help_mode.lower() in HELP_MODES:
12
- mode = explicit_help_mode.lower()
13
- return "walkthrough" if mode == "guide" else mode
 
 
 
 
 
 
 
14
 
15
- text = (user_text or "").lower()
16
- if any(phrase in text for phrase in ["just a hint", "small hint", "nudge", "dont tell me"]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  return "hint"
18
- if any(phrase in text for phrase in ["walk me through", "step by step", "show the steps", "explain how"]):
19
- return "walkthrough"
20
- if any(phrase in text for phrase in ["what is the answer", "which option", "is it a or b", "final answer"]):
21
  return "answer"
22
- return "explain"
23
 
24
 
25
-
26
- def is_social_or_meta_message(text: str) -> bool:
27
- lower = (text or "").strip().lower()
28
- if not lower:
29
- return True
30
- social_patterns = [
31
- r"^(hi|hello|hey|thanks|thank you|ok|okay|cool|great|nice)\b",
32
- r"how are you",
33
- r"what can you do",
34
- r"who are you",
35
- ]
36
- return any(re.search(p, lower) for p in social_patterns)
 
1
  from __future__ import annotations
2
 
3
  import re
4
+ from typing import Dict, Optional
5
 
6
 
7
+ def split_unity_message(full_text: str) -> tuple[str, str]:
8
+ if not full_text:
9
+ return "", ""
10
+ marker = "USER_MESSAGE:"
11
+ idx = full_text.find(marker)
12
+ if idx == -1:
13
+ return "", full_text.strip()
14
+ hidden = full_text[:idx].strip()
15
+ user = full_text[idx + len(marker):].strip()
16
+ return hidden, user
17
 
18
 
19
+ def extract_game_context_fields(hidden_context: str) -> Dict[str, str]:
20
+ fields = {"category": "", "difficulty": "", "question": "", "options": ""}
21
+ if not hidden_context:
22
+ return fields
23
 
24
+ patterns = {
25
+ "category": r"Category:\s*(.+)",
26
+ "difficulty": r"Difficulty:\s*(.+)",
27
+ "question": r"Question:\s*(.+?)(?:\nOptions:|\nPlayer balance:|\nLast outcome:|$)",
28
+ "options": r"Options:\s*(.+?)(?:\nPlayer balance:|\nLast outcome:|$)",
29
+ }
30
+ for key, pattern in patterns.items():
31
+ m = re.search(pattern, hidden_context, re.DOTALL)
32
+ if m:
33
+ fields[key] = m.group(1).strip()
34
+ return fields
35
 
36
+
37
+ def detect_intent(user_text: str, supplied: Optional[str] = None) -> str:
38
+ lower = (user_text or "").strip().lower()
39
+
40
+ if supplied:
41
+ supplied = supplied.strip().lower()
42
+ if supplied in {"hint", "method", "walkthrough", "step_by_step", "full_working", "answer"}:
43
+ return supplied
44
+
45
+ if any(p in lower for p in ["full working", "full working out", "show all the working", "complete working"]):
46
+ return "full_working"
47
+ if any(p in lower for p in ["step by step", "walkthrough", "work through", "explain step by step"]):
48
+ return "step_by_step"
49
+ if any(p in lower for p in ["how do i solve", "how to solve", "method", "what method", "approach"]):
50
+ return "method"
51
+ if any(p in lower for p in ["hint", "nudge", "first step", "how do i start", "what do i do first"]):
52
  return "hint"
53
+ if any(p in lower for p in ["why", "explain", "breakdown"]):
54
+ return "step_by_step"
55
+ if any(p in lower for p in ["solve", "what is", "answer", "give me the answer"]):
56
  return "answer"
57
+ return "answer"
58
 
59
 
60
+ def intent_to_help_mode(intent: str) -> str:
61
+ if intent == "hint":
62
+ return "hint"
63
+ if intent in {"method", "step_by_step", "full_working"}:
64
+ return "walkthrough"
65
+ return "answer"
 
 
 
 
 
 
conversation_logic.py CHANGED
@@ -1,80 +1,106 @@
1
  from __future__ import annotations
2
 
3
- from typing import Any, Dict, List
4
 
5
- from context_parser import is_social_or_meta_message
6
- from formatting import build_guidance_text
7
- from generator_engine import GenerativeEngine
8
- from models import GameContext, ResponsePackage, RetrievalChunk, SolverResult
9
  from quant_solver import is_quant_question, solve_quant
10
- from safety import enforce_study_guardrails
11
- from utils import soft_truncate
12
-
13
-
14
- def _summarize_retrieval(chunks: List[RetrievalChunk]) -> str:
15
- if not chunks:
16
- return ""
17
- lines = ["Relevant study notes:"]
18
- for chunk in chunks[:3]:
19
- snippet = soft_truncate(chunk.text, 180)
20
- label = chunk.topic_guess or chunk.source_name or chunk.chunk_id
21
- lines.append(f"- {label}: {snippet}")
22
- return "\n".join(lines)
23
-
24
-
25
- def generate_response(
26
- *,
27
- user_text: str,
28
- tone: float,
29
- verbosity: float,
30
- transparency: float,
31
- help_mode: str,
32
- game_context: GameContext,
33
- chat_history: List[Dict[str, Any]],
34
- retrieval_chunks: List[RetrievalChunk],
35
- generator: GenerativeEngine,
36
- ) -> ResponsePackage:
37
- combined = "\n".join(part for part in [game_context.combined_question_block, user_text] if part.strip()).strip()
38
- retrieval_text = _summarize_retrieval(retrieval_chunks)
39
-
40
- if is_quant_question(combined):
41
- result: SolverResult = solve_quant(combined)
42
- reply = build_guidance_text(result, verbosity=verbosity, transparency=transparency, tone=tone)
43
- if retrieval_text and verbosity >= 0.5:
44
- reply = reply + "\n\n" + retrieval_text
45
- reply = enforce_study_guardrails(reply, result)
46
- return ResponsePackage(
47
- reply=reply,
48
- meta={
49
- "domain": result.domain,
50
- "solved": result.solved,
51
- "help_mode": help_mode,
52
- "answer_letter": None,
53
- "answer_value": None,
54
- "topic": result.detected_topic,
55
- "used_retrieval": bool(retrieval_chunks),
56
- "used_generator": False,
57
- },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  )
 
 
 
 
 
59
 
60
- reply = generator.generate(
61
- user_text=user_text,
62
- chat_history=chat_history,
63
- game_context=game_context,
64
- retrieval_summary=retrieval_text,
65
- tone=tone,
66
- verbosity=verbosity,
67
- transparency=transparency,
68
- )
69
- reply = enforce_study_guardrails(reply)
70
- return ResponsePackage(
71
- reply=reply,
72
- meta={
73
- "domain": "general",
74
- "solved": False,
75
- "help_mode": help_mode,
76
- "topic": "social" if is_social_or_meta_message(user_text) else "general",
77
- "used_retrieval": bool(retrieval_chunks),
78
- "used_generator": True,
79
- },
80
- )
 
1
  from __future__ import annotations
2
 
3
+ from typing import Any, Dict, List, Optional
4
 
5
+ from context_parser import intent_to_help_mode
6
+ from formatting import format_reply
7
+ from generator_engine import GeneratorEngine
8
+ from models import RetrievedChunk, SolverResult
9
  from quant_solver import is_quant_question, solve_quant
10
+ from retrieval_engine import RetrievalEngine
11
+ from utils import short_lines
12
+
13
+
14
+ def _teaching_lines(chunks: List[RetrievedChunk]) -> List[str]:
15
+ lines = []
16
+ for chunk in chunks:
17
+ text = chunk.text.strip().replace("\n", " ")
18
+ if len(text) > 220:
19
+ text = text[:217].rstrip() + "…"
20
+ lines.append(f"- {chunk.topic}: {text}")
21
+ return lines
22
+
23
+
24
+ def _compose_quant_reply(result: SolverResult, intent: str, reveal_answer: bool, verbosity: float) -> str:
25
+ steps = result.steps or []
26
+ internal = result.internal_answer or result.answer_value or ""
27
+
28
+ if intent == "hint":
29
+ return steps[0] if steps else "Start by translating the wording into an equation."
30
+
31
+ if intent == "method":
32
+ body = "Use this method:\n" + "\n".join(f"- {s}" for s in steps[:3])
33
+ if reveal_answer and internal:
34
+ body += f"\n\nInternal result: {internal}."
35
+ return body
36
+
37
+ if intent in {"step_by_step", "full_working"}:
38
+ body = "\n".join(f"{i+1}. {s}" for i, s in enumerate(steps[:4])) if steps else "Work through the algebra one step at a time."
39
+ if reveal_answer and internal:
40
+ body += f"\n\nSo the result is {internal}."
41
+ return body
42
+
43
+ if reveal_answer and internal:
44
+ return f"The result is {internal}."
45
+
46
+ if steps:
47
+ return "\n".join(f"- {s}" for s in steps[:2])
48
+ return result.reply or "I can help solve this, but I need a little more structure from the question."
49
+
50
+
51
+ class ConversationEngine:
52
+ def __init__(self, retriever: RetrievalEngine, generator: Optional[GeneratorEngine] = None):
53
+ self.retriever = retriever
54
+ self.generator = generator or GeneratorEngine()
55
+
56
+ def generate_response(
57
+ self,
58
+ raw_user_text: str,
59
+ tone: float,
60
+ verbosity: float,
61
+ transparency: float,
62
+ intent: str,
63
+ help_mode: str,
64
+ chat_history: Optional[List[Dict[str, Any]]] = None,
65
+ question_text: str = "",
66
+ options_text: str = "",
67
+ retrieval_context: str = "",
68
+ ) -> SolverResult:
69
+ user_text = (raw_user_text or "").strip()
70
+ question_block = "\n".join([x for x in [question_text.strip(), options_text.strip()] if x]).strip()
71
+ solver_input = user_text or question_text or question_block
72
+
73
+ if is_quant_question(solver_input) or (question_text and is_quant_question(question_text)):
74
+ result = solve_quant(solver_input if is_quant_question(solver_input) else question_text)
75
+ result.help_mode = help_mode
76
+ reveal_answer = intent in {"answer", "full_working"} or transparency >= 0.85
77
+
78
+ if result.topic:
79
+ chunks = self.retriever.search(question_text or user_text, topic=result.topic, intent=intent, k=3)
80
+ else:
81
+ chunks = self.retriever.search(question_text or user_text, topic="general", intent=intent, k=3)
82
+ result.teaching_chunks = chunks
83
+ result.used_retrieval = bool(chunks)
84
+
85
+ core = _compose_quant_reply(result, intent, reveal_answer=reveal_answer, verbosity=verbosity)
86
+ if chunks and intent in {"hint", "method", "step_by_step", "full_working"}:
87
+ core += "\n\nRelevant study notes:\n" + "\n".join(_teaching_lines(chunks))
88
+ result.reply = format_reply(core, tone, verbosity, transparency, help_mode)
89
+ return result
90
+
91
+ # Non-quant conversational support
92
+ result = SolverResult(domain="general", solved=False, help_mode=help_mode)
93
+ prompt = (
94
+ "You are a helpful study assistant. Reply naturally and briefly. "
95
+ "Do not invent facts. If the user is asking for emotional support or general help, be supportive and practical.\n\n"
96
+ f"User message: {user_text}"
97
  )
98
+ generated = self.generator.generate(prompt) if self.generator and self.generator.available() else None
99
+ if generated:
100
+ result.reply = format_reply(generated, tone, verbosity, transparency, help_mode)
101
+ result.used_generator = True
102
+ return result
103
 
104
+ fallback = "I can help with the current question, explain a method, or talk through the next step."
105
+ result.reply = format_reply(fallback, tone, verbosity, transparency, help_mode)
106
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
formatting.py CHANGED
@@ -1,43 +1,27 @@
1
  from __future__ import annotations
2
 
3
- from typing import List
4
 
5
- from models import SolverResult
6
- from utils import soft_truncate
 
 
 
 
7
 
8
 
 
 
 
 
 
9
 
10
- def tone_prefix(tone: float) -> str:
11
- if tone < 0.34:
12
- return ""
13
- if tone < 0.67:
14
- return "You’ve got this. "
15
- return "You’re doing the right thing by breaking it down. "
16
 
 
 
17
 
 
 
18
 
19
- def build_guidance_text(result: SolverResult, verbosity: float, transparency: float, tone: float) -> str:
20
- lines: List[str] = []
21
- prefix = tone_prefix(tone)
22
- if prefix:
23
- lines.append(prefix.strip())
24
-
25
- lines.append(result.explanation)
26
-
27
- if verbosity >= 0.34 and result.steps:
28
- max_steps = 2 if verbosity < 0.67 else 4
29
- chosen_steps = result.steps[:max_steps]
30
- if transparency < 0.34:
31
- lines.append("Focus on the setup rather than the final value:")
32
- elif transparency < 0.67:
33
- lines.append("Use this structure:")
34
- else:
35
- lines.append("Here is the reasoning structure to follow:")
36
- lines.extend(f"- {step}" for step in chosen_steps)
37
-
38
- if transparency >= 0.67 and result.internal_answer_value:
39
- lines.append("When you finish the calculation, compare your value with the available choices rather than jumping straight to a choice now.")
40
- else:
41
- lines.append("Then compare your result against the options yourself.")
42
-
43
- return soft_truncate("\n".join(lines).strip(), 1600)
 
1
  from __future__ import annotations
2
 
 
3
 
4
+ def style_prefix(tone: float) -> str:
5
+ if tone < 0.33:
6
+ return "Let’s solve it efficiently."
7
+ if tone < 0.66:
8
+ return "Let’s work through it."
9
+ return "You’ve got this — let’s solve it cleanly."
10
 
11
 
12
+ def format_reply(core: str, tone: float, verbosity: float, transparency: float, help_mode: str) -> str:
13
+ prefix = style_prefix(tone)
14
+ core = (core or "").strip()
15
+ if not core:
16
+ return prefix
17
 
18
+ if help_mode == "hint":
19
+ return f"{prefix}\n\nHint:\n{core}"
 
 
 
 
20
 
21
+ if help_mode == "walkthrough" and verbosity >= 0.4:
22
+ return f"{prefix}\n\nWalkthrough:\n{core}"
23
 
24
+ if transparency >= 0.75 and help_mode == "answer":
25
+ return f"{prefix}\n\n{core}"
26
 
27
+ return f"{prefix}\n\n{core}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
generator_engine.py CHANGED
@@ -1,181 +1,33 @@
1
  from __future__ import annotations
2
 
3
- from typing import Any, Dict, List
4
 
5
  try:
6
  from transformers import pipeline
7
  except Exception:
8
  pipeline = None
9
 
10
- from config import settings
11
- from models import GameContext
12
- from utils import normalize_spaces, soft_truncate
13
 
14
-
15
- class GenerativeEngine:
16
- def __init__(self) -> None:
17
- self.task = settings.generator_task
18
- self.model_name = settings.generator_model
19
  self.pipe = None
20
- if pipeline is None:
21
- return
22
- try:
23
- self.pipe = pipeline(
24
- task=self.task,
25
- model=self.model_name,
26
- tokenizer=self.model_name,
27
- device=-1,
28
- )
29
- except Exception:
30
- self.pipe = None
31
-
32
- def _style_instruction(self, tone: float, verbosity: float, transparency: float) -> str:
33
- if tone < 0.34:
34
- tone_label = "neutral and direct"
35
- elif tone < 0.67:
36
- tone_label = "friendly and professional"
37
- else:
38
- tone_label = "warm, encouraging, and supportive"
39
-
40
- if verbosity < 0.34:
41
- verbosity_label = "brief"
42
- elif verbosity < 0.67:
43
- verbosity_label = "moderate length"
44
- else:
45
- verbosity_label = "detailed"
46
-
47
- if transparency < 0.34:
48
- transparency_label = "do not expose chain-of-thought; keep reasoning concise"
49
- elif transparency < 0.67:
50
- transparency_label = "give a short explanation of reasoning"
51
- else:
52
- transparency_label = "explain the logic clearly but do not reveal hidden chain-of-thought"
53
-
54
- return (
55
- f"Style: {tone_label}. "
56
- f"Length: {verbosity_label}. "
57
- f"Reasoning visibility: {transparency_label}."
58
- )
59
-
60
- def _build_prompt(
61
- self,
62
- *,
63
- user_text: str,
64
- chat_history: List[Dict[str, Any]],
65
- game_context: GameContext,
66
- retrieval_summary: str,
67
- tone: float,
68
- verbosity: float,
69
- transparency: float,
70
- ) -> str:
71
- history_lines: List[str] = []
72
- for item in chat_history[-8:]:
73
- role = str(item.get("role", "user")).strip().lower()
74
- text = normalize_spaces(str(item.get("text", "")))
75
- if text:
76
- history_lines.append(f"{role.title()}: {text}")
77
-
78
- context_lines: List[str] = []
79
- if game_context.question_text:
80
- context_lines.append(f"Current question: {game_context.question_text}")
81
- if game_context.options_text:
82
- context_lines.append(f"Answer options: {game_context.options_text}")
83
- if game_context.question_category:
84
- context_lines.append(f"Question category: {game_context.question_category}")
85
- if game_context.question_difficulty:
86
- context_lines.append(f"Question difficulty: {game_context.question_difficulty}")
87
- if retrieval_summary:
88
- context_lines.append(f"Reference material:\n{retrieval_summary}")
89
 
90
- system = (
91
- "You are an AI assistant inside a Unity trading-game study. "
92
- "For non-quant messages, respond naturally and helpfully. "
93
- "If the message is about gameplay, AI usage, sliders, or study procedure, answer clearly and accurately. "
94
- "Do not invent access to hidden systems or researcher-only information. "
95
- "Never claim certainty about unknown participant details. "
96
- + self._style_instruction(tone, verbosity, transparency)
97
- )
98
 
99
- user_block = f"User: {normalize_spaces(user_text)}"
100
- prompt = "\n\n".join(
101
- part for part in [
102
- system,
103
- "Recent conversation:\n" + "\n".join(history_lines) if history_lines else "",
104
- "Game context:\n" + "\n".join(context_lines) if context_lines else "",
105
- user_block,
106
- "Assistant:"
107
- ] if part
108
- )
109
- return prompt
110
-
111
- def _fallback_reply(self, user_text: str, retrieval_summary: str) -> str:
112
- lower = (user_text or "").lower()
113
- if "what can you do" in lower:
114
- return (
115
- "I can help with GMAT-style quant questions, explain how the sliders affect the AI, "
116
- "talk through strategy in the trading-game study, and answer general non-quant questions."
117
- )
118
- if "token" in lower or "cost" in lower:
119
- return (
120
- "AI use can be treated as a limited resource in the study. "
121
- "A prompt can have a visible token cost before sending, and that cost can be logged for later analysis."
122
- )
123
- if retrieval_summary:
124
- return "I can help with that. Here are the most relevant notes I found:\n\n" + retrieval_summary
125
- return "I can help with game questions, AI-use questions, and general non-quant conversation."
126
-
127
- def generate(
128
- self,
129
- *,
130
- user_text: str,
131
- chat_history: List[Dict[str, Any]],
132
- game_context: GameContext,
133
- retrieval_summary: str,
134
- tone: float,
135
- verbosity: float,
136
- transparency: float,
137
- ) -> str:
138
  if self.pipe is None:
139
- return self._fallback_reply(user_text, retrieval_summary)
140
-
141
- prompt = self._build_prompt(
142
- user_text=user_text,
143
- chat_history=chat_history,
144
- game_context=game_context,
145
- retrieval_summary=retrieval_summary,
146
- tone=tone,
147
- verbosity=verbosity,
148
- transparency=transparency,
149
- )
150
-
151
- kwargs: Dict[str, Any] = {
152
- "max_new_tokens": settings.generator_max_new_tokens,
153
- }
154
- if self.task == "text-generation":
155
- kwargs.update(
156
- {
157
- "do_sample": settings.generator_do_sample,
158
- "temperature": settings.generator_temperature,
159
- "top_p": settings.generator_top_p,
160
- "return_full_text": False,
161
- }
162
- )
163
- else:
164
- kwargs.update(
165
- {
166
- "do_sample": settings.generator_do_sample,
167
- "temperature": settings.generator_temperature,
168
- "top_p": settings.generator_top_p,
169
- }
170
- )
171
-
172
- result = self.pipe(prompt, **kwargs)
173
- if not result:
174
- return self._fallback_reply(user_text, retrieval_summary)
175
-
176
- first = result[0]
177
- text = first.get("generated_text") or first.get("summary_text") or ""
178
- text = normalize_spaces(text)
179
- if not text:
180
- return self._fallback_reply(user_text, retrieval_summary)
181
- return soft_truncate(text, 1200)
 
1
  from __future__ import annotations
2
 
3
+ from typing import Optional
4
 
5
  try:
6
  from transformers import pipeline
7
  except Exception:
8
  pipeline = None
9
 
 
 
 
10
 
11
+ class GeneratorEngine:
12
+ def __init__(self, model_name: str = "google/flan-t5-small"):
13
+ self.model_name = model_name
 
 
14
  self.pipe = None
15
+ if pipeline is not None:
16
+ try:
17
+ self.pipe = pipeline("text2text-generation", model=model_name)
18
+ except Exception:
19
+ self.pipe = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ def available(self) -> bool:
22
+ return self.pipe is not None
 
 
 
 
 
 
23
 
24
+ def generate(self, prompt: str, max_new_tokens: int = 96) -> Optional[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  if self.pipe is None:
26
+ return None
27
+ try:
28
+ out = self.pipe(prompt, max_new_tokens=max_new_tokens, do_sample=False)
29
+ if out and isinstance(out, list):
30
+ return str(out[0].get("generated_text", "")).strip()
31
+ except Exception:
32
+ return None
33
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
logging_store.py CHANGED
@@ -1,142 +1,76 @@
1
  from __future__ import annotations
2
 
3
  import json
4
- import uuid
5
  from datetime import datetime, timezone
6
- from pathlib import Path
7
  from typing import Any, Dict, List, Optional
8
 
9
- from huggingface_hub import CommitOperationAdd, HfApi
10
-
11
- from config import settings
12
- from models import EventLogRequest, SessionFinalizeRequest, SessionRecord, SessionStartRequest
13
-
14
-
15
- def _utc_stamp() -> str:
16
- return datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S%fZ")
17
-
18
 
19
  class LoggingStore:
20
- def __init__(self) -> None:
21
- self.root = Path(settings.local_log_dir)
22
- self.root.mkdir(parents=True, exist_ok=True)
23
- self.sessions_dir = self.root / "sessions"
24
- self.events_dir = self.root / "events"
25
- self.sessions_dir.mkdir(parents=True, exist_ok=True)
26
- self.events_dir.mkdir(parents=True, exist_ok=True)
27
- self.hf_api: Optional[HfApi] = HfApi(token=settings.hf_token) if settings.hf_token else None
28
-
29
- def _session_path(self, session_id: str) -> Path:
30
- return self.sessions_dir / f"{session_id}.json"
31
-
32
- def _event_path(self, session_id: str, event_type: str) -> Path:
33
- safe_event = "".join(c if c.isalnum() or c in "-._" else "_" for c in event_type)[:80]
34
- session_dir = self.events_dir / session_id
35
- session_dir.mkdir(parents=True, exist_ok=True)
36
- return session_dir / f"{_utc_stamp()}__{safe_event}.json"
37
-
38
- def _load_session(self, session_id: str) -> Optional[Dict[str, Any]]:
39
- path = self._session_path(session_id)
40
- if not path.exists():
41
- return None
42
- return json.loads(path.read_text(encoding="utf-8"))
43
-
44
- def _write_json(self, path: Path, payload: Dict[str, Any]) -> None:
45
- path.write_text(json.dumps(payload, indent=2, ensure_ascii=False), encoding="utf-8")
46
-
47
- def _push_file_to_hub(self, local_path: Path, repo_path: str) -> None:
48
- if not (settings.push_logs_to_hub and settings.log_dataset_repo_id and self.hf_api):
49
- return
50
- self.hf_api.create_repo(
51
- repo_id=settings.log_dataset_repo_id,
52
- repo_type="dataset",
53
- private=settings.log_dataset_private,
54
- exist_ok=True,
55
- )
56
- self.hf_api.create_commit(
57
- repo_id=settings.log_dataset_repo_id,
58
- repo_type="dataset",
59
- commit_message=f"Add log file {repo_path}",
60
- operations=[
61
- CommitOperationAdd(
62
- path_in_repo=repo_path,
63
- path_or_fileobj=str(local_path),
64
- )
65
- ],
66
- )
67
-
68
- def start_session(self, req: SessionStartRequest) -> SessionRecord:
69
- session_id = req.session_id or str(uuid.uuid4())
70
- record = SessionRecord(
71
- participant_id=req.participant_id,
72
- session_id=session_id,
73
- started_at=req.started_at,
74
- condition=req.condition,
75
- study_id=req.study_id,
76
- game_version=req.game_version,
77
- metadata=req.metadata,
78
- summary={},
79
- event_count=0,
80
- )
81
- path = self._session_path(session_id)
82
- payload = record.model_dump()
83
- self._write_json(path, payload)
84
- self._push_file_to_hub(path, f"sessions/{session_id}.json")
85
  return record
86
 
87
- def append_event(self, req: EventLogRequest) -> Dict[str, Any]:
88
- session = self._load_session(req.session_id)
89
- if session is None:
90
- raise FileNotFoundError(f"Unknown session_id: {req.session_id}")
91
-
92
- payload = req.model_dump()
93
- event_path = self._event_path(req.session_id, req.event_type)
94
- self._write_json(event_path, payload)
95
-
96
- session["event_count"] = int(session.get("event_count", 0)) + 1
97
- self._write_json(self._session_path(req.session_id), session)
98
-
99
- self._push_file_to_hub(event_path, f"events/{req.session_id}/{event_path.name}")
100
- self._push_file_to_hub(self._session_path(req.session_id), f"sessions/{req.session_id}.json")
101
-
102
- return {
103
- "ok": True,
104
- "session_id": req.session_id,
105
- "event_file": event_path.name,
106
- "event_count": session["event_count"],
107
  }
 
 
108
 
109
- def finalize_session(self, req: SessionFinalizeRequest) -> Dict[str, Any]:
110
- session = self._load_session(req.session_id)
111
- if session is None:
112
- raise FileNotFoundError(f"Unknown session_id: {req.session_id}")
 
 
 
 
 
113
 
114
- session["finished_at"] = req.finished_at
115
- session["summary"] = req.summary
116
- self._write_json(self._session_path(req.session_id), session)
117
- self._push_file_to_hub(self._session_path(req.session_id), f"sessions/{req.session_id}.json")
118
- return {"ok": True, "session_id": req.session_id, "finished_at": req.finished_at}
 
 
 
 
 
 
 
 
 
119
 
120
  def list_sessions(self) -> List[Dict[str, Any]]:
121
- items: List[Dict[str, Any]] = []
122
- for path in sorted(self.sessions_dir.glob("*.json"), reverse=True):
123
- try:
124
- items.append(json.loads(path.read_text(encoding="utf-8")))
125
- except Exception:
126
- continue
127
- return items
128
-
129
- def read_session_bundle(self, session_id: str) -> Dict[str, Any]:
130
- session = self._load_session(session_id)
131
- if session is None:
132
- raise FileNotFoundError(f"Unknown session_id: {session_id}")
133
-
134
- events: List[Dict[str, Any]] = []
135
- session_dir = self.events_dir / session_id
136
- for path in sorted(session_dir.glob("*.json")):
137
- try:
138
- events.append(json.loads(path.read_text(encoding="utf-8")))
139
- except Exception:
140
- continue
141
 
142
- return {"session": session, "events": events}
 
 
 
 
1
  from __future__ import annotations
2
 
3
  import json
4
+ import os
5
  from datetime import datetime, timezone
 
6
  from typing import Any, Dict, List, Optional
7
 
 
 
 
 
 
 
 
 
 
8
 
9
  class LoggingStore:
10
+ def __init__(self, root: str = "logs"):
11
+ self.root = root
12
+ os.makedirs(self.root, exist_ok=True)
13
+ self.sessions_path = os.path.join(self.root, "sessions.jsonl")
14
+ self.events_path = os.path.join(self.root, "events.jsonl")
15
+
16
+ def _append(self, path: str, payload: Dict[str, Any]) -> None:
17
+ with open(path, "a", encoding="utf-8") as f:
18
+ f.write(json.dumps(payload, ensure_ascii=False) + "\n")
19
+
20
+ def _now(self) -> str:
21
+ return datetime.now(timezone.utc).isoformat()
22
+
23
+ def start_session(self, session_id: str, user_id: Optional[str], condition: Optional[str], metadata: Optional[Dict[str, Any]]) -> Dict[str, Any]:
24
+ record = {
25
+ "session_id": session_id,
26
+ "user_id": user_id,
27
+ "condition": condition,
28
+ "metadata": metadata or {},
29
+ "started_at": self._now(),
30
+ "type": "session_start",
31
+ }
32
+ self._append(self.sessions_path, record)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  return record
34
 
35
+ def log_event(self, session_id: str, event_type: str, payload: Optional[Dict[str, Any]], timestamp: Optional[str]) -> Dict[str, Any]:
36
+ record = {
37
+ "session_id": session_id,
38
+ "event_type": event_type,
39
+ "timestamp": timestamp or self._now(),
40
+ "payload": payload or {},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  }
42
+ self._append(self.events_path, record)
43
+ return record
44
 
45
+ def finalize_session(self, session_id: str, summary: Optional[Dict[str, Any]]) -> Dict[str, Any]:
46
+ record = {
47
+ "session_id": session_id,
48
+ "summary": summary or {},
49
+ "finalized_at": self._now(),
50
+ "type": "session_finalize",
51
+ }
52
+ self._append(self.sessions_path, record)
53
+ return record
54
 
55
+ def _read_jsonl(self, path: str) -> List[Dict[str, Any]]:
56
+ if not os.path.exists(path):
57
+ return []
58
+ rows = []
59
+ with open(path, "r", encoding="utf-8") as f:
60
+ for line in f:
61
+ line = line.strip()
62
+ if not line:
63
+ continue
64
+ try:
65
+ rows.append(json.loads(line))
66
+ except Exception:
67
+ continue
68
+ return rows
69
 
70
  def list_sessions(self) -> List[Dict[str, Any]]:
71
+ return self._read_jsonl(self.sessions_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ def get_session(self, session_id: str) -> Dict[str, Any]:
74
+ sessions = [r for r in self._read_jsonl(self.sessions_path) if r.get("session_id") == session_id]
75
+ events = [r for r in self._read_jsonl(self.events_path) if r.get("session_id") == session_id]
76
+ return {"session_id": session_id, "records": sessions, "events": events}
models.py CHANGED
@@ -1,19 +1,9 @@
1
  from __future__ import annotations
2
 
3
  from dataclasses import dataclass, field
4
- from datetime import datetime, timezone
5
  from typing import Any, Dict, List, Optional
6
 
7
- from pydantic import BaseModel, Field
8
-
9
-
10
- def utc_now_iso() -> str:
11
- return datetime.now(timezone.utc).isoformat()
12
-
13
-
14
- class ChatMessage(BaseModel):
15
- role: str = "user"
16
- text: str = ""
17
 
18
 
19
  class ChatRequest(BaseModel):
@@ -23,93 +13,59 @@ class ChatRequest(BaseModel):
23
  text: Optional[str] = None
24
  user_message: Optional[str] = None
25
 
26
- chat_history: List[ChatMessage] = Field(default_factory=list)
27
- history: List[ChatMessage] = Field(default_factory=list)
28
-
29
- tone: float = 0.5
30
- verbosity: float = 0.5
31
- transparency: float = 0.5
32
 
33
  help_mode: Optional[str] = None
 
 
34
 
35
-
36
- @dataclass
37
- class GameContext:
38
- raw_hidden_context: str = ""
39
- question_text: str = ""
40
- options_text: str = ""
41
- question_category: str = ""
42
- question_difficulty: str = ""
43
- player_balance: Optional[float] = None
44
- last_outcome: str = ""
45
-
46
- @property
47
- def combined_question_block(self) -> str:
48
- parts = [p for p in [self.question_text, self.options_text] if p.strip()]
49
- return "\n".join(parts).strip()
50
-
51
-
52
- @dataclass
53
- class RetrievalChunk:
54
- chunk_id: str
55
- text: str
56
- source_name: str = ""
57
- topic_guess: str = ""
58
- score: float = 0.0
59
-
60
-
61
- @dataclass
62
- class SolverResult:
63
- solved: bool
64
- explanation: str
65
- domain: str = "general"
66
- internal_answer_value: Optional[str] = None
67
- internal_answer_letter: Optional[str] = None
68
- detected_topic: str = ""
69
- steps: List[str] = field(default_factory=list)
70
-
71
-
72
- @dataclass
73
- class ResponsePackage:
74
- reply: str
75
- meta: Dict[str, Any]
76
 
77
 
78
  class SessionStartRequest(BaseModel):
79
- participant_id: str
80
- session_id: Optional[str] = None
81
  condition: Optional[str] = None
82
- study_id: Optional[str] = None
83
- game_version: Optional[str] = None
84
- metadata: Dict[str, Any] = Field(default_factory=dict)
85
- started_at: str = Field(default_factory=utc_now_iso)
86
 
87
 
88
  class EventLogRequest(BaseModel):
89
- participant_id: str
90
  session_id: str
91
  event_type: str
92
- timestamp: str = Field(default_factory=utc_now_iso)
93
- question_index: Optional[int] = None
94
- turn_index: Optional[int] = None
95
- payload: Dict[str, Any] = Field(default_factory=dict)
96
 
97
 
98
  class SessionFinalizeRequest(BaseModel):
99
- participant_id: str
100
  session_id: str
101
- finished_at: str = Field(default_factory=utc_now_iso)
102
- summary: Dict[str, Any] = Field(default_factory=dict)
103
 
104
 
105
- class SessionRecord(BaseModel):
106
- participant_id: str
107
- session_id: str
108
- started_at: str
109
- finished_at: Optional[str] = None
110
- condition: Optional[str] = None
111
- study_id: Optional[str] = None
112
- game_version: Optional[str] = None
113
- metadata: Dict[str, Any] = Field(default_factory=dict)
114
- summary: Dict[str, Any] = Field(default_factory=dict)
115
- event_count: int = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  from dataclasses import dataclass, field
 
4
  from typing import Any, Dict, List, Optional
5
 
6
+ from pydantic import BaseModel
 
 
 
 
 
 
 
 
 
7
 
8
 
9
  class ChatRequest(BaseModel):
 
13
  text: Optional[str] = None
14
  user_message: Optional[str] = None
15
 
16
+ tone: Optional[float] = 0.5
17
+ verbosity: Optional[float] = 0.5
18
+ transparency: Optional[float] = 0.5
 
 
 
19
 
20
  help_mode: Optional[str] = None
21
+ chat_history: Optional[List[Dict[str, Any]]] = None
22
+ history: Optional[List[Dict[str, Any]]] = None
23
 
24
+ question_text: Optional[str] = None
25
+ question_id: Optional[str] = None
26
+ session_id: Optional[str] = None
27
+ user_id: Optional[str] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  class SessionStartRequest(BaseModel):
31
+ session_id: str
32
+ user_id: Optional[str] = None
33
  condition: Optional[str] = None
34
+ metadata: Optional[Dict[str, Any]] = None
 
 
 
35
 
36
 
37
  class EventLogRequest(BaseModel):
 
38
  session_id: str
39
  event_type: str
40
+ timestamp: Optional[str] = None
41
+ payload: Optional[Dict[str, Any]] = None
 
 
42
 
43
 
44
  class SessionFinalizeRequest(BaseModel):
 
45
  session_id: str
46
+ summary: Optional[Dict[str, Any]] = None
 
47
 
48
 
49
+ @dataclass
50
+ class RetrievedChunk:
51
+ text: str
52
+ topic: str = "general"
53
+ source: str = "local"
54
+ score: float = 0.0
55
+
56
+
57
+ @dataclass
58
+ class SolverResult:
59
+ reply: str = ""
60
+ domain: str = "fallback"
61
+ solved: bool = False
62
+ help_mode: str = "answer"
63
+ answer_letter: Optional[str] = None
64
+ answer_value: Optional[str] = None
65
+ topic: Optional[str] = None
66
+ used_retrieval: bool = False
67
+ used_generator: bool = False
68
+ internal_answer: Optional[str] = None
69
+ steps: List[str] = field(default_factory=list)
70
+ teaching_chunks: List[RetrievedChunk] = field(default_factory=list)
71
+ meta: Dict[str, Any] = field(default_factory=dict)
quant_solver.py CHANGED
@@ -2,9 +2,8 @@ from __future__ import annotations
2
 
3
  import math
4
  import re
5
- from fractions import Fraction
6
  from statistics import mean, median
7
- from typing import Dict, List, Optional
8
 
9
  try:
10
  import sympy as sp
@@ -12,268 +11,214 @@ except Exception:
12
  sp = None
13
 
14
  from models import SolverResult
15
- from utils import clean_math_text, normalize_spaces, safe_div
16
-
17
- CHOICE_LETTERS = ["A", "B", "C", "D", "E"]
18
-
19
 
20
 
21
  def extract_choices(text: str) -> Dict[str, str]:
 
22
  matches = list(
23
  re.finditer(
24
- r"(?im)(?:^|\n)\s*([A-E])[\)\.:]\s*(.*?)(?=(?:\n\s*[A-E][\)\.:]\s)|$)",
25
- text or "",
26
  )
27
  )
28
  return {m.group(1).upper(): normalize_spaces(m.group(2)) for m in matches}
29
 
30
 
31
-
32
  def has_answer_choices(text: str) -> bool:
33
  return len(extract_choices(text)) >= 3
34
 
35
 
36
-
37
  def is_quant_question(text: str) -> bool:
38
  lower = clean_math_text(text).lower()
39
- quant_keywords = [
40
- "integer", "divisible", "remainder", "percent", "ratio", "probability",
41
- "mean", "median", "average", "sum", "difference", "product", "triangle",
42
- "circle", "rectangle", "area", "perimeter", "volume", "x", "y", "equation",
43
- "inequality", "consecutive", "mixture", "speed", "distance", "work", "algebra",
44
  ]
45
- return any(k in lower for k in quant_keywords) or (bool(re.search(r"\d", lower)) and ("?" in lower or has_answer_choices(lower)))
46
-
 
 
 
 
 
47
 
48
 
49
  def _prepare_expression(expr: str) -> str:
50
  expr = clean_math_text(expr).strip()
51
  expr = expr.replace("^", "**")
52
- expr = expr.replace("%", "/100")
53
  expr = re.sub(r"(\d)\s*\(", r"\1*(", expr)
54
  expr = re.sub(r"\)\s*(\d)", r")*\1", expr)
55
  expr = re.sub(r"(\d)([a-zA-Z])", r"\1*\2", expr)
56
  return expr
57
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- def _parse_numeric_text(text: str) -> Optional[float]:
61
- raw = clean_math_text(text).strip().lower()
62
- raw_no_space = raw.replace(" ", "")
63
-
64
- pct_match = re.fullmatch(r"(-?\d+(?:\.\d+)?)%", raw_no_space)
65
- if pct_match:
66
- return float(pct_match.group(1)) / 100.0
67
-
68
- frac_match = re.fullmatch(r"(-?\d+)\s*/\s*(-?\d+)", raw)
69
- if frac_match:
70
- num, den = float(frac_match.group(1)), float(frac_match.group(2))
71
- return None if den == 0 else num / den
72
 
 
 
 
 
 
 
 
 
 
 
 
73
  try:
74
  return float(eval(_prepare_expression(raw), {"__builtins__": {}}, {"sqrt": math.sqrt, "pi": math.pi}))
75
  except Exception:
76
  return None
77
 
78
 
79
-
80
- def compare_to_choices_numeric(answer_value: float, choices: Dict[str, str], tolerance: float = 1e-6) -> Optional[str]:
81
  best_letter = None
82
  best_diff = float("inf")
83
  for letter, raw in choices.items():
84
- parsed = _parse_numeric_text(raw)
85
  if parsed is None:
86
  continue
87
  diff = abs(parsed - answer_value)
88
  if diff < best_diff:
89
  best_diff = diff
90
  best_letter = letter
91
- if best_letter is not None and best_diff <= tolerance:
92
  return best_letter
93
  return None
94
 
95
 
96
-
97
- def _solve_percent_patterns(text: str) -> Optional[SolverResult]:
98
  lower = clean_math_text(text).lower()
99
  choices = extract_choices(text)
100
 
101
- m = re.search(r"what is\s+(\d+(?:\.\d+)?)\s*(?:%|percent)\s+of\s+(\d+(?:\.\d+)?)", lower)
102
  if m:
103
- p, n = float(m.group(1)), float(m.group(2))
104
- ans = p / 100.0 * n
 
105
  return SolverResult(
106
- solved=True,
107
  domain="quant",
108
- explanation="Convert the percent to a decimal and multiply by the base quantity.",
109
- internal_answer_value=f"{ans:g}",
110
- internal_answer_letter=compare_to_choices_numeric(ans, choices) if choices else None,
111
- detected_topic="percent",
 
112
  steps=[
113
- f"Rewrite {p}% as {p/100:g}.",
114
- f"Multiply {p/100:g} by {n}.",
115
- "Then compare your result with the answer choices.",
116
  ],
117
  )
118
 
119
- m = re.search(r"(\d+(?:\.\d+)?)\s+is\s+what percent of\s+(\d+(?:\.\d+)?)", lower)
120
  if m:
121
- x, y = float(m.group(1)), float(m.group(2))
122
- if y == 0:
123
- return None
124
- ans = x / y * 100.0
125
  return SolverResult(
126
- solved=True,
127
  domain="quant",
128
- explanation="Set up part ÷ whole, then convert to a percent.",
129
- internal_answer_value=f"{ans:g}%",
130
- internal_answer_letter=None,
131
- detected_topic="percent",
132
- steps=[
133
- f"Treat {x} as the part and {y} as the whole.",
134
- f"Compute {x}/{y}.",
135
- "Multiply by 100 to convert to a percent, then match to the choices.",
136
- ],
137
  )
138
  return None
139
 
140
 
141
-
142
- def _solve_average_patterns(text: str) -> Optional[SolverResult]:
143
  lower = clean_math_text(text).lower()
144
  nums = [float(n) for n in re.findall(r"-?\d+(?:\.\d+)?", lower)]
145
  if not nums:
146
  return None
147
-
148
  if "mean" in lower or "average" in lower:
149
- avg = mean(nums)
150
- return SolverResult(
151
- solved=True,
152
- domain="quant",
153
- explanation="Add the values and divide by how many values there are.",
154
- internal_answer_value=f"{avg:g}",
155
- detected_topic="statistics",
156
- steps=[f"Add the listed values: total them carefully.", f"Count how many values there are: {len(nums)}.", "Divide total by count, then check the choices."],
157
- )
158
-
159
  if "median" in lower:
160
- med = median(nums)
161
- return SolverResult(
162
- solved=True,
163
- domain="quant",
164
- explanation="Order the values first, then identify the middle position.",
165
- internal_answer_value=f"{med:g}",
166
- detected_topic="statistics",
167
- steps=["Write the numbers in increasing order.", "Find the middle value, or average the two middle values if there are an even number of terms.", "Compare that result with the choices."],
168
- )
169
- return None
170
-
171
-
172
-
173
- def _solve_ratio_patterns(text: str) -> Optional[SolverResult]:
174
- lower = clean_math_text(text).lower()
175
- m = re.search(r"ratio of\s+(\w+)\s+to\s+(\w+)\s+is\s+(\d+)\s*:\s*(\d+)", lower)
176
- if m:
177
- return SolverResult(
178
- solved=True,
179
- domain="quant",
180
- explanation="Use the common multiplier that turns the ratio into the actual quantities.",
181
- detected_topic="ratio",
182
- steps=["Call the common multiplier k.", f"Write the two quantities as {m.group(3)}k and {m.group(4)}k.", "Use the extra condition in the question to solve for k, then compute the quantity you need."],
183
- )
184
- return None
185
-
186
-
187
-
188
- def _solve_divisibility_patterns(text: str) -> Optional[SolverResult]:
189
- lower = clean_math_text(text).lower()
190
- if "divisible by" in lower and "integer" in lower:
191
- expr_match = re.search(r"if\s+([a-z])\s+is an integer and\s+(.+?)\s+is divisible by\s+(\d+)", lower)
192
- if expr_match:
193
- var = expr_match.group(1)
194
- expr = expr_match.group(2)
195
- divisor = int(expr_match.group(3))
196
- choices = extract_choices(text)
197
- valid_letters: List[str] = []
198
- valid_values: List[str] = []
199
- if sp:
200
- symbol = sp.symbols(var)
201
- parsed_expr = sp.sympify(_prepare_expression(expr))
202
- for letter, raw_choice in choices.items():
203
- try:
204
- value = int(float(_parse_numeric_text(raw_choice)))
205
- except Exception:
206
- continue
207
- if parsed_expr.subs(symbol, value) % divisor == 0:
208
- valid_letters.append(letter)
209
- valid_values.append(str(value))
210
- return SolverResult(
211
- solved=bool(valid_letters),
212
- domain="quant",
213
- explanation="Test each answer choice in the divisibility condition instead of solving abstractly first.",
214
- internal_answer_value=", ".join(valid_values) if valid_values else None,
215
- internal_answer_letter=valid_letters[0] if len(valid_letters) == 1 else None,
216
- detected_topic="number_theory",
217
- steps=[
218
- f"Use the condition that the expression must be divisible by {divisor}.",
219
- "Substitute each answer choice into the expression.",
220
- "Keep the value that makes the result a multiple of the divisor.",
221
- ],
222
- )
223
  return None
224
 
225
 
226
-
227
  def _solve_linear_equation(text: str) -> Optional[SolverResult]:
228
- if not sp:
229
  return None
230
- cleaned = clean_math_text(text)
231
- lower = cleaned.lower()
232
- if "value of x" not in lower and not re.search(r"\bsolve\b", lower):
233
  return None
234
-
235
- eq_match = re.search(r"([\d\sa-zA-Z\+\-\*/\^\(\)=]+)", cleaned)
236
- if not eq_match or "=" not in eq_match.group(1):
237
- return None
238
- expr = eq_match.group(1)
239
  try:
240
  lhs, rhs = expr.split("=", 1)
241
- x = sp.symbols("x")
242
- sol = sp.solve(sp.Eq(sp.sympify(_prepare_expression(lhs)), sp.sympify(_prepare_expression(rhs))), x)
 
 
 
 
243
  if not sol:
244
  return None
245
  value = sol[0]
 
 
 
 
 
246
  return SolverResult(
247
- solved=True,
248
  domain="quant",
249
- explanation="Isolate the variable by performing inverse operations on both sides.",
250
- internal_answer_value=str(value),
251
- detected_topic="algebra",
252
- steps=["Simplify each side first if needed.", "Move variable terms to one side and constants to the other.", "Divide by the remaining coefficient, then compare with the choices."],
 
 
 
 
 
 
253
  )
254
  except Exception:
255
  return None
256
 
257
 
258
-
259
  def solve_quant(text: str) -> SolverResult:
260
  text = text or ""
261
- for solver in [_solve_percent_patterns, _solve_average_patterns, _solve_ratio_patterns, _solve_divisibility_patterns, _solve_linear_equation]:
262
- result = solver(text)
263
- if result:
264
  return result
265
-
266
- topic = "general_quant"
267
- steps = [
268
- "Identify exactly what quantity the question wants.",
269
- "Translate the words into an equation, ratio, table, or diagram.",
270
- "Do the calculation carefully.",
271
- "Use the answer choices to check reasonableness and units.",
272
- ]
273
  return SolverResult(
274
- solved=False,
275
  domain="quant",
276
- explanation="This looks quantitative, but it does not match a strong rule-based pattern yet.",
277
- detected_topic=topic,
278
- steps=steps,
 
 
 
 
 
279
  )
 
2
 
3
  import math
4
  import re
 
5
  from statistics import mean, median
6
+ from typing import Dict, Optional
7
 
8
  try:
9
  import sympy as sp
 
11
  sp = None
12
 
13
  from models import SolverResult
14
+ from utils import clean_math_text, normalize_spaces
 
 
 
15
 
16
 
17
  def extract_choices(text: str) -> Dict[str, str]:
18
+ text = text or ""
19
  matches = list(
20
  re.finditer(
21
+ r"(?i)\b([A-E])[\)\.:]\s*(.*?)(?=\s+\b[A-E][\)\.:]\s*|$)",
22
+ text,
23
  )
24
  )
25
  return {m.group(1).upper(): normalize_spaces(m.group(2)) for m in matches}
26
 
27
 
 
28
  def has_answer_choices(text: str) -> bool:
29
  return len(extract_choices(text)) >= 3
30
 
31
 
 
32
  def is_quant_question(text: str) -> bool:
33
  lower = clean_math_text(text).lower()
34
+ keywords = [
35
+ "solve", "equation", "percent", "ratio", "probability", "mean", "median",
36
+ "average", "sum", "difference", "product", "quotient", "triangle", "circle",
37
+ "rectangle", "area", "perimeter", "volume", "algebra", "integer", "divisible",
38
+ "number", "fraction", "decimal", "geometry", "distance", "speed", "work",
39
  ]
40
+ if any(k in lower for k in keywords):
41
+ return True
42
+ if "=" in lower and re.search(r"[a-z]", lower):
43
+ return True
44
+ if re.search(r"\d", lower) and ("?" in lower or has_answer_choices(lower)):
45
+ return True
46
+ return False
47
 
48
 
49
  def _prepare_expression(expr: str) -> str:
50
  expr = clean_math_text(expr).strip()
51
  expr = expr.replace("^", "**")
 
52
  expr = re.sub(r"(\d)\s*\(", r"\1*(", expr)
53
  expr = re.sub(r"\)\s*(\d)", r")*\1", expr)
54
  expr = re.sub(r"(\d)([a-zA-Z])", r"\1*\2", expr)
55
  return expr
56
 
57
 
58
+ def _extract_equation(text: str) -> Optional[str]:
59
+ cleaned = clean_math_text(text)
60
+ if "=" not in cleaned:
61
+ return None
62
+ patterns = [
63
+ r"([A-Za-z0-9\.\+\-\*/\^\(\)\s]*[a-zA-Z][A-Za-z0-9\.\+\-\*/\^\(\)\s]*=[A-Za-z0-9\.\+\-\*/\^\(\)\s]+)",
64
+ r"([0-9A-Za-z\.\+\-\*/\^\(\)\s]+=[0-9A-Za-z\.\+\-\*/\^\(\)\s]+)",
65
+ ]
66
+ for pattern in patterns:
67
+ for m in re.finditer(pattern, cleaned):
68
+ candidate = m.group(1).strip()
69
+ tokens = re.findall(r"[a-z]", candidate.lower())
70
+ if tokens and not candidate.lower().startswith(("how do", "can you", "please", "what is", "solve ")):
71
+ return candidate
72
+ eq_index = cleaned.find("=")
73
+ left = re.findall(r"[A-Za-z0-9\.\+\-\*/\^\(\)\s]+$", cleaned[:eq_index])
74
+ right = re.findall(r"^[A-Za-z0-9\.\+\-\*/\^\(\)\s]+", cleaned[eq_index + 1:])
75
+ if left and right:
76
+ candidate = left[0].strip().split()[-1] + " = " + right[0].strip().split()[0]
77
+ if re.search(r"[a-z]", candidate.lower()):
78
+ return candidate
79
+ return None
80
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ def _parse_number(text: str) -> Optional[float]:
83
+ raw = clean_math_text(text).strip().lower()
84
+ pct = re.fullmatch(r"(-?\d+(?:\.\d+)?)%", raw.replace(" ", ""))
85
+ if pct:
86
+ return float(pct.group(1)) / 100.0
87
+ frac = re.fullmatch(r"(-?\d+)\s*/\s*(-?\d+)", raw)
88
+ if frac:
89
+ den = float(frac.group(2))
90
+ if den == 0:
91
+ return None
92
+ return float(frac.group(1)) / den
93
  try:
94
  return float(eval(_prepare_expression(raw), {"__builtins__": {}}, {"sqrt": math.sqrt, "pi": math.pi}))
95
  except Exception:
96
  return None
97
 
98
 
99
+ def _best_choice(answer_value: float, choices: Dict[str, str]) -> Optional[str]:
 
100
  best_letter = None
101
  best_diff = float("inf")
102
  for letter, raw in choices.items():
103
+ parsed = _parse_number(raw)
104
  if parsed is None:
105
  continue
106
  diff = abs(parsed - answer_value)
107
  if diff < best_diff:
108
  best_diff = diff
109
  best_letter = letter
110
+ if best_letter is not None and best_diff <= 1e-6:
111
  return best_letter
112
  return None
113
 
114
 
115
+ def _solve_percent(text: str) -> Optional[SolverResult]:
 
116
  lower = clean_math_text(text).lower()
117
  choices = extract_choices(text)
118
 
119
+ m = re.search(r"(\d+(?:\.\d+)?)\s*(?:%|percent)\s+of\s+(?:a\s+)?number\s+is\s+(\d+(?:\.\d+)?)", lower)
120
  if m:
121
+ p = float(m.group(1))
122
+ value = float(m.group(2))
123
+ ans = value / (p / 100.0)
124
  return SolverResult(
 
125
  domain="quant",
126
+ solved=True,
127
+ topic="percent",
128
+ answer_value=f"{ans:g}",
129
+ answer_letter=_best_choice(ans, choices) if choices else None,
130
+ internal_answer=f"{ans:g}",
131
  steps=[
132
+ f"Let the number be n.",
133
+ f"Write {p}% of n as {p/100:g}n.",
134
+ f"Set {p/100:g}n = {value} and solve for n.",
135
  ],
136
  )
137
 
138
+ m = re.search(r"what is\s+(\d+(?:\.\d+)?)\s*(?:%|percent)\s+of\s+(\d+(?:\.\d+)?)", lower)
139
  if m:
140
+ p = float(m.group(1))
141
+ n = float(m.group(2))
142
+ ans = p / 100.0 * n
 
143
  return SolverResult(
 
144
  domain="quant",
145
+ solved=True,
146
+ topic="percent",
147
+ answer_value=f"{ans:g}",
148
+ answer_letter=_best_choice(ans, choices) if choices else None,
149
+ internal_answer=f"{ans:g}",
150
+ steps=[f"Convert {p}% to {p/100:g}.", f"Multiply by {n}."]
 
 
 
151
  )
152
  return None
153
 
154
 
155
+ def _solve_mean_median(text: str) -> Optional[SolverResult]:
 
156
  lower = clean_math_text(text).lower()
157
  nums = [float(n) for n in re.findall(r"-?\d+(?:\.\d+)?", lower)]
158
  if not nums:
159
  return None
 
160
  if "mean" in lower or "average" in lower:
161
+ ans = mean(nums)
162
+ return SolverResult(domain="quant", solved=True, topic="statistics", answer_value=f"{ans:g}", internal_answer=f"{ans:g}", steps=["Add the values.", f"Divide by {len(nums)}."])
 
 
 
 
 
 
 
 
163
  if "median" in lower:
164
+ ans = median(nums)
165
+ return SolverResult(domain="quant", solved=True, topic="statistics", answer_value=f"{ans:g}", internal_answer=f"{ans:g}", steps=["Order the values.", "Take the middle value."])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  return None
167
 
168
 
 
169
  def _solve_linear_equation(text: str) -> Optional[SolverResult]:
170
+ if sp is None:
171
  return None
172
+ expr = _extract_equation(text)
173
+ if not expr:
 
174
  return None
 
 
 
 
 
175
  try:
176
  lhs, rhs = expr.split("=", 1)
177
+ symbols = sorted(set(re.findall(r"\b[a-z]\b", expr)))
178
+ if not symbols:
179
+ return None
180
+ var_name = symbols[0]
181
+ var = sp.symbols(var_name)
182
+ sol = sp.solve(sp.Eq(sp.sympify(_prepare_expression(lhs)), sp.sympify(_prepare_expression(rhs))), var)
183
  if not sol:
184
  return None
185
  value = sol[0]
186
+ try:
187
+ as_float = float(value)
188
+ except Exception:
189
+ as_float = None
190
+ choices = extract_choices(text)
191
  return SolverResult(
 
192
  domain="quant",
193
+ solved=True,
194
+ topic="algebra",
195
+ answer_value=str(value),
196
+ answer_letter=_best_choice(as_float, choices) if (as_float is not None and choices) else None,
197
+ internal_answer=f"{var_name} = {value}",
198
+ steps=[
199
+ "Treat the statement as an equation.",
200
+ "Undo operations on both sides to isolate the variable.",
201
+ f"That gives {var_name} = {value}.",
202
+ ],
203
  )
204
  except Exception:
205
  return None
206
 
207
 
 
208
  def solve_quant(text: str) -> SolverResult:
209
  text = text or ""
210
+ for fn in (_solve_percent, _solve_mean_median, _solve_linear_equation):
211
+ result = fn(text)
212
+ if result is not None:
213
  return result
 
 
 
 
 
 
 
 
214
  return SolverResult(
 
215
  domain="quant",
216
+ solved=False,
217
+ topic="general_quant",
218
+ reply="This looks quantitative, but it does not match a strong rule-based pattern yet.",
219
+ steps=[
220
+ "Identify the quantity the question wants.",
221
+ "Translate the wording into an equation, ratio, or diagram.",
222
+ "Carry out the calculation carefully.",
223
+ ],
224
  )
retrieval_engine.py CHANGED
@@ -1,133 +1,98 @@
1
  from __future__ import annotations
2
 
3
  import json
4
- from pathlib import Path
5
- from typing import List
6
 
7
- import numpy as np
 
8
 
9
  try:
10
- from datasets import load_dataset
11
  except Exception:
12
- load_dataset = None
13
 
14
  try:
15
- from sentence_transformers import CrossEncoder, SentenceTransformer
16
  except Exception:
17
- CrossEncoder = None
18
  SentenceTransformer = None
19
 
20
- from config import settings
21
- from models import RetrievalChunk
22
- from utils import normalize_spaces
23
-
24
 
25
  class RetrievalEngine:
26
- def __init__(self) -> None:
27
- self.rows = self._load_rows()
28
- self.texts = [normalize_spaces(row.get("text", "")) for row in self.rows]
29
- self.embedder = None
30
- self.reranker = None
31
  self.embeddings = None
32
-
33
- if SentenceTransformer is None or CrossEncoder is None:
34
- return
35
-
36
- try:
37
- self.embedder = SentenceTransformer(settings.embedding_model)
38
- self.reranker = CrossEncoder(settings.cross_encoder_model)
39
- self.embeddings = self.embedder.encode(
40
- self.texts,
41
- batch_size=64,
42
- convert_to_numpy=True,
43
- normalize_embeddings=True,
44
- show_progress_bar=False,
45
- )
46
- except Exception:
47
- self.embedder = None
48
- self.reranker = None
49
- self.embeddings = None
50
-
51
- def _load_rows(self) -> List[dict]:
52
- local_path = Path(settings.local_chunks_path)
53
- if local_path.exists():
54
- rows = []
55
- with local_path.open("r", encoding="utf-8") as f:
56
- for line in f:
57
- line = line.strip()
58
- if line:
59
- rows.append(json.loads(line))
60
- if rows:
61
- return rows
62
-
63
- if settings.enable_remote_dataset_fallback and load_dataset is not None:
64
- ds = load_dataset(settings.dataset_repo_id, split=settings.dataset_split)
65
- return [dict(row) for row in ds]
66
-
67
- raise FileNotFoundError(
68
- f"Could not load retrieval corpus from {local_path} or {settings.dataset_repo_id}."
69
- )
70
-
71
- def _lexical_search(self, query: str, k: int) -> List[RetrievalChunk]:
72
- tokens = [t for t in normalize_spaces(query).lower().split() if t]
73
- scored = []
74
- for idx, text in enumerate(self.texts):
75
- lower = text.lower()
76
- score = sum(lower.count(tok) for tok in tokens)
77
- if score > 0:
78
- scored.append((idx, float(score)))
79
- scored.sort(key=lambda x: x[1], reverse=True)
80
- results: List[RetrievalChunk] = []
81
- for idx, score in scored[:k]:
82
- row = self.rows[int(idx)]
83
- results.append(
84
- RetrievalChunk(
85
- chunk_id=str(row.get("id", idx)),
86
- text=self.texts[int(idx)],
87
- source_name=str(row.get("source_name", "")),
88
- topic_guess=str(row.get("topic_guess", "")),
89
- score=float(score),
90
- )
91
- )
92
- return results
93
-
94
- def search(self, query: str, k: int | None = None) -> List[RetrievalChunk]:
95
- query = normalize_spaces(query)
96
- if not query:
97
  return []
98
-
99
- top_k = k or settings.retrieval_k
100
-
101
- if self.embedder is None or self.reranker is None or self.embeddings is None:
102
- return self._lexical_search(query, settings.rerank_k)
103
-
104
- query_emb = self.embedder.encode(
105
- [query],
106
- convert_to_numpy=True,
107
- normalize_embeddings=True,
108
- show_progress_bar=False,
109
- )[0]
110
-
111
- scores = np.dot(self.embeddings, query_emb)
112
- candidate_idx = np.argsort(scores)[::-1][:top_k]
113
- pairs = [[query, self.texts[i]] for i in candidate_idx]
114
- rerank_scores = self.reranker.predict(pairs)
115
- reranked = sorted(
116
- zip(candidate_idx, rerank_scores),
117
- key=lambda x: float(x[1]),
118
- reverse=True,
119
- )[: settings.rerank_k]
120
-
121
- results: List[RetrievalChunk] = []
122
- for idx, score in reranked:
123
- row = self.rows[int(idx)]
124
- results.append(
125
- RetrievalChunk(
126
- chunk_id=str(row.get("id", idx)),
127
- text=self.texts[int(idx)],
128
- source_name=str(row.get("source_name", "")),
129
- topic_guess=str(row.get("topic_guess", "")),
130
- score=float(score),
131
- )
132
- )
133
  return results
 
1
  from __future__ import annotations
2
 
3
  import json
4
+ import os
5
+ from typing import List, Optional
6
 
7
+ from models import RetrievedChunk
8
+ from utils import clean_math_text, score_token_overlap
9
 
10
  try:
11
+ import numpy as np
12
  except Exception:
13
+ np = None
14
 
15
  try:
16
+ from sentence_transformers import SentenceTransformer
17
  except Exception:
 
18
  SentenceTransformer = None
19
 
 
 
 
 
20
 
21
  class RetrievalEngine:
22
+ def __init__(self, data_path: str = "data/gmat_hf_chunks.jsonl"):
23
+ self.data_path = data_path
24
+ self.rows = self._load_rows(data_path)
25
+ self.encoder = None
 
26
  self.embeddings = None
27
+ if SentenceTransformer is not None and self.rows:
28
+ try:
29
+ self.encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
30
+ self.embeddings = self.encoder.encode([r["text"] for r in self.rows], convert_to_numpy=True, normalize_embeddings=True)
31
+ except Exception:
32
+ self.encoder = None
33
+ self.embeddings = None
34
+
35
+ def _load_rows(self, data_path: str):
36
+ rows = []
37
+ if not os.path.exists(data_path):
38
+ return rows
39
+ with open(data_path, "r", encoding="utf-8") as f:
40
+ for line in f:
41
+ line = line.strip()
42
+ if not line:
43
+ continue
44
+ try:
45
+ item = json.loads(line)
46
+ except Exception:
47
+ continue
48
+ rows.append({
49
+ "text": item.get("text", ""),
50
+ "topic": item.get("topic", item.get("section", "general")) or "general",
51
+ "source": item.get("source", "local_corpus"),
52
+ })
53
+ return rows
54
+
55
+ def _topic_bonus(self, desired_topic: str, row_topic: str, intent: str) -> float:
56
+ desired_topic = (desired_topic or "").lower()
57
+ row_topic = (row_topic or "").lower()
58
+ intent = (intent or "").lower()
59
+ bonus = 0.0
60
+ if desired_topic and desired_topic in row_topic:
61
+ bonus += 1.25
62
+ if desired_topic == "algebra" and row_topic in {"algebra", "linear equations", "equations"}:
63
+ bonus += 1.0
64
+ if desired_topic == "percent" and "percent" in row_topic:
65
+ bonus += 1.0
66
+ if intent in {"method", "step_by_step", "full_working", "hint"}:
67
+ if any(k in row_topic for k in ["algebra", "percent", "fractions", "word_problems", "general"]):
68
+ bonus += 0.25
69
+ return bonus
70
+
71
+ def search(self, query: str, topic: str = "", intent: str = "answer", k: int = 3) -> List[RetrievedChunk]:
72
+ if not self.rows:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  return []
74
+ combined_query = clean_math_text(query)
75
+
76
+ scores = []
77
+ if self.encoder is not None and self.embeddings is not None and np is not None:
78
+ try:
79
+ q = self.encoder.encode([combined_query], convert_to_numpy=True, normalize_embeddings=True)[0]
80
+ semantic_scores = self.embeddings @ q
81
+ for row, sem in zip(self.rows, semantic_scores.tolist()):
82
+ lexical = score_token_overlap(combined_query, row["text"])
83
+ bonus = self._topic_bonus(topic, row["topic"], intent)
84
+ scores.append((0.7 * sem + 0.3 * lexical + bonus, row))
85
+ except Exception:
86
+ scores = []
87
+
88
+ if not scores:
89
+ for row in self.rows:
90
+ lexical = score_token_overlap(combined_query, row["text"])
91
+ bonus = self._topic_bonus(topic, row["topic"], intent)
92
+ scores.append((lexical + bonus, row))
93
+
94
+ scores.sort(key=lambda x: x[0], reverse=True)
95
+ results = []
96
+ for score, row in scores[:k]:
97
+ results.append(RetrievedChunk(text=row["text"], topic=row["topic"], source=row["source"], score=float(score)))
 
 
 
 
 
 
 
 
 
 
 
98
  return results
ui_html.py CHANGED
@@ -2,68 +2,37 @@ HOME_HTML = """
2
  <!doctype html>
3
  <html>
4
  <head>
5
- <meta charset="utf-8" />
6
- <title>Trading Game Study AI</title>
 
7
  <style>
8
- body { font-family: Arial, sans-serif; max-width: 980px; margin: 32px auto; line-height: 1.45; }
9
- textarea, input { width: 100%; margin-top: 8px; padding: 10px; }
10
- button { padding: 10px 14px; margin-top: 12px; cursor: pointer; }
11
- pre { background: #f6f6f6; padding: 12px; overflow-x: auto; white-space: pre-wrap; }
12
- .row { margin-bottom: 24px; }
13
  </style>
14
  </head>
15
  <body>
16
- <h1>Trading Game Study AI</h1>
17
- <p>This Space supports three things: quant-help chat, non-quant generative conversation, and researcher-facing study logging endpoints.</p>
18
-
19
- <div class="row">
20
- <h2>Chat test</h2>
21
- <textarea id="msg" rows="8">What can you do in this game?</textarea>
22
- <button onclick="sendChat()">Send /chat</button>
23
- <pre id="chatOut"></pre>
24
- </div>
25
-
26
- <div class="row">
27
- <h2>Session start test</h2>
28
- <input id="participant" value="demo_participant_001" />
29
- <button onclick="startSession()">Send /log/session/start</button>
30
- <pre id="sessionOut"></pre>
31
- </div>
32
-
33
  <script>
34
- let currentSessionId = null;
35
-
36
- async function sendChat() {
37
- const body = {
38
- message: document.getElementById("msg").value,
39
- tone: 0.7,
40
- verbosity: 0.6,
41
- transparency: 0.5,
42
- chat_history: []
43
- };
44
- const res = await fetch("/chat", {
45
- method: "POST",
46
- headers: {"Content-Type": "application/json"},
47
- body: JSON.stringify(body)
48
- });
49
- document.getElementById("chatOut").textContent = JSON.stringify(await res.json(), null, 2);
50
- }
51
-
52
- async function startSession() {
53
- const body = {
54
- participant_id: document.getElementById("participant").value,
55
- condition: "demo",
56
- study_id: "pilot",
57
- metadata: {"source": "browser_test_page"}
58
- };
59
- const res = await fetch("/log/session/start", {
60
- method: "POST",
61
- headers: {"Content-Type": "application/json"},
62
- body: JSON.stringify(body)
63
- });
64
  const data = await res.json();
65
- currentSessionId = data.session_id || null;
66
- document.getElementById("sessionOut").textContent = JSON.stringify(data, null, 2);
67
  }
68
  </script>
69
  </body>
 
2
  <!doctype html>
3
  <html>
4
  <head>
5
+ <meta charset=\"utf-8\">
6
+ <meta name=\"viewport\" content=\"width=device-width, initial-scale=1\">
7
+ <title>Trading Game AI V2</title>
8
  <style>
9
+ body { font-family: Arial, sans-serif; max-width: 900px; margin: 40px auto; padding: 0 16px; }
10
+ textarea { width: 100%; min-height: 180px; }
11
+ button { margin-top: 12px; padding: 10px 16px; }
12
+ pre { background: #f5f5f5; padding: 16px; white-space: pre-wrap; }
 
13
  </style>
14
  </head>
15
  <body>
16
+ <h1>Trading Game AI V2</h1>
17
+ <p>Test the <code>/chat</code> endpoint here.</p>
18
+ <textarea id=\"payload\">{
19
+ \"message\": \"Solve: x/5 = 12\",
20
+ \"chat_history\": [],
21
+ \"tone\": 0.5,
22
+ \"verbosity\": 0.6,
23
+ \"transparency\": 0.6,
24
+ \"session_id\": \"test-session-1\",
25
+ \"user_id\": \"test-user-1\"
26
+ }</textarea>
27
+ <br>
28
+ <button onclick=\"send()\">Send</button>
29
+ <pre id=\"out\"></pre>
 
 
 
30
  <script>
31
+ async function send() {
32
+ const payload = JSON.parse(document.getElementById('payload').value);
33
+ const res = await fetch('/chat', {method:'POST', headers:{'Content-Type':'application/json'}, body: JSON.stringify(payload)});
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  const data = await res.json();
35
+ document.getElementById('out').textContent = JSON.stringify(data, null, 2);
 
36
  }
37
  </script>
38
  </body>
utils.py CHANGED
@@ -1,104 +1,102 @@
1
  from __future__ import annotations
2
 
 
 
3
  import math
4
  import re
5
- from typing import Any, Iterable, List, Tuple
6
 
7
- from models import ChatRequest, GameContext
8
 
9
 
10
- USER_MESSAGE_MARKER = "USER_MESSAGE:"
11
-
12
-
13
- def clamp01(value: Any, default: float = 0.5) -> float:
14
  try:
15
- v = float(value)
16
  return max(0.0, min(1.0, v))
17
  except Exception:
18
  return default
19
 
20
 
21
  def normalize_spaces(text: str) -> str:
22
- return re.sub(r"\s+", " ", (text or "")).strip()
23
 
24
 
25
  def clean_math_text(text: str) -> str:
26
- text = (text or "")
27
- text = text.replace("−", "-").replace("×", "*").replace("÷", "/")
28
- text = text.replace("\u2212", "-").replace("\u00d7", "*").replace("\u00f7", "/")
29
- text = text.replace("\u2264", "<=").replace("\u2265", ">=")
30
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
 
33
  def get_user_text(req: ChatRequest, raw_body: Any = None) -> str:
34
- for field in [req.message, req.prompt, req.query, req.text, req.user_message]:
35
- if isinstance(field, str) and field.strip():
36
- return field.strip()
37
- if isinstance(raw_body, str):
38
- return raw_body.strip()
39
- return ""
40
-
41
-
42
- def split_unity_message(full_text: str) -> Tuple[str, str]:
43
- if not full_text:
44
- return "", ""
45
- idx = full_text.find(USER_MESSAGE_MARKER)
46
- if idx == -1:
47
- return "", full_text.strip()
48
- return full_text[:idx].strip(), full_text[idx + len(USER_MESSAGE_MARKER):].strip()
49
-
50
-
51
-
52
- def _search_field(pattern: str, text: str) -> str:
53
- m = re.search(pattern, text, flags=re.IGNORECASE | re.DOTALL)
54
- return m.group(1).strip() if m else ""
55
-
56
-
57
-
58
- def parse_hidden_context(hidden_context: str) -> GameContext:
59
- ctx = GameContext(raw_hidden_context=hidden_context or "")
60
- if not hidden_context:
61
- return ctx
62
-
63
- ctx.question_category = _search_field(r"Category:\s*(.+)", hidden_context).splitlines()[0] if "Category:" in hidden_context else ""
64
- ctx.question_difficulty = _search_field(r"Difficulty:\s*(.+)", hidden_context).splitlines()[0] if "Difficulty:" in hidden_context else ""
65
- ctx.last_outcome = _search_field(r"Last outcome:\s*(.+)", hidden_context).splitlines()[0] if "Last outcome:" in hidden_context else ""
66
-
67
- question = _search_field(r"Question:\s*(.+?)(?:\nOptions:|\nPlayer balance:|\nLast outcome:|$)", hidden_context)
68
- options = _search_field(r"Options:\s*(.+?)(?:\nPlayer balance:|\nLast outcome:|$)", hidden_context)
69
- balance_text = _search_field(r"Player balance:\s*([0-9]+(?:\.[0-9]+)?)", hidden_context)
70
-
71
- ctx.question_text = question.strip()
72
- ctx.options_text = options.strip()
73
- try:
74
- ctx.player_balance = float(balance_text) if balance_text else None
75
- except Exception:
76
- ctx.player_balance = None
77
- return ctx
78
-
79
-
80
-
81
- def soft_truncate(text: str, limit: int) -> str:
82
- text = (text or "").strip()
83
- if len(text) <= limit:
84
- return text
85
- trimmed = text[: limit - 1].rsplit(" ", 1)[0].strip()
86
- return (trimmed or text[: limit - 1]).rstrip() + "…"
87
-
88
-
89
-
90
- def safe_div(a: float, b: float) -> float:
91
- return a / b if b else math.inf
92
-
93
-
94
-
95
- def flatten_history(items: Iterable[Any]) -> List[dict]:
96
- out: List[dict] = []
97
- for item in items or []:
98
- if isinstance(item, dict):
99
- out.append({"role": str(item.get("role", "user")), "text": str(item.get("text", ""))})
100
- else:
101
- role = getattr(item, "role", "user")
102
- text = getattr(item, "text", "")
103
- out.append({"role": str(role), "text": str(text)})
104
  return out
 
1
  from __future__ import annotations
2
 
3
+ import ast
4
+ import json
5
  import math
6
  import re
7
+ from typing import Any, Iterable, List
8
 
9
+ from models import ChatRequest
10
 
11
 
12
+ def clamp01(x: Any, default: float = 0.5) -> float:
 
 
 
13
  try:
14
+ v = float(x)
15
  return max(0.0, min(1.0, v))
16
  except Exception:
17
  return default
18
 
19
 
20
  def normalize_spaces(text: str) -> str:
21
+ return re.sub(r"\s+", " ", str(text or "")).strip()
22
 
23
 
24
  def clean_math_text(text: str) -> str:
25
+ t = str(text or "")
26
+ t = t.replace("×", "*").replace("÷", "/")
27
+ t = t.replace("", "-").replace("", "-").replace("", "-")
28
+ t = t.replace("\u00a0", " ")
29
+ return t
30
+
31
+
32
+ def tokenize(text: str) -> List[str]:
33
+ return re.findall(r"[a-z0-9]+", clean_math_text(text).lower())
34
+
35
+
36
+ def score_token_overlap(query: str, text: str) -> float:
37
+ q = set(tokenize(query))
38
+ t = set(tokenize(text))
39
+ if not q or not t:
40
+ return 0.0
41
+ overlap = len(q & t)
42
+ return overlap / max(1, len(q))
43
+
44
+
45
+ def extract_text_from_any_payload(payload: Any) -> str:
46
+ if payload is None:
47
+ return ""
48
+
49
+ if isinstance(payload, str):
50
+ s = payload.strip()
51
+ if not s:
52
+ return ""
53
+ if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")):
54
+ try:
55
+ decoded = json.loads(s)
56
+ return extract_text_from_any_payload(decoded)
57
+ except Exception:
58
+ pass
59
+ try:
60
+ decoded = ast.literal_eval(s)
61
+ if isinstance(decoded, (dict, list)):
62
+ return extract_text_from_any_payload(decoded)
63
+ except Exception:
64
+ pass
65
+ return s
66
+
67
+ if isinstance(payload, dict):
68
+ for key in [
69
+ "message", "prompt", "query", "text", "user_message",
70
+ "input", "data", "payload", "body", "content",
71
+ ]:
72
+ if key in payload:
73
+ maybe = extract_text_from_any_payload(payload[key])
74
+ if maybe:
75
+ return maybe
76
+ parts = [extract_text_from_any_payload(v) for v in payload.values()]
77
+ return "\n".join([p for p in parts if p]).strip()
78
+
79
+ if isinstance(payload, list):
80
+ parts = [extract_text_from_any_payload(x) for x in payload]
81
+ return "\n".join([p for p in parts if p]).strip()
82
+
83
+ return str(payload).strip()
84
 
85
 
86
  def get_user_text(req: ChatRequest, raw_body: Any = None) -> str:
87
+ for field in ["message", "prompt", "query", "text", "user_message"]:
88
+ value = getattr(req, field, None)
89
+ if isinstance(value, str) and value.strip():
90
+ return value.strip()
91
+ return extract_text_from_any_payload(raw_body).strip()
92
+
93
+
94
+ def short_lines(items: Iterable[str], limit: int) -> List[str]:
95
+ out: List[str] = []
96
+ for item in items:
97
+ item = normalize_spaces(item)
98
+ if item:
99
+ out.append(item)
100
+ if len(out) >= limit:
101
+ break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  return out