| """ |
| FastAPI Backend for Trading Game + AI Chatbot Integration |
| Supports tunable AI parameters for psychological experiments |
| """ |
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.staticfiles import StaticFiles |
| from pydantic import BaseModel |
| from typing import Optional, List, Dict |
| import json |
| import os |
| from datetime import datetime |
|
|
| |
| from app import chain, llm, ChatOpenAI, PromptTemplate |
| from langchain_core.prompts import ChatPromptTemplate |
|
|
| app = FastAPI(title="Trading Game AI Experiment API") |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| app.mount("/game", StaticFiles(directory="game", html=True), name="game") |
|
|
| |
| experiment_state = {} |
| game_sessions = {} |
|
|
| class AIMessageRequest(BaseModel): |
| question: str |
| chat_history: List[tuple] = [] |
| risk_level: float = 5.0 |
| temperature: float = 0.7 |
| confidence_boost: float = 0.0 |
| session_id: str = "default" |
|
|
| class TradingDecision(BaseModel): |
| session_id: str |
| symbol: str |
| action: str |
| quantity: int |
| price: float |
| ai_advice_followed: bool |
| trust_score: Optional[float] = None |
|
|
| class ScenarioTrigger(BaseModel): |
| session_id: str |
| scenario_type: str |
| context: Dict |
| ai_required: bool = True |
|
|
| def get_ai_with_params(risk_level: float, temperature: float, confidence_boost: float): |
| """Create LLM with tunable parameters for experiment""" |
| |
| adjusted_temp = temperature + (risk_level / 10.0) * 0.3 |
| |
| |
| import os |
| model_name = os.getenv("HF_MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct:novita") |
| api_key = os.getenv("HF_TOKEN") |
| if not api_key: |
| raise ValueError("HF_TOKEN environment variable is not set.") |
| |
| |
| tuned_llm = ChatOpenAI( |
| model=model_name, |
| base_url="https://router.huggingface.co/v1", |
| api_key=api_key, |
| temperature=min(adjusted_temp, 2.0), |
| max_tokens=512, |
| ) |
| return tuned_llm |
|
|
| def get_contextual_prompt(context: Dict, risk_level: float, confidence_boost: float): |
| """Generate prompt that reflects risk level and confidence boost""" |
| risk_descriptor = { |
| 0: "extremely conservative", |
| 2: "very conservative", |
| 4: "conservative", |
| 5: "moderate", |
| 6: "moderately aggressive", |
| 8: "aggressive", |
| 10: "very aggressive" |
| } |
| risk_text = risk_descriptor.get(int(risk_level), "moderate") |
| |
| confidence_text = "" |
| if confidence_boost > 20: |
| confidence_text = "You are highly confident in your recommendations." |
| elif confidence_boost > 0: |
| confidence_text = "You are confident in your recommendations." |
| elif confidence_boost < -20: |
| confidence_text = "You are uncertain and should express caution in your recommendations." |
| |
| base_template = """ |
| You are an AI trading advisor for the Quantum Financial Network. Your risk profile is {risk_level}. |
| {confidence_text} |
| |
| You provide trading advice based on market data. Consider: |
| - Current market conditions: {market_context} |
| - Player portfolio status: {portfolio_context} |
| - Recent events: {events_context} |
| |
| Answer the question with appropriate caution/certainty based on your risk profile. |
| |
| Context: |
| {{context}} |
| |
| Question: {{question}} |
| |
| Answer: |
| """ |
| |
| return PromptTemplate( |
| input_variables=["context", "question"], |
| template=base_template.format( |
| risk_level=risk_text, |
| confidence_text=confidence_text, |
| market_context=context.get("market", "normal conditions"), |
| portfolio_context=context.get("portfolio", "standard portfolio"), |
| events_context=context.get("events", "no major events") |
| ) |
| ) |
|
|
| @app.post("/api/ai/chat") |
| async def chat_with_ai(request: AIMessageRequest): |
| """Main AI chat endpoint with tunable parameters""" |
| try: |
| |
| if request.session_id not in game_sessions: |
| game_sessions[request.session_id] = { |
| "chat_history": [], |
| "decisions": [], |
| "trust_scores": [], |
| "params_history": [] |
| } |
| |
| session = game_sessions[request.session_id] |
| |
| |
| |
| chat_history_tuples = [] |
| if request.chat_history: |
| for item in request.chat_history: |
| if isinstance(item, (list, tuple)) and len(item) >= 2: |
| chat_history_tuples.append((str(item[0]), str(item[1]))) |
| |
| |
| |
| enhanced_question = request.question |
| |
| |
| if "risk tolerance" not in enhanced_question.lower() and request.risk_level is not None: |
| risk_desc = { |
| 0: "extremely conservative", |
| 1: "very conservative", |
| 3: "conservative", |
| 5: "moderate", |
| 7: "moderately aggressive", |
| 9: "aggressive", |
| 10: "very aggressive" |
| } |
| risk_text = risk_desc.get(int(request.risk_level), "moderate") |
| |
| if "Current Market Scenario" not in enhanced_question: |
| enhanced_question = f"Risk Profile: {risk_text} ({request.risk_level}/10)\n\n{enhanced_question}" |
| |
| |
| result = chain({ |
| "question": enhanced_question, |
| "chat_history": chat_history_tuples |
| }) |
| |
| |
| interaction = { |
| "timestamp": datetime.now().isoformat(), |
| "question": request.question, |
| "response": result["answer"], |
| "risk_level": request.risk_level, |
| "temperature": request.temperature, |
| "confidence_boost": request.confidence_boost |
| } |
| session["params_history"].append(interaction) |
| |
| return { |
| "answer": result["answer"], |
| "sources": [doc.page_content[:100] for doc in result.get("source_documents", [])] if "source_documents" in result else [], |
| "interaction_id": len(session["params_history"]) - 1 |
| } |
| |
| except Exception as e: |
| import traceback |
| error_detail = str(e) + "\n" + traceback.format_exc() |
| raise HTTPException(status_code=500, detail=error_detail) |
|
|
| @app.post("/api/experiment/decision") |
| async def log_decision(decision: TradingDecision): |
| """Log player trading decisions for trust analysis""" |
| if decision.session_id not in game_sessions: |
| game_sessions[decision.session_id] = { |
| "chat_history": [], |
| "decisions": [], |
| "trust_scores": [], |
| "params_history": [] |
| } |
| |
| game_sessions[decision.session_id]["decisions"].append({ |
| "timestamp": datetime.now().isoformat(), |
| "symbol": decision.symbol, |
| "action": decision.action, |
| "quantity": decision.quantity, |
| "price": decision.price, |
| "ai_advice_followed": decision.ai_advice_followed, |
| "trust_score": decision.trust_score |
| }) |
| |
| return {"status": "logged", "decision_id": len(game_sessions[decision.session_id]["decisions"]) - 1} |
|
|
| @app.post("/api/experiment/scenario") |
| async def trigger_scenario(scenario: ScenarioTrigger): |
| """Trigger situational scenarios that require AI assistance""" |
| scenario_prompts = { |
| "volatility": "The market is experiencing high volatility. A stock in your portfolio has moved 10% in the last hour. What should you do?", |
| "large_position": "You're about to make a large position trade ($10,000+). This would represent a significant portion of your portfolio. Should you proceed?", |
| "loss_recovery": "You're down 5% today. Would you like advice on whether to cut losses or hold your positions?", |
| "news_event": "Breaking news just released that affects several stocks in your watchlist. How should this impact your trading decisions?" |
| } |
| |
| if scenario.scenario_type not in scenario_prompts: |
| raise HTTPException(status_code=400, detail="Unknown scenario type") |
| |
| prompt = scenario_prompts[scenario.scenario_type] |
| |
| |
| return { |
| "scenario_type": scenario.scenario_type, |
| "prompt": prompt, |
| "context": scenario.context, |
| "requires_ai": scenario.ai_required |
| } |
|
|
| @app.get("/api/experiment/session/{session_id}") |
| async def get_session_data(session_id: str): |
| """Get all experiment data for a session""" |
| if session_id not in game_sessions: |
| raise HTTPException(status_code=404, detail="Session not found") |
| |
| return game_sessions[session_id] |
|
|
| @app.get("/api/experiment/export/{session_id}") |
| async def export_experiment_data(session_id: str): |
| """Export experiment data as JSON for analysis""" |
| if session_id not in game_sessions: |
| raise HTTPException(status_code=404, detail="Session not found") |
| |
| data = game_sessions[session_id] |
| return data |
|
|
| @app.get("/") |
| async def root(): |
| """Redirect to game""" |
| from fastapi.responses import RedirectResponse |
| return RedirectResponse(url="/game/trade.html") |
|
|
| @app.get("/game") |
| async def game_redirect(): |
| """Redirect to game""" |
| from fastapi.responses import RedirectResponse |
| return RedirectResponse(url="/game/trade.html") |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|