Upload 11 files
Browse files- app.py +53 -84
- context_parser.py +52 -23
- conversation_logic.py +100 -74
- formatting.py +18 -34
- generator_engine.py +20 -168
- logging_store.py +60 -126
- models.py +39 -83
- quant_solver.py +116 -171
- retrieval_engine.py +81 -116
- ui_html.py +25 -56
- utils.py +81 -83
app.py
CHANGED
|
@@ -1,27 +1,27 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
import
|
| 4 |
from typing import Any, Dict
|
| 5 |
|
| 6 |
-
from fastapi import FastAPI,
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
from fastapi.responses import HTMLResponse, JSONResponse
|
| 9 |
|
| 10 |
-
from
|
| 11 |
-
from
|
| 12 |
-
from
|
| 13 |
-
from generator_engine import GenerativeEngine
|
| 14 |
from logging_store import LoggingStore
|
| 15 |
from models import ChatRequest, EventLogRequest, SessionFinalizeRequest, SessionStartRequest
|
| 16 |
from retrieval_engine import RetrievalEngine
|
| 17 |
from ui_html import HOME_HTML
|
| 18 |
-
from utils import clamp01,
|
| 19 |
|
| 20 |
-
app = FastAPI(title=settings.app_name, version=settings.app_version)
|
| 21 |
retriever = RetrievalEngine()
|
| 22 |
-
generator =
|
| 23 |
-
|
|
|
|
| 24 |
|
|
|
|
| 25 |
app.add_middleware(
|
| 26 |
CORSMiddleware,
|
| 27 |
allow_origins=["*"],
|
|
@@ -31,14 +31,9 @@ app.add_middleware(
|
|
| 31 |
)
|
| 32 |
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def _require_research_key(x_research_key: str | None) -> None:
|
| 40 |
-
if settings.research_api_key and not secrets.compare_digest(x_research_key or "", settings.research_api_key):
|
| 41 |
-
raise HTTPException(status_code=401, detail="Invalid research key")
|
| 42 |
|
| 43 |
|
| 44 |
@app.get("/", response_class=HTMLResponse)
|
|
@@ -46,20 +41,6 @@ def home() -> str:
|
|
| 46 |
return HOME_HTML
|
| 47 |
|
| 48 |
|
| 49 |
-
@app.get("/health")
|
| 50 |
-
def health() -> Dict[str, Any]:
|
| 51 |
-
return {
|
| 52 |
-
"ok": True,
|
| 53 |
-
"app": settings.app_name,
|
| 54 |
-
"version": settings.app_version,
|
| 55 |
-
"retrieval_rows": len(retriever.rows),
|
| 56 |
-
"generator_model": settings.generator_model,
|
| 57 |
-
"generator_task": settings.generator_task,
|
| 58 |
-
"hub_log_push_enabled": settings.push_logs_to_hub,
|
| 59 |
-
"hub_log_repo": settings.log_dataset_repo_id,
|
| 60 |
-
}
|
| 61 |
-
|
| 62 |
-
|
| 63 |
@app.post("/chat")
|
| 64 |
async def chat(request: Request) -> JSONResponse:
|
| 65 |
raw_body: Any = None
|
|
@@ -76,86 +57,74 @@ async def chat(request: Request) -> JSONResponse:
|
|
| 76 |
|
| 77 |
full_text = get_user_text(req, raw_body)
|
| 78 |
hidden_context, actual_user_message = split_unity_message(full_text)
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
| 82 |
|
| 83 |
tone = clamp01(req_data.get("tone", req.tone), 0.5)
|
| 84 |
verbosity = clamp01(req_data.get("verbosity", req.verbosity), 0.5)
|
| 85 |
transparency = clamp01(req_data.get("transparency", req.transparency), 0.5)
|
| 86 |
-
help_mode = detect_help_mode(actual_user_message or full_text, req_data.get("help_mode", req.help_mode))
|
| 87 |
-
chat_history = flatten_history(req_data.get("chat_history") or req_data.get("history") or req.chat_history or req.history)
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
| 91 |
tone=tone,
|
| 92 |
verbosity=verbosity,
|
| 93 |
transparency=transparency,
|
|
|
|
| 94 |
help_mode=help_mode,
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
generator=generator,
|
| 99 |
)
|
| 100 |
|
| 101 |
-
return JSONResponse(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
|
| 104 |
@app.post("/log/session/start")
|
| 105 |
-
def
|
| 106 |
-
|
| 107 |
-
x_ingest_key: str | None = Header(default=None),
|
| 108 |
-
) -> Dict[str, Any]:
|
| 109 |
-
_require_ingest_key(x_ingest_key)
|
| 110 |
-
record = logging_store.start_session(req)
|
| 111 |
-
return {"ok": True, "session_id": record.session_id, "record": record.model_dump()}
|
| 112 |
|
| 113 |
|
| 114 |
@app.post("/log/event")
|
| 115 |
-
def log_event(
|
| 116 |
-
|
| 117 |
-
x_ingest_key: str | None = Header(default=None),
|
| 118 |
-
) -> Dict[str, Any]:
|
| 119 |
-
_require_ingest_key(x_ingest_key)
|
| 120 |
-
try:
|
| 121 |
-
return logging_store.append_event(req)
|
| 122 |
-
except FileNotFoundError as e:
|
| 123 |
-
raise HTTPException(status_code=404, detail=str(e)) from e
|
| 124 |
|
| 125 |
|
| 126 |
@app.post("/log/session/finalize")
|
| 127 |
-
def
|
| 128 |
-
|
| 129 |
-
x_ingest_key: str | None = Header(default=None),
|
| 130 |
-
) -> Dict[str, Any]:
|
| 131 |
-
_require_ingest_key(x_ingest_key)
|
| 132 |
-
try:
|
| 133 |
-
return logging_store.finalize_session(req)
|
| 134 |
-
except FileNotFoundError as e:
|
| 135 |
-
raise HTTPException(status_code=404, detail=str(e)) from e
|
| 136 |
|
| 137 |
|
| 138 |
@app.get("/research/sessions")
|
| 139 |
-
def research_sessions(
|
| 140 |
-
|
| 141 |
-
) -> Dict[str, Any]:
|
| 142 |
-
_require_research_key(x_research_key)
|
| 143 |
-
return {"ok": True, "sessions": logging_store.list_sessions()}
|
| 144 |
|
| 145 |
|
| 146 |
@app.get("/research/session/{session_id}")
|
| 147 |
-
def research_session(
|
| 148 |
-
|
| 149 |
-
x_research_key: str | None = Header(default=None),
|
| 150 |
-
) -> Dict[str, Any]:
|
| 151 |
-
_require_research_key(x_research_key)
|
| 152 |
-
try:
|
| 153 |
-
return {"ok": True, **logging_store.read_session_bundle(session_id)}
|
| 154 |
-
except FileNotFoundError as e:
|
| 155 |
-
raise HTTPException(status_code=404, detail=str(e)) from e
|
| 156 |
|
| 157 |
|
| 158 |
if __name__ == "__main__":
|
| 159 |
import uvicorn
|
| 160 |
|
| 161 |
-
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
import os
|
| 4 |
from typing import Any, Dict
|
| 5 |
|
| 6 |
+
from fastapi import FastAPI, Request
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
from fastapi.responses import HTMLResponse, JSONResponse
|
| 9 |
|
| 10 |
+
from context_parser import detect_intent, extract_game_context_fields, intent_to_help_mode, split_unity_message
|
| 11 |
+
from conversation_logic import ConversationEngine
|
| 12 |
+
from generator_engine import GeneratorEngine
|
|
|
|
| 13 |
from logging_store import LoggingStore
|
| 14 |
from models import ChatRequest, EventLogRequest, SessionFinalizeRequest, SessionStartRequest
|
| 15 |
from retrieval_engine import RetrievalEngine
|
| 16 |
from ui_html import HOME_HTML
|
| 17 |
+
from utils import clamp01, get_user_text
|
| 18 |
|
|
|
|
| 19 |
retriever = RetrievalEngine()
|
| 20 |
+
generator = GeneratorEngine()
|
| 21 |
+
engine = ConversationEngine(retriever=retriever, generator=generator)
|
| 22 |
+
store = LoggingStore()
|
| 23 |
|
| 24 |
+
app = FastAPI(title="Trading Game AI V2", version="2.2.0")
|
| 25 |
app.add_middleware(
|
| 26 |
CORSMiddleware,
|
| 27 |
allow_origins=["*"],
|
|
|
|
| 31 |
)
|
| 32 |
|
| 33 |
|
| 34 |
+
@app.get("/health")
|
| 35 |
+
def health() -> Dict[str, Any]:
|
| 36 |
+
return {"ok": True, "app": "Trading Game AI V2", "generator_available": generator.available()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
@app.get("/", response_class=HTMLResponse)
|
|
|
|
| 41 |
return HOME_HTML
|
| 42 |
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
@app.post("/chat")
|
| 45 |
async def chat(request: Request) -> JSONResponse:
|
| 46 |
raw_body: Any = None
|
|
|
|
| 57 |
|
| 58 |
full_text = get_user_text(req, raw_body)
|
| 59 |
hidden_context, actual_user_message = split_unity_message(full_text)
|
| 60 |
+
game_fields = extract_game_context_fields(hidden_context)
|
| 61 |
+
|
| 62 |
+
question_text = (req.question_text or "").strip() or game_fields["question"]
|
| 63 |
+
options_text = game_fields["options"]
|
| 64 |
|
| 65 |
tone = clamp01(req_data.get("tone", req.tone), 0.5)
|
| 66 |
verbosity = clamp01(req_data.get("verbosity", req.verbosity), 0.5)
|
| 67 |
transparency = clamp01(req_data.get("transparency", req.transparency), 0.5)
|
|
|
|
|
|
|
| 68 |
|
| 69 |
+
intent = detect_intent(actual_user_message or full_text, req_data.get("help_mode", req.help_mode))
|
| 70 |
+
help_mode = intent_to_help_mode(intent)
|
| 71 |
+
|
| 72 |
+
result = engine.generate_response(
|
| 73 |
+
raw_user_text=actual_user_message or full_text,
|
| 74 |
tone=tone,
|
| 75 |
verbosity=verbosity,
|
| 76 |
transparency=transparency,
|
| 77 |
+
intent=intent,
|
| 78 |
help_mode=help_mode,
|
| 79 |
+
chat_history=req.chat_history or req.history or [],
|
| 80 |
+
question_text=question_text,
|
| 81 |
+
options_text=options_text,
|
|
|
|
| 82 |
)
|
| 83 |
|
| 84 |
+
return JSONResponse(
|
| 85 |
+
{
|
| 86 |
+
"reply": result.reply,
|
| 87 |
+
"meta": {
|
| 88 |
+
"domain": result.domain,
|
| 89 |
+
"solved": result.solved,
|
| 90 |
+
"help_mode": result.help_mode,
|
| 91 |
+
"answer_letter": result.answer_letter,
|
| 92 |
+
"answer_value": result.answer_value,
|
| 93 |
+
"topic": result.topic,
|
| 94 |
+
"used_retrieval": result.used_retrieval,
|
| 95 |
+
"used_generator": result.used_generator,
|
| 96 |
+
},
|
| 97 |
+
}
|
| 98 |
+
)
|
| 99 |
|
| 100 |
|
| 101 |
@app.post("/log/session/start")
|
| 102 |
+
def log_session_start(payload: SessionStartRequest) -> Dict[str, Any]:
|
| 103 |
+
return store.start_session(payload.session_id, payload.user_id, payload.condition, payload.metadata)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
|
| 106 |
@app.post("/log/event")
|
| 107 |
+
def log_event(payload: EventLogRequest) -> Dict[str, Any]:
|
| 108 |
+
return store.log_event(payload.session_id, payload.event_type, payload.payload, payload.timestamp)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
|
| 111 |
@app.post("/log/session/finalize")
|
| 112 |
+
def log_session_finalize(payload: SessionFinalizeRequest) -> Dict[str, Any]:
|
| 113 |
+
return store.finalize_session(payload.session_id, payload.summary)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
|
| 116 |
@app.get("/research/sessions")
|
| 117 |
+
def research_sessions() -> Dict[str, Any]:
|
| 118 |
+
return {"sessions": store.list_sessions()}
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
@app.get("/research/session/{session_id}")
|
| 122 |
+
def research_session(session_id: str) -> Dict[str, Any]:
|
| 123 |
+
return store.get_session(session_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
|
| 126 |
if __name__ == "__main__":
|
| 127 |
import uvicorn
|
| 128 |
|
| 129 |
+
port = int(os.getenv("PORT", "7860"))
|
| 130 |
+
uvicorn.run(app, host="0.0.0.0", port=port)
|
context_parser.py
CHANGED
|
@@ -1,36 +1,65 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import re
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
return "hint"
|
| 18 |
-
if any(
|
| 19 |
-
return "
|
| 20 |
-
if any(
|
| 21 |
return "answer"
|
| 22 |
-
return "
|
| 23 |
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
if
|
| 29 |
-
return
|
| 30 |
-
|
| 31 |
-
r"^(hi|hello|hey|thanks|thank you|ok|okay|cool|great|nice)\b",
|
| 32 |
-
r"how are you",
|
| 33 |
-
r"what can you do",
|
| 34 |
-
r"who are you",
|
| 35 |
-
]
|
| 36 |
-
return any(re.search(p, lower) for p in social_patterns)
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import re
|
| 4 |
+
from typing import Dict, Optional
|
| 5 |
|
| 6 |
|
| 7 |
+
def split_unity_message(full_text: str) -> tuple[str, str]:
|
| 8 |
+
if not full_text:
|
| 9 |
+
return "", ""
|
| 10 |
+
marker = "USER_MESSAGE:"
|
| 11 |
+
idx = full_text.find(marker)
|
| 12 |
+
if idx == -1:
|
| 13 |
+
return "", full_text.strip()
|
| 14 |
+
hidden = full_text[:idx].strip()
|
| 15 |
+
user = full_text[idx + len(marker):].strip()
|
| 16 |
+
return hidden, user
|
| 17 |
|
| 18 |
|
| 19 |
+
def extract_game_context_fields(hidden_context: str) -> Dict[str, str]:
|
| 20 |
+
fields = {"category": "", "difficulty": "", "question": "", "options": ""}
|
| 21 |
+
if not hidden_context:
|
| 22 |
+
return fields
|
| 23 |
|
| 24 |
+
patterns = {
|
| 25 |
+
"category": r"Category:\s*(.+)",
|
| 26 |
+
"difficulty": r"Difficulty:\s*(.+)",
|
| 27 |
+
"question": r"Question:\s*(.+?)(?:\nOptions:|\nPlayer balance:|\nLast outcome:|$)",
|
| 28 |
+
"options": r"Options:\s*(.+?)(?:\nPlayer balance:|\nLast outcome:|$)",
|
| 29 |
+
}
|
| 30 |
+
for key, pattern in patterns.items():
|
| 31 |
+
m = re.search(pattern, hidden_context, re.DOTALL)
|
| 32 |
+
if m:
|
| 33 |
+
fields[key] = m.group(1).strip()
|
| 34 |
+
return fields
|
| 35 |
|
| 36 |
+
|
| 37 |
+
def detect_intent(user_text: str, supplied: Optional[str] = None) -> str:
|
| 38 |
+
lower = (user_text or "").strip().lower()
|
| 39 |
+
|
| 40 |
+
if supplied:
|
| 41 |
+
supplied = supplied.strip().lower()
|
| 42 |
+
if supplied in {"hint", "method", "walkthrough", "step_by_step", "full_working", "answer"}:
|
| 43 |
+
return supplied
|
| 44 |
+
|
| 45 |
+
if any(p in lower for p in ["full working", "full working out", "show all the working", "complete working"]):
|
| 46 |
+
return "full_working"
|
| 47 |
+
if any(p in lower for p in ["step by step", "walkthrough", "work through", "explain step by step"]):
|
| 48 |
+
return "step_by_step"
|
| 49 |
+
if any(p in lower for p in ["how do i solve", "how to solve", "method", "what method", "approach"]):
|
| 50 |
+
return "method"
|
| 51 |
+
if any(p in lower for p in ["hint", "nudge", "first step", "how do i start", "what do i do first"]):
|
| 52 |
return "hint"
|
| 53 |
+
if any(p in lower for p in ["why", "explain", "breakdown"]):
|
| 54 |
+
return "step_by_step"
|
| 55 |
+
if any(p in lower for p in ["solve", "what is", "answer", "give me the answer"]):
|
| 56 |
return "answer"
|
| 57 |
+
return "answer"
|
| 58 |
|
| 59 |
|
| 60 |
+
def intent_to_help_mode(intent: str) -> str:
|
| 61 |
+
if intent == "hint":
|
| 62 |
+
return "hint"
|
| 63 |
+
if intent in {"method", "step_by_step", "full_working"}:
|
| 64 |
+
return "walkthrough"
|
| 65 |
+
return "answer"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conversation_logic.py
CHANGED
|
@@ -1,80 +1,106 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
from typing import Any, Dict, List
|
| 4 |
|
| 5 |
-
from context_parser import
|
| 6 |
-
from formatting import
|
| 7 |
-
from generator_engine import
|
| 8 |
-
from models import
|
| 9 |
from quant_solver import is_quant_question, solve_quant
|
| 10 |
-
from
|
| 11 |
-
from utils import
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
game_context=game_context,
|
| 64 |
-
retrieval_summary=retrieval_text,
|
| 65 |
-
tone=tone,
|
| 66 |
-
verbosity=verbosity,
|
| 67 |
-
transparency=transparency,
|
| 68 |
-
)
|
| 69 |
-
reply = enforce_study_guardrails(reply)
|
| 70 |
-
return ResponsePackage(
|
| 71 |
-
reply=reply,
|
| 72 |
-
meta={
|
| 73 |
-
"domain": "general",
|
| 74 |
-
"solved": False,
|
| 75 |
-
"help_mode": help_mode,
|
| 76 |
-
"topic": "social" if is_social_or_meta_message(user_text) else "general",
|
| 77 |
-
"used_retrieval": bool(retrieval_chunks),
|
| 78 |
-
"used_generator": True,
|
| 79 |
-
},
|
| 80 |
-
)
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
from typing import Any, Dict, List, Optional
|
| 4 |
|
| 5 |
+
from context_parser import intent_to_help_mode
|
| 6 |
+
from formatting import format_reply
|
| 7 |
+
from generator_engine import GeneratorEngine
|
| 8 |
+
from models import RetrievedChunk, SolverResult
|
| 9 |
from quant_solver import is_quant_question, solve_quant
|
| 10 |
+
from retrieval_engine import RetrievalEngine
|
| 11 |
+
from utils import short_lines
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _teaching_lines(chunks: List[RetrievedChunk]) -> List[str]:
|
| 15 |
+
lines = []
|
| 16 |
+
for chunk in chunks:
|
| 17 |
+
text = chunk.text.strip().replace("\n", " ")
|
| 18 |
+
if len(text) > 220:
|
| 19 |
+
text = text[:217].rstrip() + "…"
|
| 20 |
+
lines.append(f"- {chunk.topic}: {text}")
|
| 21 |
+
return lines
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _compose_quant_reply(result: SolverResult, intent: str, reveal_answer: bool, verbosity: float) -> str:
|
| 25 |
+
steps = result.steps or []
|
| 26 |
+
internal = result.internal_answer or result.answer_value or ""
|
| 27 |
+
|
| 28 |
+
if intent == "hint":
|
| 29 |
+
return steps[0] if steps else "Start by translating the wording into an equation."
|
| 30 |
+
|
| 31 |
+
if intent == "method":
|
| 32 |
+
body = "Use this method:\n" + "\n".join(f"- {s}" for s in steps[:3])
|
| 33 |
+
if reveal_answer and internal:
|
| 34 |
+
body += f"\n\nInternal result: {internal}."
|
| 35 |
+
return body
|
| 36 |
+
|
| 37 |
+
if intent in {"step_by_step", "full_working"}:
|
| 38 |
+
body = "\n".join(f"{i+1}. {s}" for i, s in enumerate(steps[:4])) if steps else "Work through the algebra one step at a time."
|
| 39 |
+
if reveal_answer and internal:
|
| 40 |
+
body += f"\n\nSo the result is {internal}."
|
| 41 |
+
return body
|
| 42 |
+
|
| 43 |
+
if reveal_answer and internal:
|
| 44 |
+
return f"The result is {internal}."
|
| 45 |
+
|
| 46 |
+
if steps:
|
| 47 |
+
return "\n".join(f"- {s}" for s in steps[:2])
|
| 48 |
+
return result.reply or "I can help solve this, but I need a little more structure from the question."
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class ConversationEngine:
|
| 52 |
+
def __init__(self, retriever: RetrievalEngine, generator: Optional[GeneratorEngine] = None):
|
| 53 |
+
self.retriever = retriever
|
| 54 |
+
self.generator = generator or GeneratorEngine()
|
| 55 |
+
|
| 56 |
+
def generate_response(
|
| 57 |
+
self,
|
| 58 |
+
raw_user_text: str,
|
| 59 |
+
tone: float,
|
| 60 |
+
verbosity: float,
|
| 61 |
+
transparency: float,
|
| 62 |
+
intent: str,
|
| 63 |
+
help_mode: str,
|
| 64 |
+
chat_history: Optional[List[Dict[str, Any]]] = None,
|
| 65 |
+
question_text: str = "",
|
| 66 |
+
options_text: str = "",
|
| 67 |
+
retrieval_context: str = "",
|
| 68 |
+
) -> SolverResult:
|
| 69 |
+
user_text = (raw_user_text or "").strip()
|
| 70 |
+
question_block = "\n".join([x for x in [question_text.strip(), options_text.strip()] if x]).strip()
|
| 71 |
+
solver_input = user_text or question_text or question_block
|
| 72 |
+
|
| 73 |
+
if is_quant_question(solver_input) or (question_text and is_quant_question(question_text)):
|
| 74 |
+
result = solve_quant(solver_input if is_quant_question(solver_input) else question_text)
|
| 75 |
+
result.help_mode = help_mode
|
| 76 |
+
reveal_answer = intent in {"answer", "full_working"} or transparency >= 0.85
|
| 77 |
+
|
| 78 |
+
if result.topic:
|
| 79 |
+
chunks = self.retriever.search(question_text or user_text, topic=result.topic, intent=intent, k=3)
|
| 80 |
+
else:
|
| 81 |
+
chunks = self.retriever.search(question_text or user_text, topic="general", intent=intent, k=3)
|
| 82 |
+
result.teaching_chunks = chunks
|
| 83 |
+
result.used_retrieval = bool(chunks)
|
| 84 |
+
|
| 85 |
+
core = _compose_quant_reply(result, intent, reveal_answer=reveal_answer, verbosity=verbosity)
|
| 86 |
+
if chunks and intent in {"hint", "method", "step_by_step", "full_working"}:
|
| 87 |
+
core += "\n\nRelevant study notes:\n" + "\n".join(_teaching_lines(chunks))
|
| 88 |
+
result.reply = format_reply(core, tone, verbosity, transparency, help_mode)
|
| 89 |
+
return result
|
| 90 |
+
|
| 91 |
+
# Non-quant conversational support
|
| 92 |
+
result = SolverResult(domain="general", solved=False, help_mode=help_mode)
|
| 93 |
+
prompt = (
|
| 94 |
+
"You are a helpful study assistant. Reply naturally and briefly. "
|
| 95 |
+
"Do not invent facts. If the user is asking for emotional support or general help, be supportive and practical.\n\n"
|
| 96 |
+
f"User message: {user_text}"
|
| 97 |
)
|
| 98 |
+
generated = self.generator.generate(prompt) if self.generator and self.generator.available() else None
|
| 99 |
+
if generated:
|
| 100 |
+
result.reply = format_reply(generated, tone, verbosity, transparency, help_mode)
|
| 101 |
+
result.used_generator = True
|
| 102 |
+
return result
|
| 103 |
|
| 104 |
+
fallback = "I can help with the current question, explain a method, or talk through the next step."
|
| 105 |
+
result.reply = format_reply(fallback, tone, verbosity, transparency, help_mode)
|
| 106 |
+
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
formatting.py
CHANGED
|
@@ -1,43 +1,27 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
from typing import List
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
return ""
|
| 13 |
-
if tone < 0.67:
|
| 14 |
-
return "You’ve got this. "
|
| 15 |
-
return "You’re doing the right thing by breaking it down. "
|
| 16 |
|
|
|
|
|
|
|
| 17 |
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
-
lines: List[str] = []
|
| 21 |
-
prefix = tone_prefix(tone)
|
| 22 |
-
if prefix:
|
| 23 |
-
lines.append(prefix.strip())
|
| 24 |
-
|
| 25 |
-
lines.append(result.explanation)
|
| 26 |
-
|
| 27 |
-
if verbosity >= 0.34 and result.steps:
|
| 28 |
-
max_steps = 2 if verbosity < 0.67 else 4
|
| 29 |
-
chosen_steps = result.steps[:max_steps]
|
| 30 |
-
if transparency < 0.34:
|
| 31 |
-
lines.append("Focus on the setup rather than the final value:")
|
| 32 |
-
elif transparency < 0.67:
|
| 33 |
-
lines.append("Use this structure:")
|
| 34 |
-
else:
|
| 35 |
-
lines.append("Here is the reasoning structure to follow:")
|
| 36 |
-
lines.extend(f"- {step}" for step in chosen_steps)
|
| 37 |
-
|
| 38 |
-
if transparency >= 0.67 and result.internal_answer_value:
|
| 39 |
-
lines.append("When you finish the calculation, compare your value with the available choices rather than jumping straight to a choice now.")
|
| 40 |
-
else:
|
| 41 |
-
lines.append("Then compare your result against the options yourself.")
|
| 42 |
-
|
| 43 |
-
return soft_truncate("\n".join(lines).strip(), 1600)
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
|
|
|
| 3 |
|
| 4 |
+
def style_prefix(tone: float) -> str:
|
| 5 |
+
if tone < 0.33:
|
| 6 |
+
return "Let’s solve it efficiently."
|
| 7 |
+
if tone < 0.66:
|
| 8 |
+
return "Let’s work through it."
|
| 9 |
+
return "You’ve got this — let’s solve it cleanly."
|
| 10 |
|
| 11 |
|
| 12 |
+
def format_reply(core: str, tone: float, verbosity: float, transparency: float, help_mode: str) -> str:
|
| 13 |
+
prefix = style_prefix(tone)
|
| 14 |
+
core = (core or "").strip()
|
| 15 |
+
if not core:
|
| 16 |
+
return prefix
|
| 17 |
|
| 18 |
+
if help_mode == "hint":
|
| 19 |
+
return f"{prefix}\n\nHint:\n{core}"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
+
if help_mode == "walkthrough" and verbosity >= 0.4:
|
| 22 |
+
return f"{prefix}\n\nWalkthrough:\n{core}"
|
| 23 |
|
| 24 |
+
if transparency >= 0.75 and help_mode == "answer":
|
| 25 |
+
return f"{prefix}\n\n{core}"
|
| 26 |
|
| 27 |
+
return f"{prefix}\n\n{core}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generator_engine.py
CHANGED
|
@@ -1,181 +1,33 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
from typing import
|
| 4 |
|
| 5 |
try:
|
| 6 |
from transformers import pipeline
|
| 7 |
except Exception:
|
| 8 |
pipeline = None
|
| 9 |
|
| 10 |
-
from config import settings
|
| 11 |
-
from models import GameContext
|
| 12 |
-
from utils import normalize_spaces, soft_truncate
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
self.task = settings.generator_task
|
| 18 |
-
self.model_name = settings.generator_model
|
| 19 |
self.pipe = None
|
| 20 |
-
if pipeline is None:
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
model=self.model_name,
|
| 26 |
-
tokenizer=self.model_name,
|
| 27 |
-
device=-1,
|
| 28 |
-
)
|
| 29 |
-
except Exception:
|
| 30 |
-
self.pipe = None
|
| 31 |
-
|
| 32 |
-
def _style_instruction(self, tone: float, verbosity: float, transparency: float) -> str:
|
| 33 |
-
if tone < 0.34:
|
| 34 |
-
tone_label = "neutral and direct"
|
| 35 |
-
elif tone < 0.67:
|
| 36 |
-
tone_label = "friendly and professional"
|
| 37 |
-
else:
|
| 38 |
-
tone_label = "warm, encouraging, and supportive"
|
| 39 |
-
|
| 40 |
-
if verbosity < 0.34:
|
| 41 |
-
verbosity_label = "brief"
|
| 42 |
-
elif verbosity < 0.67:
|
| 43 |
-
verbosity_label = "moderate length"
|
| 44 |
-
else:
|
| 45 |
-
verbosity_label = "detailed"
|
| 46 |
-
|
| 47 |
-
if transparency < 0.34:
|
| 48 |
-
transparency_label = "do not expose chain-of-thought; keep reasoning concise"
|
| 49 |
-
elif transparency < 0.67:
|
| 50 |
-
transparency_label = "give a short explanation of reasoning"
|
| 51 |
-
else:
|
| 52 |
-
transparency_label = "explain the logic clearly but do not reveal hidden chain-of-thought"
|
| 53 |
-
|
| 54 |
-
return (
|
| 55 |
-
f"Style: {tone_label}. "
|
| 56 |
-
f"Length: {verbosity_label}. "
|
| 57 |
-
f"Reasoning visibility: {transparency_label}."
|
| 58 |
-
)
|
| 59 |
-
|
| 60 |
-
def _build_prompt(
|
| 61 |
-
self,
|
| 62 |
-
*,
|
| 63 |
-
user_text: str,
|
| 64 |
-
chat_history: List[Dict[str, Any]],
|
| 65 |
-
game_context: GameContext,
|
| 66 |
-
retrieval_summary: str,
|
| 67 |
-
tone: float,
|
| 68 |
-
verbosity: float,
|
| 69 |
-
transparency: float,
|
| 70 |
-
) -> str:
|
| 71 |
-
history_lines: List[str] = []
|
| 72 |
-
for item in chat_history[-8:]:
|
| 73 |
-
role = str(item.get("role", "user")).strip().lower()
|
| 74 |
-
text = normalize_spaces(str(item.get("text", "")))
|
| 75 |
-
if text:
|
| 76 |
-
history_lines.append(f"{role.title()}: {text}")
|
| 77 |
-
|
| 78 |
-
context_lines: List[str] = []
|
| 79 |
-
if game_context.question_text:
|
| 80 |
-
context_lines.append(f"Current question: {game_context.question_text}")
|
| 81 |
-
if game_context.options_text:
|
| 82 |
-
context_lines.append(f"Answer options: {game_context.options_text}")
|
| 83 |
-
if game_context.question_category:
|
| 84 |
-
context_lines.append(f"Question category: {game_context.question_category}")
|
| 85 |
-
if game_context.question_difficulty:
|
| 86 |
-
context_lines.append(f"Question difficulty: {game_context.question_difficulty}")
|
| 87 |
-
if retrieval_summary:
|
| 88 |
-
context_lines.append(f"Reference material:\n{retrieval_summary}")
|
| 89 |
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
"For non-quant messages, respond naturally and helpfully. "
|
| 93 |
-
"If the message is about gameplay, AI usage, sliders, or study procedure, answer clearly and accurately. "
|
| 94 |
-
"Do not invent access to hidden systems or researcher-only information. "
|
| 95 |
-
"Never claim certainty about unknown participant details. "
|
| 96 |
-
+ self._style_instruction(tone, verbosity, transparency)
|
| 97 |
-
)
|
| 98 |
|
| 99 |
-
|
| 100 |
-
prompt = "\n\n".join(
|
| 101 |
-
part for part in [
|
| 102 |
-
system,
|
| 103 |
-
"Recent conversation:\n" + "\n".join(history_lines) if history_lines else "",
|
| 104 |
-
"Game context:\n" + "\n".join(context_lines) if context_lines else "",
|
| 105 |
-
user_block,
|
| 106 |
-
"Assistant:"
|
| 107 |
-
] if part
|
| 108 |
-
)
|
| 109 |
-
return prompt
|
| 110 |
-
|
| 111 |
-
def _fallback_reply(self, user_text: str, retrieval_summary: str) -> str:
|
| 112 |
-
lower = (user_text or "").lower()
|
| 113 |
-
if "what can you do" in lower:
|
| 114 |
-
return (
|
| 115 |
-
"I can help with GMAT-style quant questions, explain how the sliders affect the AI, "
|
| 116 |
-
"talk through strategy in the trading-game study, and answer general non-quant questions."
|
| 117 |
-
)
|
| 118 |
-
if "token" in lower or "cost" in lower:
|
| 119 |
-
return (
|
| 120 |
-
"AI use can be treated as a limited resource in the study. "
|
| 121 |
-
"A prompt can have a visible token cost before sending, and that cost can be logged for later analysis."
|
| 122 |
-
)
|
| 123 |
-
if retrieval_summary:
|
| 124 |
-
return "I can help with that. Here are the most relevant notes I found:\n\n" + retrieval_summary
|
| 125 |
-
return "I can help with game questions, AI-use questions, and general non-quant conversation."
|
| 126 |
-
|
| 127 |
-
def generate(
|
| 128 |
-
self,
|
| 129 |
-
*,
|
| 130 |
-
user_text: str,
|
| 131 |
-
chat_history: List[Dict[str, Any]],
|
| 132 |
-
game_context: GameContext,
|
| 133 |
-
retrieval_summary: str,
|
| 134 |
-
tone: float,
|
| 135 |
-
verbosity: float,
|
| 136 |
-
transparency: float,
|
| 137 |
-
) -> str:
|
| 138 |
if self.pipe is None:
|
| 139 |
-
return
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
verbosity=verbosity,
|
| 148 |
-
transparency=transparency,
|
| 149 |
-
)
|
| 150 |
-
|
| 151 |
-
kwargs: Dict[str, Any] = {
|
| 152 |
-
"max_new_tokens": settings.generator_max_new_tokens,
|
| 153 |
-
}
|
| 154 |
-
if self.task == "text-generation":
|
| 155 |
-
kwargs.update(
|
| 156 |
-
{
|
| 157 |
-
"do_sample": settings.generator_do_sample,
|
| 158 |
-
"temperature": settings.generator_temperature,
|
| 159 |
-
"top_p": settings.generator_top_p,
|
| 160 |
-
"return_full_text": False,
|
| 161 |
-
}
|
| 162 |
-
)
|
| 163 |
-
else:
|
| 164 |
-
kwargs.update(
|
| 165 |
-
{
|
| 166 |
-
"do_sample": settings.generator_do_sample,
|
| 167 |
-
"temperature": settings.generator_temperature,
|
| 168 |
-
"top_p": settings.generator_top_p,
|
| 169 |
-
}
|
| 170 |
-
)
|
| 171 |
-
|
| 172 |
-
result = self.pipe(prompt, **kwargs)
|
| 173 |
-
if not result:
|
| 174 |
-
return self._fallback_reply(user_text, retrieval_summary)
|
| 175 |
-
|
| 176 |
-
first = result[0]
|
| 177 |
-
text = first.get("generated_text") or first.get("summary_text") or ""
|
| 178 |
-
text = normalize_spaces(text)
|
| 179 |
-
if not text:
|
| 180 |
-
return self._fallback_reply(user_text, retrieval_summary)
|
| 181 |
-
return soft_truncate(text, 1200)
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
from typing import Optional
|
| 4 |
|
| 5 |
try:
|
| 6 |
from transformers import pipeline
|
| 7 |
except Exception:
|
| 8 |
pipeline = None
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
class GeneratorEngine:
|
| 12 |
+
def __init__(self, model_name: str = "google/flan-t5-small"):
|
| 13 |
+
self.model_name = model_name
|
|
|
|
|
|
|
| 14 |
self.pipe = None
|
| 15 |
+
if pipeline is not None:
|
| 16 |
+
try:
|
| 17 |
+
self.pipe = pipeline("text2text-generation", model=model_name)
|
| 18 |
+
except Exception:
|
| 19 |
+
self.pipe = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
+
def available(self) -> bool:
|
| 22 |
+
return self.pipe is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
+
def generate(self, prompt: str, max_new_tokens: int = 96) -> Optional[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
if self.pipe is None:
|
| 26 |
+
return None
|
| 27 |
+
try:
|
| 28 |
+
out = self.pipe(prompt, max_new_tokens=max_new_tokens, do_sample=False)
|
| 29 |
+
if out and isinstance(out, list):
|
| 30 |
+
return str(out[0].get("generated_text", "")).strip()
|
| 31 |
+
except Exception:
|
| 32 |
+
return None
|
| 33 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging_store.py
CHANGED
|
@@ -1,142 +1,76 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import json
|
| 4 |
-
import
|
| 5 |
from datetime import datetime, timezone
|
| 6 |
-
from pathlib import Path
|
| 7 |
from typing import Any, Dict, List, Optional
|
| 8 |
|
| 9 |
-
from huggingface_hub import CommitOperationAdd, HfApi
|
| 10 |
-
|
| 11 |
-
from config import settings
|
| 12 |
-
from models import EventLogRequest, SessionFinalizeRequest, SessionRecord, SessionStartRequest
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def _utc_stamp() -> str:
|
| 16 |
-
return datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S%fZ")
|
| 17 |
-
|
| 18 |
|
| 19 |
class LoggingStore:
|
| 20 |
-
def __init__(self
|
| 21 |
-
self.root =
|
| 22 |
-
self.root
|
| 23 |
-
self.
|
| 24 |
-
self.
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def _write_json(self, path: Path, payload: Dict[str, Any]) -> None:
|
| 45 |
-
path.write_text(json.dumps(payload, indent=2, ensure_ascii=False), encoding="utf-8")
|
| 46 |
-
|
| 47 |
-
def _push_file_to_hub(self, local_path: Path, repo_path: str) -> None:
|
| 48 |
-
if not (settings.push_logs_to_hub and settings.log_dataset_repo_id and self.hf_api):
|
| 49 |
-
return
|
| 50 |
-
self.hf_api.create_repo(
|
| 51 |
-
repo_id=settings.log_dataset_repo_id,
|
| 52 |
-
repo_type="dataset",
|
| 53 |
-
private=settings.log_dataset_private,
|
| 54 |
-
exist_ok=True,
|
| 55 |
-
)
|
| 56 |
-
self.hf_api.create_commit(
|
| 57 |
-
repo_id=settings.log_dataset_repo_id,
|
| 58 |
-
repo_type="dataset",
|
| 59 |
-
commit_message=f"Add log file {repo_path}",
|
| 60 |
-
operations=[
|
| 61 |
-
CommitOperationAdd(
|
| 62 |
-
path_in_repo=repo_path,
|
| 63 |
-
path_or_fileobj=str(local_path),
|
| 64 |
-
)
|
| 65 |
-
],
|
| 66 |
-
)
|
| 67 |
-
|
| 68 |
-
def start_session(self, req: SessionStartRequest) -> SessionRecord:
|
| 69 |
-
session_id = req.session_id or str(uuid.uuid4())
|
| 70 |
-
record = SessionRecord(
|
| 71 |
-
participant_id=req.participant_id,
|
| 72 |
-
session_id=session_id,
|
| 73 |
-
started_at=req.started_at,
|
| 74 |
-
condition=req.condition,
|
| 75 |
-
study_id=req.study_id,
|
| 76 |
-
game_version=req.game_version,
|
| 77 |
-
metadata=req.metadata,
|
| 78 |
-
summary={},
|
| 79 |
-
event_count=0,
|
| 80 |
-
)
|
| 81 |
-
path = self._session_path(session_id)
|
| 82 |
-
payload = record.model_dump()
|
| 83 |
-
self._write_json(path, payload)
|
| 84 |
-
self._push_file_to_hub(path, f"sessions/{session_id}.json")
|
| 85 |
return record
|
| 86 |
|
| 87 |
-
def
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
event_path = self._event_path(req.session_id, req.event_type)
|
| 94 |
-
self._write_json(event_path, payload)
|
| 95 |
-
|
| 96 |
-
session["event_count"] = int(session.get("event_count", 0)) + 1
|
| 97 |
-
self._write_json(self._session_path(req.session_id), session)
|
| 98 |
-
|
| 99 |
-
self._push_file_to_hub(event_path, f"events/{req.session_id}/{event_path.name}")
|
| 100 |
-
self._push_file_to_hub(self._session_path(req.session_id), f"sessions/{req.session_id}.json")
|
| 101 |
-
|
| 102 |
-
return {
|
| 103 |
-
"ok": True,
|
| 104 |
-
"session_id": req.session_id,
|
| 105 |
-
"event_file": event_path.name,
|
| 106 |
-
"event_count": session["event_count"],
|
| 107 |
}
|
|
|
|
|
|
|
| 108 |
|
| 109 |
-
def finalize_session(self,
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
def list_sessions(self) -> List[Dict[str, Any]]:
|
| 121 |
-
|
| 122 |
-
for path in sorted(self.sessions_dir.glob("*.json"), reverse=True):
|
| 123 |
-
try:
|
| 124 |
-
items.append(json.loads(path.read_text(encoding="utf-8")))
|
| 125 |
-
except Exception:
|
| 126 |
-
continue
|
| 127 |
-
return items
|
| 128 |
-
|
| 129 |
-
def read_session_bundle(self, session_id: str) -> Dict[str, Any]:
|
| 130 |
-
session = self._load_session(session_id)
|
| 131 |
-
if session is None:
|
| 132 |
-
raise FileNotFoundError(f"Unknown session_id: {session_id}")
|
| 133 |
-
|
| 134 |
-
events: List[Dict[str, Any]] = []
|
| 135 |
-
session_dir = self.events_dir / session_id
|
| 136 |
-
for path in sorted(session_dir.glob("*.json")):
|
| 137 |
-
try:
|
| 138 |
-
events.append(json.loads(path.read_text(encoding="utf-8")))
|
| 139 |
-
except Exception:
|
| 140 |
-
continue
|
| 141 |
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import json
|
| 4 |
+
import os
|
| 5 |
from datetime import datetime, timezone
|
|
|
|
| 6 |
from typing import Any, Dict, List, Optional
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
class LoggingStore:
|
| 10 |
+
def __init__(self, root: str = "logs"):
|
| 11 |
+
self.root = root
|
| 12 |
+
os.makedirs(self.root, exist_ok=True)
|
| 13 |
+
self.sessions_path = os.path.join(self.root, "sessions.jsonl")
|
| 14 |
+
self.events_path = os.path.join(self.root, "events.jsonl")
|
| 15 |
+
|
| 16 |
+
def _append(self, path: str, payload: Dict[str, Any]) -> None:
|
| 17 |
+
with open(path, "a", encoding="utf-8") as f:
|
| 18 |
+
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
| 19 |
+
|
| 20 |
+
def _now(self) -> str:
|
| 21 |
+
return datetime.now(timezone.utc).isoformat()
|
| 22 |
+
|
| 23 |
+
def start_session(self, session_id: str, user_id: Optional[str], condition: Optional[str], metadata: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
| 24 |
+
record = {
|
| 25 |
+
"session_id": session_id,
|
| 26 |
+
"user_id": user_id,
|
| 27 |
+
"condition": condition,
|
| 28 |
+
"metadata": metadata or {},
|
| 29 |
+
"started_at": self._now(),
|
| 30 |
+
"type": "session_start",
|
| 31 |
+
}
|
| 32 |
+
self._append(self.sessions_path, record)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
return record
|
| 34 |
|
| 35 |
+
def log_event(self, session_id: str, event_type: str, payload: Optional[Dict[str, Any]], timestamp: Optional[str]) -> Dict[str, Any]:
|
| 36 |
+
record = {
|
| 37 |
+
"session_id": session_id,
|
| 38 |
+
"event_type": event_type,
|
| 39 |
+
"timestamp": timestamp or self._now(),
|
| 40 |
+
"payload": payload or {},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
}
|
| 42 |
+
self._append(self.events_path, record)
|
| 43 |
+
return record
|
| 44 |
|
| 45 |
+
def finalize_session(self, session_id: str, summary: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
| 46 |
+
record = {
|
| 47 |
+
"session_id": session_id,
|
| 48 |
+
"summary": summary or {},
|
| 49 |
+
"finalized_at": self._now(),
|
| 50 |
+
"type": "session_finalize",
|
| 51 |
+
}
|
| 52 |
+
self._append(self.sessions_path, record)
|
| 53 |
+
return record
|
| 54 |
|
| 55 |
+
def _read_jsonl(self, path: str) -> List[Dict[str, Any]]:
|
| 56 |
+
if not os.path.exists(path):
|
| 57 |
+
return []
|
| 58 |
+
rows = []
|
| 59 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 60 |
+
for line in f:
|
| 61 |
+
line = line.strip()
|
| 62 |
+
if not line:
|
| 63 |
+
continue
|
| 64 |
+
try:
|
| 65 |
+
rows.append(json.loads(line))
|
| 66 |
+
except Exception:
|
| 67 |
+
continue
|
| 68 |
+
return rows
|
| 69 |
|
| 70 |
def list_sessions(self) -> List[Dict[str, Any]]:
|
| 71 |
+
return self._read_jsonl(self.sessions_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
+
def get_session(self, session_id: str) -> Dict[str, Any]:
|
| 74 |
+
sessions = [r for r in self._read_jsonl(self.sessions_path) if r.get("session_id") == session_id]
|
| 75 |
+
events = [r for r in self._read_jsonl(self.events_path) if r.get("session_id") == session_id]
|
| 76 |
+
return {"session_id": session_id, "records": sessions, "events": events}
|
models.py
CHANGED
|
@@ -1,19 +1,9 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
from dataclasses import dataclass, field
|
| 4 |
-
from datetime import datetime, timezone
|
| 5 |
from typing import Any, Dict, List, Optional
|
| 6 |
|
| 7 |
-
from pydantic import BaseModel
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def utc_now_iso() -> str:
|
| 11 |
-
return datetime.now(timezone.utc).isoformat()
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class ChatMessage(BaseModel):
|
| 15 |
-
role: str = "user"
|
| 16 |
-
text: str = ""
|
| 17 |
|
| 18 |
|
| 19 |
class ChatRequest(BaseModel):
|
|
@@ -23,93 +13,59 @@ class ChatRequest(BaseModel):
|
|
| 23 |
text: Optional[str] = None
|
| 24 |
user_message: Optional[str] = None
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
tone: float = 0.5
|
| 30 |
-
verbosity: float = 0.5
|
| 31 |
-
transparency: float = 0.5
|
| 32 |
|
| 33 |
help_mode: Optional[str] = None
|
|
|
|
|
|
|
| 34 |
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
question_text: str = ""
|
| 40 |
-
options_text: str = ""
|
| 41 |
-
question_category: str = ""
|
| 42 |
-
question_difficulty: str = ""
|
| 43 |
-
player_balance: Optional[float] = None
|
| 44 |
-
last_outcome: str = ""
|
| 45 |
-
|
| 46 |
-
@property
|
| 47 |
-
def combined_question_block(self) -> str:
|
| 48 |
-
parts = [p for p in [self.question_text, self.options_text] if p.strip()]
|
| 49 |
-
return "\n".join(parts).strip()
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
@dataclass
|
| 53 |
-
class RetrievalChunk:
|
| 54 |
-
chunk_id: str
|
| 55 |
-
text: str
|
| 56 |
-
source_name: str = ""
|
| 57 |
-
topic_guess: str = ""
|
| 58 |
-
score: float = 0.0
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
@dataclass
|
| 62 |
-
class SolverResult:
|
| 63 |
-
solved: bool
|
| 64 |
-
explanation: str
|
| 65 |
-
domain: str = "general"
|
| 66 |
-
internal_answer_value: Optional[str] = None
|
| 67 |
-
internal_answer_letter: Optional[str] = None
|
| 68 |
-
detected_topic: str = ""
|
| 69 |
-
steps: List[str] = field(default_factory=list)
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
@dataclass
|
| 73 |
-
class ResponsePackage:
|
| 74 |
-
reply: str
|
| 75 |
-
meta: Dict[str, Any]
|
| 76 |
|
| 77 |
|
| 78 |
class SessionStartRequest(BaseModel):
|
| 79 |
-
|
| 80 |
-
|
| 81 |
condition: Optional[str] = None
|
| 82 |
-
|
| 83 |
-
game_version: Optional[str] = None
|
| 84 |
-
metadata: Dict[str, Any] = Field(default_factory=dict)
|
| 85 |
-
started_at: str = Field(default_factory=utc_now_iso)
|
| 86 |
|
| 87 |
|
| 88 |
class EventLogRequest(BaseModel):
|
| 89 |
-
participant_id: str
|
| 90 |
session_id: str
|
| 91 |
event_type: str
|
| 92 |
-
timestamp: str =
|
| 93 |
-
|
| 94 |
-
turn_index: Optional[int] = None
|
| 95 |
-
payload: Dict[str, Any] = Field(default_factory=dict)
|
| 96 |
|
| 97 |
|
| 98 |
class SessionFinalizeRequest(BaseModel):
|
| 99 |
-
participant_id: str
|
| 100 |
session_id: str
|
| 101 |
-
|
| 102 |
-
summary: Dict[str, Any] = Field(default_factory=dict)
|
| 103 |
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
from dataclasses import dataclass, field
|
|
|
|
| 4 |
from typing import Any, Dict, List, Optional
|
| 5 |
|
| 6 |
+
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
class ChatRequest(BaseModel):
|
|
|
|
| 13 |
text: Optional[str] = None
|
| 14 |
user_message: Optional[str] = None
|
| 15 |
|
| 16 |
+
tone: Optional[float] = 0.5
|
| 17 |
+
verbosity: Optional[float] = 0.5
|
| 18 |
+
transparency: Optional[float] = 0.5
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
help_mode: Optional[str] = None
|
| 21 |
+
chat_history: Optional[List[Dict[str, Any]]] = None
|
| 22 |
+
history: Optional[List[Dict[str, Any]]] = None
|
| 23 |
|
| 24 |
+
question_text: Optional[str] = None
|
| 25 |
+
question_id: Optional[str] = None
|
| 26 |
+
session_id: Optional[str] = None
|
| 27 |
+
user_id: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
class SessionStartRequest(BaseModel):
|
| 31 |
+
session_id: str
|
| 32 |
+
user_id: Optional[str] = None
|
| 33 |
condition: Optional[str] = None
|
| 34 |
+
metadata: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
class EventLogRequest(BaseModel):
|
|
|
|
| 38 |
session_id: str
|
| 39 |
event_type: str
|
| 40 |
+
timestamp: Optional[str] = None
|
| 41 |
+
payload: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
|
| 42 |
|
| 43 |
|
| 44 |
class SessionFinalizeRequest(BaseModel):
|
|
|
|
| 45 |
session_id: str
|
| 46 |
+
summary: Optional[Dict[str, Any]] = None
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
+
@dataclass
|
| 50 |
+
class RetrievedChunk:
|
| 51 |
+
text: str
|
| 52 |
+
topic: str = "general"
|
| 53 |
+
source: str = "local"
|
| 54 |
+
score: float = 0.0
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@dataclass
|
| 58 |
+
class SolverResult:
|
| 59 |
+
reply: str = ""
|
| 60 |
+
domain: str = "fallback"
|
| 61 |
+
solved: bool = False
|
| 62 |
+
help_mode: str = "answer"
|
| 63 |
+
answer_letter: Optional[str] = None
|
| 64 |
+
answer_value: Optional[str] = None
|
| 65 |
+
topic: Optional[str] = None
|
| 66 |
+
used_retrieval: bool = False
|
| 67 |
+
used_generator: bool = False
|
| 68 |
+
internal_answer: Optional[str] = None
|
| 69 |
+
steps: List[str] = field(default_factory=list)
|
| 70 |
+
teaching_chunks: List[RetrievedChunk] = field(default_factory=list)
|
| 71 |
+
meta: Dict[str, Any] = field(default_factory=dict)
|
quant_solver.py
CHANGED
|
@@ -2,9 +2,8 @@ from __future__ import annotations
|
|
| 2 |
|
| 3 |
import math
|
| 4 |
import re
|
| 5 |
-
from fractions import Fraction
|
| 6 |
from statistics import mean, median
|
| 7 |
-
from typing import Dict,
|
| 8 |
|
| 9 |
try:
|
| 10 |
import sympy as sp
|
|
@@ -12,268 +11,214 @@ except Exception:
|
|
| 12 |
sp = None
|
| 13 |
|
| 14 |
from models import SolverResult
|
| 15 |
-
from utils import clean_math_text, normalize_spaces
|
| 16 |
-
|
| 17 |
-
CHOICE_LETTERS = ["A", "B", "C", "D", "E"]
|
| 18 |
-
|
| 19 |
|
| 20 |
|
| 21 |
def extract_choices(text: str) -> Dict[str, str]:
|
|
|
|
| 22 |
matches = list(
|
| 23 |
re.finditer(
|
| 24 |
-
r"(?
|
| 25 |
-
text
|
| 26 |
)
|
| 27 |
)
|
| 28 |
return {m.group(1).upper(): normalize_spaces(m.group(2)) for m in matches}
|
| 29 |
|
| 30 |
|
| 31 |
-
|
| 32 |
def has_answer_choices(text: str) -> bool:
|
| 33 |
return len(extract_choices(text)) >= 3
|
| 34 |
|
| 35 |
|
| 36 |
-
|
| 37 |
def is_quant_question(text: str) -> bool:
|
| 38 |
lower = clean_math_text(text).lower()
|
| 39 |
-
|
| 40 |
-
"
|
| 41 |
-
"
|
| 42 |
-
"
|
| 43 |
-
"
|
| 44 |
]
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
def _prepare_expression(expr: str) -> str:
|
| 50 |
expr = clean_math_text(expr).strip()
|
| 51 |
expr = expr.replace("^", "**")
|
| 52 |
-
expr = expr.replace("%", "/100")
|
| 53 |
expr = re.sub(r"(\d)\s*\(", r"\1*(", expr)
|
| 54 |
expr = re.sub(r"\)\s*(\d)", r")*\1", expr)
|
| 55 |
expr = re.sub(r"(\d)([a-zA-Z])", r"\1*\2", expr)
|
| 56 |
return expr
|
| 57 |
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
-
def _parse_numeric_text(text: str) -> Optional[float]:
|
| 61 |
-
raw = clean_math_text(text).strip().lower()
|
| 62 |
-
raw_no_space = raw.replace(" ", "")
|
| 63 |
-
|
| 64 |
-
pct_match = re.fullmatch(r"(-?\d+(?:\.\d+)?)%", raw_no_space)
|
| 65 |
-
if pct_match:
|
| 66 |
-
return float(pct_match.group(1)) / 100.0
|
| 67 |
-
|
| 68 |
-
frac_match = re.fullmatch(r"(-?\d+)\s*/\s*(-?\d+)", raw)
|
| 69 |
-
if frac_match:
|
| 70 |
-
num, den = float(frac_match.group(1)), float(frac_match.group(2))
|
| 71 |
-
return None if den == 0 else num / den
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
try:
|
| 74 |
return float(eval(_prepare_expression(raw), {"__builtins__": {}}, {"sqrt": math.sqrt, "pi": math.pi}))
|
| 75 |
except Exception:
|
| 76 |
return None
|
| 77 |
|
| 78 |
|
| 79 |
-
|
| 80 |
-
def compare_to_choices_numeric(answer_value: float, choices: Dict[str, str], tolerance: float = 1e-6) -> Optional[str]:
|
| 81 |
best_letter = None
|
| 82 |
best_diff = float("inf")
|
| 83 |
for letter, raw in choices.items():
|
| 84 |
-
parsed =
|
| 85 |
if parsed is None:
|
| 86 |
continue
|
| 87 |
diff = abs(parsed - answer_value)
|
| 88 |
if diff < best_diff:
|
| 89 |
best_diff = diff
|
| 90 |
best_letter = letter
|
| 91 |
-
if best_letter is not None and best_diff <=
|
| 92 |
return best_letter
|
| 93 |
return None
|
| 94 |
|
| 95 |
|
| 96 |
-
|
| 97 |
-
def _solve_percent_patterns(text: str) -> Optional[SolverResult]:
|
| 98 |
lower = clean_math_text(text).lower()
|
| 99 |
choices = extract_choices(text)
|
| 100 |
|
| 101 |
-
m = re.search(r"
|
| 102 |
if m:
|
| 103 |
-
p
|
| 104 |
-
|
|
|
|
| 105 |
return SolverResult(
|
| 106 |
-
solved=True,
|
| 107 |
domain="quant",
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
| 112 |
steps=[
|
| 113 |
-
f"
|
| 114 |
-
f"
|
| 115 |
-
"
|
| 116 |
],
|
| 117 |
)
|
| 118 |
|
| 119 |
-
m = re.search(r"(\d+(?:\.\d+)?)\s
|
| 120 |
if m:
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
ans = x / y * 100.0
|
| 125 |
return SolverResult(
|
| 126 |
-
solved=True,
|
| 127 |
domain="quant",
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
f"Compute {x}/{y}.",
|
| 135 |
-
"Multiply by 100 to convert to a percent, then match to the choices.",
|
| 136 |
-
],
|
| 137 |
)
|
| 138 |
return None
|
| 139 |
|
| 140 |
|
| 141 |
-
|
| 142 |
-
def _solve_average_patterns(text: str) -> Optional[SolverResult]:
|
| 143 |
lower = clean_math_text(text).lower()
|
| 144 |
nums = [float(n) for n in re.findall(r"-?\d+(?:\.\d+)?", lower)]
|
| 145 |
if not nums:
|
| 146 |
return None
|
| 147 |
-
|
| 148 |
if "mean" in lower or "average" in lower:
|
| 149 |
-
|
| 150 |
-
return SolverResult(
|
| 151 |
-
solved=True,
|
| 152 |
-
domain="quant",
|
| 153 |
-
explanation="Add the values and divide by how many values there are.",
|
| 154 |
-
internal_answer_value=f"{avg:g}",
|
| 155 |
-
detected_topic="statistics",
|
| 156 |
-
steps=[f"Add the listed values: total them carefully.", f"Count how many values there are: {len(nums)}.", "Divide total by count, then check the choices."],
|
| 157 |
-
)
|
| 158 |
-
|
| 159 |
if "median" in lower:
|
| 160 |
-
|
| 161 |
-
return SolverResult(
|
| 162 |
-
solved=True,
|
| 163 |
-
domain="quant",
|
| 164 |
-
explanation="Order the values first, then identify the middle position.",
|
| 165 |
-
internal_answer_value=f"{med:g}",
|
| 166 |
-
detected_topic="statistics",
|
| 167 |
-
steps=["Write the numbers in increasing order.", "Find the middle value, or average the two middle values if there are an even number of terms.", "Compare that result with the choices."],
|
| 168 |
-
)
|
| 169 |
-
return None
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
def _solve_ratio_patterns(text: str) -> Optional[SolverResult]:
|
| 174 |
-
lower = clean_math_text(text).lower()
|
| 175 |
-
m = re.search(r"ratio of\s+(\w+)\s+to\s+(\w+)\s+is\s+(\d+)\s*:\s*(\d+)", lower)
|
| 176 |
-
if m:
|
| 177 |
-
return SolverResult(
|
| 178 |
-
solved=True,
|
| 179 |
-
domain="quant",
|
| 180 |
-
explanation="Use the common multiplier that turns the ratio into the actual quantities.",
|
| 181 |
-
detected_topic="ratio",
|
| 182 |
-
steps=["Call the common multiplier k.", f"Write the two quantities as {m.group(3)}k and {m.group(4)}k.", "Use the extra condition in the question to solve for k, then compute the quantity you need."],
|
| 183 |
-
)
|
| 184 |
-
return None
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
def _solve_divisibility_patterns(text: str) -> Optional[SolverResult]:
|
| 189 |
-
lower = clean_math_text(text).lower()
|
| 190 |
-
if "divisible by" in lower and "integer" in lower:
|
| 191 |
-
expr_match = re.search(r"if\s+([a-z])\s+is an integer and\s+(.+?)\s+is divisible by\s+(\d+)", lower)
|
| 192 |
-
if expr_match:
|
| 193 |
-
var = expr_match.group(1)
|
| 194 |
-
expr = expr_match.group(2)
|
| 195 |
-
divisor = int(expr_match.group(3))
|
| 196 |
-
choices = extract_choices(text)
|
| 197 |
-
valid_letters: List[str] = []
|
| 198 |
-
valid_values: List[str] = []
|
| 199 |
-
if sp:
|
| 200 |
-
symbol = sp.symbols(var)
|
| 201 |
-
parsed_expr = sp.sympify(_prepare_expression(expr))
|
| 202 |
-
for letter, raw_choice in choices.items():
|
| 203 |
-
try:
|
| 204 |
-
value = int(float(_parse_numeric_text(raw_choice)))
|
| 205 |
-
except Exception:
|
| 206 |
-
continue
|
| 207 |
-
if parsed_expr.subs(symbol, value) % divisor == 0:
|
| 208 |
-
valid_letters.append(letter)
|
| 209 |
-
valid_values.append(str(value))
|
| 210 |
-
return SolverResult(
|
| 211 |
-
solved=bool(valid_letters),
|
| 212 |
-
domain="quant",
|
| 213 |
-
explanation="Test each answer choice in the divisibility condition instead of solving abstractly first.",
|
| 214 |
-
internal_answer_value=", ".join(valid_values) if valid_values else None,
|
| 215 |
-
internal_answer_letter=valid_letters[0] if len(valid_letters) == 1 else None,
|
| 216 |
-
detected_topic="number_theory",
|
| 217 |
-
steps=[
|
| 218 |
-
f"Use the condition that the expression must be divisible by {divisor}.",
|
| 219 |
-
"Substitute each answer choice into the expression.",
|
| 220 |
-
"Keep the value that makes the result a multiple of the divisor.",
|
| 221 |
-
],
|
| 222 |
-
)
|
| 223 |
return None
|
| 224 |
|
| 225 |
|
| 226 |
-
|
| 227 |
def _solve_linear_equation(text: str) -> Optional[SolverResult]:
|
| 228 |
-
if
|
| 229 |
return None
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
if "value of x" not in lower and not re.search(r"\bsolve\b", lower):
|
| 233 |
return None
|
| 234 |
-
|
| 235 |
-
eq_match = re.search(r"([\d\sa-zA-Z\+\-\*/\^\(\)=]+)", cleaned)
|
| 236 |
-
if not eq_match or "=" not in eq_match.group(1):
|
| 237 |
-
return None
|
| 238 |
-
expr = eq_match.group(1)
|
| 239 |
try:
|
| 240 |
lhs, rhs = expr.split("=", 1)
|
| 241 |
-
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
if not sol:
|
| 244 |
return None
|
| 245 |
value = sol[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
return SolverResult(
|
| 247 |
-
solved=True,
|
| 248 |
domain="quant",
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
)
|
| 254 |
except Exception:
|
| 255 |
return None
|
| 256 |
|
| 257 |
|
| 258 |
-
|
| 259 |
def solve_quant(text: str) -> SolverResult:
|
| 260 |
text = text or ""
|
| 261 |
-
for
|
| 262 |
-
result =
|
| 263 |
-
if result:
|
| 264 |
return result
|
| 265 |
-
|
| 266 |
-
topic = "general_quant"
|
| 267 |
-
steps = [
|
| 268 |
-
"Identify exactly what quantity the question wants.",
|
| 269 |
-
"Translate the words into an equation, ratio, table, or diagram.",
|
| 270 |
-
"Do the calculation carefully.",
|
| 271 |
-
"Use the answer choices to check reasonableness and units.",
|
| 272 |
-
]
|
| 273 |
return SolverResult(
|
| 274 |
-
solved=False,
|
| 275 |
domain="quant",
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
)
|
|
|
|
| 2 |
|
| 3 |
import math
|
| 4 |
import re
|
|
|
|
| 5 |
from statistics import mean, median
|
| 6 |
+
from typing import Dict, Optional
|
| 7 |
|
| 8 |
try:
|
| 9 |
import sympy as sp
|
|
|
|
| 11 |
sp = None
|
| 12 |
|
| 13 |
from models import SolverResult
|
| 14 |
+
from utils import clean_math_text, normalize_spaces
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
def extract_choices(text: str) -> Dict[str, str]:
|
| 18 |
+
text = text or ""
|
| 19 |
matches = list(
|
| 20 |
re.finditer(
|
| 21 |
+
r"(?i)\b([A-E])[\)\.:]\s*(.*?)(?=\s+\b[A-E][\)\.:]\s*|$)",
|
| 22 |
+
text,
|
| 23 |
)
|
| 24 |
)
|
| 25 |
return {m.group(1).upper(): normalize_spaces(m.group(2)) for m in matches}
|
| 26 |
|
| 27 |
|
|
|
|
| 28 |
def has_answer_choices(text: str) -> bool:
|
| 29 |
return len(extract_choices(text)) >= 3
|
| 30 |
|
| 31 |
|
|
|
|
| 32 |
def is_quant_question(text: str) -> bool:
|
| 33 |
lower = clean_math_text(text).lower()
|
| 34 |
+
keywords = [
|
| 35 |
+
"solve", "equation", "percent", "ratio", "probability", "mean", "median",
|
| 36 |
+
"average", "sum", "difference", "product", "quotient", "triangle", "circle",
|
| 37 |
+
"rectangle", "area", "perimeter", "volume", "algebra", "integer", "divisible",
|
| 38 |
+
"number", "fraction", "decimal", "geometry", "distance", "speed", "work",
|
| 39 |
]
|
| 40 |
+
if any(k in lower for k in keywords):
|
| 41 |
+
return True
|
| 42 |
+
if "=" in lower and re.search(r"[a-z]", lower):
|
| 43 |
+
return True
|
| 44 |
+
if re.search(r"\d", lower) and ("?" in lower or has_answer_choices(lower)):
|
| 45 |
+
return True
|
| 46 |
+
return False
|
| 47 |
|
| 48 |
|
| 49 |
def _prepare_expression(expr: str) -> str:
|
| 50 |
expr = clean_math_text(expr).strip()
|
| 51 |
expr = expr.replace("^", "**")
|
|
|
|
| 52 |
expr = re.sub(r"(\d)\s*\(", r"\1*(", expr)
|
| 53 |
expr = re.sub(r"\)\s*(\d)", r")*\1", expr)
|
| 54 |
expr = re.sub(r"(\d)([a-zA-Z])", r"\1*\2", expr)
|
| 55 |
return expr
|
| 56 |
|
| 57 |
|
| 58 |
+
def _extract_equation(text: str) -> Optional[str]:
|
| 59 |
+
cleaned = clean_math_text(text)
|
| 60 |
+
if "=" not in cleaned:
|
| 61 |
+
return None
|
| 62 |
+
patterns = [
|
| 63 |
+
r"([A-Za-z0-9\.\+\-\*/\^\(\)\s]*[a-zA-Z][A-Za-z0-9\.\+\-\*/\^\(\)\s]*=[A-Za-z0-9\.\+\-\*/\^\(\)\s]+)",
|
| 64 |
+
r"([0-9A-Za-z\.\+\-\*/\^\(\)\s]+=[0-9A-Za-z\.\+\-\*/\^\(\)\s]+)",
|
| 65 |
+
]
|
| 66 |
+
for pattern in patterns:
|
| 67 |
+
for m in re.finditer(pattern, cleaned):
|
| 68 |
+
candidate = m.group(1).strip()
|
| 69 |
+
tokens = re.findall(r"[a-z]", candidate.lower())
|
| 70 |
+
if tokens and not candidate.lower().startswith(("how do", "can you", "please", "what is", "solve ")):
|
| 71 |
+
return candidate
|
| 72 |
+
eq_index = cleaned.find("=")
|
| 73 |
+
left = re.findall(r"[A-Za-z0-9\.\+\-\*/\^\(\)\s]+$", cleaned[:eq_index])
|
| 74 |
+
right = re.findall(r"^[A-Za-z0-9\.\+\-\*/\^\(\)\s]+", cleaned[eq_index + 1:])
|
| 75 |
+
if left and right:
|
| 76 |
+
candidate = left[0].strip().split()[-1] + " = " + right[0].strip().split()[0]
|
| 77 |
+
if re.search(r"[a-z]", candidate.lower()):
|
| 78 |
+
return candidate
|
| 79 |
+
return None
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
+
def _parse_number(text: str) -> Optional[float]:
|
| 83 |
+
raw = clean_math_text(text).strip().lower()
|
| 84 |
+
pct = re.fullmatch(r"(-?\d+(?:\.\d+)?)%", raw.replace(" ", ""))
|
| 85 |
+
if pct:
|
| 86 |
+
return float(pct.group(1)) / 100.0
|
| 87 |
+
frac = re.fullmatch(r"(-?\d+)\s*/\s*(-?\d+)", raw)
|
| 88 |
+
if frac:
|
| 89 |
+
den = float(frac.group(2))
|
| 90 |
+
if den == 0:
|
| 91 |
+
return None
|
| 92 |
+
return float(frac.group(1)) / den
|
| 93 |
try:
|
| 94 |
return float(eval(_prepare_expression(raw), {"__builtins__": {}}, {"sqrt": math.sqrt, "pi": math.pi}))
|
| 95 |
except Exception:
|
| 96 |
return None
|
| 97 |
|
| 98 |
|
| 99 |
+
def _best_choice(answer_value: float, choices: Dict[str, str]) -> Optional[str]:
|
|
|
|
| 100 |
best_letter = None
|
| 101 |
best_diff = float("inf")
|
| 102 |
for letter, raw in choices.items():
|
| 103 |
+
parsed = _parse_number(raw)
|
| 104 |
if parsed is None:
|
| 105 |
continue
|
| 106 |
diff = abs(parsed - answer_value)
|
| 107 |
if diff < best_diff:
|
| 108 |
best_diff = diff
|
| 109 |
best_letter = letter
|
| 110 |
+
if best_letter is not None and best_diff <= 1e-6:
|
| 111 |
return best_letter
|
| 112 |
return None
|
| 113 |
|
| 114 |
|
| 115 |
+
def _solve_percent(text: str) -> Optional[SolverResult]:
|
|
|
|
| 116 |
lower = clean_math_text(text).lower()
|
| 117 |
choices = extract_choices(text)
|
| 118 |
|
| 119 |
+
m = re.search(r"(\d+(?:\.\d+)?)\s*(?:%|percent)\s+of\s+(?:a\s+)?number\s+is\s+(\d+(?:\.\d+)?)", lower)
|
| 120 |
if m:
|
| 121 |
+
p = float(m.group(1))
|
| 122 |
+
value = float(m.group(2))
|
| 123 |
+
ans = value / (p / 100.0)
|
| 124 |
return SolverResult(
|
|
|
|
| 125 |
domain="quant",
|
| 126 |
+
solved=True,
|
| 127 |
+
topic="percent",
|
| 128 |
+
answer_value=f"{ans:g}",
|
| 129 |
+
answer_letter=_best_choice(ans, choices) if choices else None,
|
| 130 |
+
internal_answer=f"{ans:g}",
|
| 131 |
steps=[
|
| 132 |
+
f"Let the number be n.",
|
| 133 |
+
f"Write {p}% of n as {p/100:g}n.",
|
| 134 |
+
f"Set {p/100:g}n = {value} and solve for n.",
|
| 135 |
],
|
| 136 |
)
|
| 137 |
|
| 138 |
+
m = re.search(r"what is\s+(\d+(?:\.\d+)?)\s*(?:%|percent)\s+of\s+(\d+(?:\.\d+)?)", lower)
|
| 139 |
if m:
|
| 140 |
+
p = float(m.group(1))
|
| 141 |
+
n = float(m.group(2))
|
| 142 |
+
ans = p / 100.0 * n
|
|
|
|
| 143 |
return SolverResult(
|
|
|
|
| 144 |
domain="quant",
|
| 145 |
+
solved=True,
|
| 146 |
+
topic="percent",
|
| 147 |
+
answer_value=f"{ans:g}",
|
| 148 |
+
answer_letter=_best_choice(ans, choices) if choices else None,
|
| 149 |
+
internal_answer=f"{ans:g}",
|
| 150 |
+
steps=[f"Convert {p}% to {p/100:g}.", f"Multiply by {n}."]
|
|
|
|
|
|
|
|
|
|
| 151 |
)
|
| 152 |
return None
|
| 153 |
|
| 154 |
|
| 155 |
+
def _solve_mean_median(text: str) -> Optional[SolverResult]:
|
|
|
|
| 156 |
lower = clean_math_text(text).lower()
|
| 157 |
nums = [float(n) for n in re.findall(r"-?\d+(?:\.\d+)?", lower)]
|
| 158 |
if not nums:
|
| 159 |
return None
|
|
|
|
| 160 |
if "mean" in lower or "average" in lower:
|
| 161 |
+
ans = mean(nums)
|
| 162 |
+
return SolverResult(domain="quant", solved=True, topic="statistics", answer_value=f"{ans:g}", internal_answer=f"{ans:g}", steps=["Add the values.", f"Divide by {len(nums)}."])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
if "median" in lower:
|
| 164 |
+
ans = median(nums)
|
| 165 |
+
return SolverResult(domain="quant", solved=True, topic="statistics", answer_value=f"{ans:g}", internal_answer=f"{ans:g}", steps=["Order the values.", "Take the middle value."])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
return None
|
| 167 |
|
| 168 |
|
|
|
|
| 169 |
def _solve_linear_equation(text: str) -> Optional[SolverResult]:
|
| 170 |
+
if sp is None:
|
| 171 |
return None
|
| 172 |
+
expr = _extract_equation(text)
|
| 173 |
+
if not expr:
|
|
|
|
| 174 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
try:
|
| 176 |
lhs, rhs = expr.split("=", 1)
|
| 177 |
+
symbols = sorted(set(re.findall(r"\b[a-z]\b", expr)))
|
| 178 |
+
if not symbols:
|
| 179 |
+
return None
|
| 180 |
+
var_name = symbols[0]
|
| 181 |
+
var = sp.symbols(var_name)
|
| 182 |
+
sol = sp.solve(sp.Eq(sp.sympify(_prepare_expression(lhs)), sp.sympify(_prepare_expression(rhs))), var)
|
| 183 |
if not sol:
|
| 184 |
return None
|
| 185 |
value = sol[0]
|
| 186 |
+
try:
|
| 187 |
+
as_float = float(value)
|
| 188 |
+
except Exception:
|
| 189 |
+
as_float = None
|
| 190 |
+
choices = extract_choices(text)
|
| 191 |
return SolverResult(
|
|
|
|
| 192 |
domain="quant",
|
| 193 |
+
solved=True,
|
| 194 |
+
topic="algebra",
|
| 195 |
+
answer_value=str(value),
|
| 196 |
+
answer_letter=_best_choice(as_float, choices) if (as_float is not None and choices) else None,
|
| 197 |
+
internal_answer=f"{var_name} = {value}",
|
| 198 |
+
steps=[
|
| 199 |
+
"Treat the statement as an equation.",
|
| 200 |
+
"Undo operations on both sides to isolate the variable.",
|
| 201 |
+
f"That gives {var_name} = {value}.",
|
| 202 |
+
],
|
| 203 |
)
|
| 204 |
except Exception:
|
| 205 |
return None
|
| 206 |
|
| 207 |
|
|
|
|
| 208 |
def solve_quant(text: str) -> SolverResult:
|
| 209 |
text = text or ""
|
| 210 |
+
for fn in (_solve_percent, _solve_mean_median, _solve_linear_equation):
|
| 211 |
+
result = fn(text)
|
| 212 |
+
if result is not None:
|
| 213 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
return SolverResult(
|
|
|
|
| 215 |
domain="quant",
|
| 216 |
+
solved=False,
|
| 217 |
+
topic="general_quant",
|
| 218 |
+
reply="This looks quantitative, but it does not match a strong rule-based pattern yet.",
|
| 219 |
+
steps=[
|
| 220 |
+
"Identify the quantity the question wants.",
|
| 221 |
+
"Translate the wording into an equation, ratio, or diagram.",
|
| 222 |
+
"Carry out the calculation carefully.",
|
| 223 |
+
],
|
| 224 |
)
|
retrieval_engine.py
CHANGED
|
@@ -1,133 +1,98 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import json
|
| 4 |
-
|
| 5 |
-
from typing import List
|
| 6 |
|
| 7 |
-
|
|
|
|
| 8 |
|
| 9 |
try:
|
| 10 |
-
|
| 11 |
except Exception:
|
| 12 |
-
|
| 13 |
|
| 14 |
try:
|
| 15 |
-
from sentence_transformers import
|
| 16 |
except Exception:
|
| 17 |
-
CrossEncoder = None
|
| 18 |
SentenceTransformer = None
|
| 19 |
|
| 20 |
-
from config import settings
|
| 21 |
-
from models import RetrievalChunk
|
| 22 |
-
from utils import normalize_spaces
|
| 23 |
-
|
| 24 |
|
| 25 |
class RetrievalEngine:
|
| 26 |
-
def __init__(self
|
| 27 |
-
self.
|
| 28 |
-
self.
|
| 29 |
-
self.
|
| 30 |
-
self.reranker = None
|
| 31 |
self.embeddings = None
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
scored.append((idx, float(score)))
|
| 79 |
-
scored.sort(key=lambda x: x[1], reverse=True)
|
| 80 |
-
results: List[RetrievalChunk] = []
|
| 81 |
-
for idx, score in scored[:k]:
|
| 82 |
-
row = self.rows[int(idx)]
|
| 83 |
-
results.append(
|
| 84 |
-
RetrievalChunk(
|
| 85 |
-
chunk_id=str(row.get("id", idx)),
|
| 86 |
-
text=self.texts[int(idx)],
|
| 87 |
-
source_name=str(row.get("source_name", "")),
|
| 88 |
-
topic_guess=str(row.get("topic_guess", "")),
|
| 89 |
-
score=float(score),
|
| 90 |
-
)
|
| 91 |
-
)
|
| 92 |
-
return results
|
| 93 |
-
|
| 94 |
-
def search(self, query: str, k: int | None = None) -> List[RetrievalChunk]:
|
| 95 |
-
query = normalize_spaces(query)
|
| 96 |
-
if not query:
|
| 97 |
return []
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
if self.
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
for idx, score in reranked:
|
| 123 |
-
row = self.rows[int(idx)]
|
| 124 |
-
results.append(
|
| 125 |
-
RetrievalChunk(
|
| 126 |
-
chunk_id=str(row.get("id", idx)),
|
| 127 |
-
text=self.texts[int(idx)],
|
| 128 |
-
source_name=str(row.get("source_name", "")),
|
| 129 |
-
topic_guess=str(row.get("topic_guess", "")),
|
| 130 |
-
score=float(score),
|
| 131 |
-
)
|
| 132 |
-
)
|
| 133 |
return results
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import json
|
| 4 |
+
import os
|
| 5 |
+
from typing import List, Optional
|
| 6 |
|
| 7 |
+
from models import RetrievedChunk
|
| 8 |
+
from utils import clean_math_text, score_token_overlap
|
| 9 |
|
| 10 |
try:
|
| 11 |
+
import numpy as np
|
| 12 |
except Exception:
|
| 13 |
+
np = None
|
| 14 |
|
| 15 |
try:
|
| 16 |
+
from sentence_transformers import SentenceTransformer
|
| 17 |
except Exception:
|
|
|
|
| 18 |
SentenceTransformer = None
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
class RetrievalEngine:
|
| 22 |
+
def __init__(self, data_path: str = "data/gmat_hf_chunks.jsonl"):
|
| 23 |
+
self.data_path = data_path
|
| 24 |
+
self.rows = self._load_rows(data_path)
|
| 25 |
+
self.encoder = None
|
|
|
|
| 26 |
self.embeddings = None
|
| 27 |
+
if SentenceTransformer is not None and self.rows:
|
| 28 |
+
try:
|
| 29 |
+
self.encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
| 30 |
+
self.embeddings = self.encoder.encode([r["text"] for r in self.rows], convert_to_numpy=True, normalize_embeddings=True)
|
| 31 |
+
except Exception:
|
| 32 |
+
self.encoder = None
|
| 33 |
+
self.embeddings = None
|
| 34 |
+
|
| 35 |
+
def _load_rows(self, data_path: str):
|
| 36 |
+
rows = []
|
| 37 |
+
if not os.path.exists(data_path):
|
| 38 |
+
return rows
|
| 39 |
+
with open(data_path, "r", encoding="utf-8") as f:
|
| 40 |
+
for line in f:
|
| 41 |
+
line = line.strip()
|
| 42 |
+
if not line:
|
| 43 |
+
continue
|
| 44 |
+
try:
|
| 45 |
+
item = json.loads(line)
|
| 46 |
+
except Exception:
|
| 47 |
+
continue
|
| 48 |
+
rows.append({
|
| 49 |
+
"text": item.get("text", ""),
|
| 50 |
+
"topic": item.get("topic", item.get("section", "general")) or "general",
|
| 51 |
+
"source": item.get("source", "local_corpus"),
|
| 52 |
+
})
|
| 53 |
+
return rows
|
| 54 |
+
|
| 55 |
+
def _topic_bonus(self, desired_topic: str, row_topic: str, intent: str) -> float:
|
| 56 |
+
desired_topic = (desired_topic or "").lower()
|
| 57 |
+
row_topic = (row_topic or "").lower()
|
| 58 |
+
intent = (intent or "").lower()
|
| 59 |
+
bonus = 0.0
|
| 60 |
+
if desired_topic and desired_topic in row_topic:
|
| 61 |
+
bonus += 1.25
|
| 62 |
+
if desired_topic == "algebra" and row_topic in {"algebra", "linear equations", "equations"}:
|
| 63 |
+
bonus += 1.0
|
| 64 |
+
if desired_topic == "percent" and "percent" in row_topic:
|
| 65 |
+
bonus += 1.0
|
| 66 |
+
if intent in {"method", "step_by_step", "full_working", "hint"}:
|
| 67 |
+
if any(k in row_topic for k in ["algebra", "percent", "fractions", "word_problems", "general"]):
|
| 68 |
+
bonus += 0.25
|
| 69 |
+
return bonus
|
| 70 |
+
|
| 71 |
+
def search(self, query: str, topic: str = "", intent: str = "answer", k: int = 3) -> List[RetrievedChunk]:
|
| 72 |
+
if not self.rows:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
return []
|
| 74 |
+
combined_query = clean_math_text(query)
|
| 75 |
+
|
| 76 |
+
scores = []
|
| 77 |
+
if self.encoder is not None and self.embeddings is not None and np is not None:
|
| 78 |
+
try:
|
| 79 |
+
q = self.encoder.encode([combined_query], convert_to_numpy=True, normalize_embeddings=True)[0]
|
| 80 |
+
semantic_scores = self.embeddings @ q
|
| 81 |
+
for row, sem in zip(self.rows, semantic_scores.tolist()):
|
| 82 |
+
lexical = score_token_overlap(combined_query, row["text"])
|
| 83 |
+
bonus = self._topic_bonus(topic, row["topic"], intent)
|
| 84 |
+
scores.append((0.7 * sem + 0.3 * lexical + bonus, row))
|
| 85 |
+
except Exception:
|
| 86 |
+
scores = []
|
| 87 |
+
|
| 88 |
+
if not scores:
|
| 89 |
+
for row in self.rows:
|
| 90 |
+
lexical = score_token_overlap(combined_query, row["text"])
|
| 91 |
+
bonus = self._topic_bonus(topic, row["topic"], intent)
|
| 92 |
+
scores.append((lexical + bonus, row))
|
| 93 |
+
|
| 94 |
+
scores.sort(key=lambda x: x[0], reverse=True)
|
| 95 |
+
results = []
|
| 96 |
+
for score, row in scores[:k]:
|
| 97 |
+
results.append(RetrievedChunk(text=row["text"], topic=row["topic"], source=row["source"], score=float(score)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
return results
|
ui_html.py
CHANGED
|
@@ -2,68 +2,37 @@ HOME_HTML = """
|
|
| 2 |
<!doctype html>
|
| 3 |
<html>
|
| 4 |
<head>
|
| 5 |
-
<meta charset="utf-8"
|
| 6 |
-
<
|
|
|
|
| 7 |
<style>
|
| 8 |
-
body { font-family: Arial, sans-serif; max-width:
|
| 9 |
-
textarea
|
| 10 |
-
button {
|
| 11 |
-
pre { background: #
|
| 12 |
-
.row { margin-bottom: 24px; }
|
| 13 |
</style>
|
| 14 |
</head>
|
| 15 |
<body>
|
| 16 |
-
<h1>Trading Game
|
| 17 |
-
<p>
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
<pre id="sessionOut"></pre>
|
| 31 |
-
</div>
|
| 32 |
-
|
| 33 |
<script>
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
const body = {
|
| 38 |
-
message: document.getElementById("msg").value,
|
| 39 |
-
tone: 0.7,
|
| 40 |
-
verbosity: 0.6,
|
| 41 |
-
transparency: 0.5,
|
| 42 |
-
chat_history: []
|
| 43 |
-
};
|
| 44 |
-
const res = await fetch("/chat", {
|
| 45 |
-
method: "POST",
|
| 46 |
-
headers: {"Content-Type": "application/json"},
|
| 47 |
-
body: JSON.stringify(body)
|
| 48 |
-
});
|
| 49 |
-
document.getElementById("chatOut").textContent = JSON.stringify(await res.json(), null, 2);
|
| 50 |
-
}
|
| 51 |
-
|
| 52 |
-
async function startSession() {
|
| 53 |
-
const body = {
|
| 54 |
-
participant_id: document.getElementById("participant").value,
|
| 55 |
-
condition: "demo",
|
| 56 |
-
study_id: "pilot",
|
| 57 |
-
metadata: {"source": "browser_test_page"}
|
| 58 |
-
};
|
| 59 |
-
const res = await fetch("/log/session/start", {
|
| 60 |
-
method: "POST",
|
| 61 |
-
headers: {"Content-Type": "application/json"},
|
| 62 |
-
body: JSON.stringify(body)
|
| 63 |
-
});
|
| 64 |
const data = await res.json();
|
| 65 |
-
|
| 66 |
-
document.getElementById("sessionOut").textContent = JSON.stringify(data, null, 2);
|
| 67 |
}
|
| 68 |
</script>
|
| 69 |
</body>
|
|
|
|
| 2 |
<!doctype html>
|
| 3 |
<html>
|
| 4 |
<head>
|
| 5 |
+
<meta charset=\"utf-8\">
|
| 6 |
+
<meta name=\"viewport\" content=\"width=device-width, initial-scale=1\">
|
| 7 |
+
<title>Trading Game AI V2</title>
|
| 8 |
<style>
|
| 9 |
+
body { font-family: Arial, sans-serif; max-width: 900px; margin: 40px auto; padding: 0 16px; }
|
| 10 |
+
textarea { width: 100%; min-height: 180px; }
|
| 11 |
+
button { margin-top: 12px; padding: 10px 16px; }
|
| 12 |
+
pre { background: #f5f5f5; padding: 16px; white-space: pre-wrap; }
|
|
|
|
| 13 |
</style>
|
| 14 |
</head>
|
| 15 |
<body>
|
| 16 |
+
<h1>Trading Game AI V2</h1>
|
| 17 |
+
<p>Test the <code>/chat</code> endpoint here.</p>
|
| 18 |
+
<textarea id=\"payload\">{
|
| 19 |
+
\"message\": \"Solve: x/5 = 12\",
|
| 20 |
+
\"chat_history\": [],
|
| 21 |
+
\"tone\": 0.5,
|
| 22 |
+
\"verbosity\": 0.6,
|
| 23 |
+
\"transparency\": 0.6,
|
| 24 |
+
\"session_id\": \"test-session-1\",
|
| 25 |
+
\"user_id\": \"test-user-1\"
|
| 26 |
+
}</textarea>
|
| 27 |
+
<br>
|
| 28 |
+
<button onclick=\"send()\">Send</button>
|
| 29 |
+
<pre id=\"out\"></pre>
|
|
|
|
|
|
|
|
|
|
| 30 |
<script>
|
| 31 |
+
async function send() {
|
| 32 |
+
const payload = JSON.parse(document.getElementById('payload').value);
|
| 33 |
+
const res = await fetch('/chat', {method:'POST', headers:{'Content-Type':'application/json'}, body: JSON.stringify(payload)});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
const data = await res.json();
|
| 35 |
+
document.getElementById('out').textContent = JSON.stringify(data, null, 2);
|
|
|
|
| 36 |
}
|
| 37 |
</script>
|
| 38 |
</body>
|
utils.py
CHANGED
|
@@ -1,104 +1,102 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
|
|
|
|
|
|
| 3 |
import math
|
| 4 |
import re
|
| 5 |
-
from typing import Any, Iterable, List
|
| 6 |
|
| 7 |
-
from models import ChatRequest
|
| 8 |
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def clamp01(value: Any, default: float = 0.5) -> float:
|
| 14 |
try:
|
| 15 |
-
v = float(
|
| 16 |
return max(0.0, min(1.0, v))
|
| 17 |
except Exception:
|
| 18 |
return default
|
| 19 |
|
| 20 |
|
| 21 |
def normalize_spaces(text: str) -> str:
|
| 22 |
-
return re.sub(r"\s+", " ", (text or "")).strip()
|
| 23 |
|
| 24 |
|
| 25 |
def clean_math_text(text: str) -> str:
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
def get_user_text(req: ChatRequest, raw_body: Any = None) -> str:
|
| 34 |
-
for field in [
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def _search_field(pattern: str, text: str) -> str:
|
| 53 |
-
m = re.search(pattern, text, flags=re.IGNORECASE | re.DOTALL)
|
| 54 |
-
return m.group(1).strip() if m else ""
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def parse_hidden_context(hidden_context: str) -> GameContext:
|
| 59 |
-
ctx = GameContext(raw_hidden_context=hidden_context or "")
|
| 60 |
-
if not hidden_context:
|
| 61 |
-
return ctx
|
| 62 |
-
|
| 63 |
-
ctx.question_category = _search_field(r"Category:\s*(.+)", hidden_context).splitlines()[0] if "Category:" in hidden_context else ""
|
| 64 |
-
ctx.question_difficulty = _search_field(r"Difficulty:\s*(.+)", hidden_context).splitlines()[0] if "Difficulty:" in hidden_context else ""
|
| 65 |
-
ctx.last_outcome = _search_field(r"Last outcome:\s*(.+)", hidden_context).splitlines()[0] if "Last outcome:" in hidden_context else ""
|
| 66 |
-
|
| 67 |
-
question = _search_field(r"Question:\s*(.+?)(?:\nOptions:|\nPlayer balance:|\nLast outcome:|$)", hidden_context)
|
| 68 |
-
options = _search_field(r"Options:\s*(.+?)(?:\nPlayer balance:|\nLast outcome:|$)", hidden_context)
|
| 69 |
-
balance_text = _search_field(r"Player balance:\s*([0-9]+(?:\.[0-9]+)?)", hidden_context)
|
| 70 |
-
|
| 71 |
-
ctx.question_text = question.strip()
|
| 72 |
-
ctx.options_text = options.strip()
|
| 73 |
-
try:
|
| 74 |
-
ctx.player_balance = float(balance_text) if balance_text else None
|
| 75 |
-
except Exception:
|
| 76 |
-
ctx.player_balance = None
|
| 77 |
-
return ctx
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
def soft_truncate(text: str, limit: int) -> str:
|
| 82 |
-
text = (text or "").strip()
|
| 83 |
-
if len(text) <= limit:
|
| 84 |
-
return text
|
| 85 |
-
trimmed = text[: limit - 1].rsplit(" ", 1)[0].strip()
|
| 86 |
-
return (trimmed or text[: limit - 1]).rstrip() + "…"
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def safe_div(a: float, b: float) -> float:
|
| 91 |
-
return a / b if b else math.inf
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
def flatten_history(items: Iterable[Any]) -> List[dict]:
|
| 96 |
-
out: List[dict] = []
|
| 97 |
-
for item in items or []:
|
| 98 |
-
if isinstance(item, dict):
|
| 99 |
-
out.append({"role": str(item.get("role", "user")), "text": str(item.get("text", ""))})
|
| 100 |
-
else:
|
| 101 |
-
role = getattr(item, "role", "user")
|
| 102 |
-
text = getattr(item, "text", "")
|
| 103 |
-
out.append({"role": str(role), "text": str(text)})
|
| 104 |
return out
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
import ast
|
| 4 |
+
import json
|
| 5 |
import math
|
| 6 |
import re
|
| 7 |
+
from typing import Any, Iterable, List
|
| 8 |
|
| 9 |
+
from models import ChatRequest
|
| 10 |
|
| 11 |
|
| 12 |
+
def clamp01(x: Any, default: float = 0.5) -> float:
|
|
|
|
|
|
|
|
|
|
| 13 |
try:
|
| 14 |
+
v = float(x)
|
| 15 |
return max(0.0, min(1.0, v))
|
| 16 |
except Exception:
|
| 17 |
return default
|
| 18 |
|
| 19 |
|
| 20 |
def normalize_spaces(text: str) -> str:
|
| 21 |
+
return re.sub(r"\s+", " ", str(text or "")).strip()
|
| 22 |
|
| 23 |
|
| 24 |
def clean_math_text(text: str) -> str:
|
| 25 |
+
t = str(text or "")
|
| 26 |
+
t = t.replace("×", "*").replace("÷", "/")
|
| 27 |
+
t = t.replace("–", "-").replace("—", "-").replace("−", "-")
|
| 28 |
+
t = t.replace("\u00a0", " ")
|
| 29 |
+
return t
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def tokenize(text: str) -> List[str]:
|
| 33 |
+
return re.findall(r"[a-z0-9]+", clean_math_text(text).lower())
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def score_token_overlap(query: str, text: str) -> float:
|
| 37 |
+
q = set(tokenize(query))
|
| 38 |
+
t = set(tokenize(text))
|
| 39 |
+
if not q or not t:
|
| 40 |
+
return 0.0
|
| 41 |
+
overlap = len(q & t)
|
| 42 |
+
return overlap / max(1, len(q))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def extract_text_from_any_payload(payload: Any) -> str:
|
| 46 |
+
if payload is None:
|
| 47 |
+
return ""
|
| 48 |
+
|
| 49 |
+
if isinstance(payload, str):
|
| 50 |
+
s = payload.strip()
|
| 51 |
+
if not s:
|
| 52 |
+
return ""
|
| 53 |
+
if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")):
|
| 54 |
+
try:
|
| 55 |
+
decoded = json.loads(s)
|
| 56 |
+
return extract_text_from_any_payload(decoded)
|
| 57 |
+
except Exception:
|
| 58 |
+
pass
|
| 59 |
+
try:
|
| 60 |
+
decoded = ast.literal_eval(s)
|
| 61 |
+
if isinstance(decoded, (dict, list)):
|
| 62 |
+
return extract_text_from_any_payload(decoded)
|
| 63 |
+
except Exception:
|
| 64 |
+
pass
|
| 65 |
+
return s
|
| 66 |
+
|
| 67 |
+
if isinstance(payload, dict):
|
| 68 |
+
for key in [
|
| 69 |
+
"message", "prompt", "query", "text", "user_message",
|
| 70 |
+
"input", "data", "payload", "body", "content",
|
| 71 |
+
]:
|
| 72 |
+
if key in payload:
|
| 73 |
+
maybe = extract_text_from_any_payload(payload[key])
|
| 74 |
+
if maybe:
|
| 75 |
+
return maybe
|
| 76 |
+
parts = [extract_text_from_any_payload(v) for v in payload.values()]
|
| 77 |
+
return "\n".join([p for p in parts if p]).strip()
|
| 78 |
+
|
| 79 |
+
if isinstance(payload, list):
|
| 80 |
+
parts = [extract_text_from_any_payload(x) for x in payload]
|
| 81 |
+
return "\n".join([p for p in parts if p]).strip()
|
| 82 |
+
|
| 83 |
+
return str(payload).strip()
|
| 84 |
|
| 85 |
|
| 86 |
def get_user_text(req: ChatRequest, raw_body: Any = None) -> str:
|
| 87 |
+
for field in ["message", "prompt", "query", "text", "user_message"]:
|
| 88 |
+
value = getattr(req, field, None)
|
| 89 |
+
if isinstance(value, str) and value.strip():
|
| 90 |
+
return value.strip()
|
| 91 |
+
return extract_text_from_any_payload(raw_body).strip()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def short_lines(items: Iterable[str], limit: int) -> List[str]:
|
| 95 |
+
out: List[str] = []
|
| 96 |
+
for item in items:
|
| 97 |
+
item = normalize_spaces(item)
|
| 98 |
+
if item:
|
| 99 |
+
out.append(item)
|
| 100 |
+
if len(out) >= limit:
|
| 101 |
+
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
return out
|