File size: 5,117 Bytes
698f4d8
 
 
 
 
 
 
 
 
 
 
 
80b3b2e
 
 
 
 
 
698f4d8
 
 
14577ec
698f4d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108bc34
 
 
 
 
 
 
 
213dee8
 
 
 
 
 
698f4d8
 
 
 
14577ec
 
 
 
698f4d8
 
2568517
 
 
 
 
 
 
 
 
108bc34
 
 
 
 
 
 
 
 
00a2188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d82eed
 
 
 
 
 
 
 
 
14577ec
698f4d8
 
 
 
14577ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""
Parlay — main application entry point.
Starts FastAPI with Dashboard + OpenEnv WebSocket + static file serving.

Usage:
    uvicorn main:app --host 0.0.0.0 --port 8000 --reload
"""
import logging
import os
from contextlib import asynccontextmanager
from pathlib import Path

from dotenv import load_dotenv

# Load project root .env so GOOGLE_API_KEY and other secrets are in os.environ
# (uvicorn does not read .env by itself — without this, only shell-exported vars apply.)
load_dotenv(Path(__file__).resolve().parent / ".env")

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, Response

from parlay_env.server import router as env_router
from dashboard.api import router as dashboard_router

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)


@asynccontextmanager
async def lifespan(app: FastAPI):
    """Initialize DB and resources on startup."""
    logger.info("Parlay starting up...")
    try:
        from scripts.init_db import init_db
        await init_db()
        logger.info("Database initialized")
    except Exception as exc:
        logger.warning(f"DB init failed (continuing): {exc}")
    yield
    logger.info("Parlay shutting down.")


app = FastAPI(
    title="Parlay",
    description="OpenEnv-compliant RL negotiation environment. Train agents, play scenarios.",
    version="1.0.0",
    lifespan=lifespan,
)

# CORS — permissive for dev; restrict origins in prod
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Routers
app.include_router(env_router)
app.include_router(dashboard_router)

# Serve static files
static_dir = Path("dashboard/static")
if static_dir.exists():
    app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")

os.makedirs("results", exist_ok=True)
try:
    app.mount(
        "/results", StaticFiles(directory="results"), name="results"
    )
except OSError as exc:
    logger.warning("Could not mount /results: %s", exc)

os.makedirs("images", exist_ok=True)
try:
    app.mount("/images", StaticFiles(directory="images"), name="images")
except OSError as exc:
    logger.warning("Could not mount /images: %s", exc)


@app.get("/", include_in_schema=False)
async def serve_index() -> FileResponse:
    """Serve the main game dashboard."""
    return FileResponse(
        "dashboard/index.html",
        headers={"Cache-Control": "no-cache, must-revalidate"},
    )


@app.get("/spectate", include_in_schema=False)
async def serve_spectate() -> FileResponse:
    """Serve the spectator dashboard."""
    return FileResponse(
        "dashboard/spectate.html",
        headers={"Cache-Control": "no-cache, must-revalidate"},
    )


@app.get("/train", include_in_schema=False)
async def serve_train_results() -> FileResponse:
    """Training results: plots, eval JSON, model hub CTA."""
    return FileResponse(
        "dashboard/train_results.html",
        headers={"Cache-Control": "no-cache, must-revalidate"},
    )


@app.get("/judge", include_in_schema=False)
async def serve_judge_demo() -> FileResponse:
    """GRPO (trained) negotiator: same game UI with opponent forced to HF model."""
    return FileResponse(
        "dashboard/judge.html",
        headers={"Cache-Control": "no-cache, must-revalidate"},
    )


@app.get("/interact", include_in_schema=False)
async def serve_interact() -> FileResponse:
    """Direct model inference page — talk to the GRPO model without game scaffolding."""
    return FileResponse(
        "dashboard/interact.html",
        headers={"Cache-Control": "no-cache, must-revalidate"},
    )


@app.get("/favicon.ico", include_in_schema=False, response_model=None)
async def favicon():
    """
    Serve site icon when `dashboard/static/favicon/favicon.ico` exists.
    Otherwise return 204 (place your .ico in that folder — see README.txt there).
    """
    ico = Path("dashboard/static/favicon/favicon.ico")
    if ico.is_file():
        return FileResponse(ico, media_type="image/x-icon")
    return Response(status_code=204)


@app.get("/health")
async def health() -> dict:
    """
    Global health check.

    Returns:
        status:  "ok" when server is running.
        db:      "ok" if parlay.db is reachable, "error" otherwise.
        gemini:  "configured" when GOOGLE_API_KEY is set, "mock" otherwise.
        version: Application version string.
    """
    import os
    import aiosqlite

    db_status = "error"
    try:
        async with aiosqlite.connect("parlay.db") as db:
            await db.execute("SELECT 1")
        db_status = "ok"
    except Exception as exc:
        logger.warning(f"Health check DB probe failed: {exc}")

    gemini_status = "configured" if os.environ.get("GOOGLE_API_KEY", "").strip() else "mock"

    return {
        "status": "ok",
        "db": db_status,
        "gemini": gemini_status,
        "version": "1.0.0",
    }