Parlay / main.py
sh4shv4t's picture
Relocate training notebooks, add BLOG and Google Colab links (SFT + GRPO HF Job), dashboard updates, and eval artifacts
00a2188
"""
Parlay — main application entry point.
Starts FastAPI with Dashboard + OpenEnv WebSocket + static file serving.
Usage:
uvicorn main:app --host 0.0.0.0 --port 8000 --reload
"""
import logging
import os
from contextlib import asynccontextmanager
from pathlib import Path
from dotenv import load_dotenv
# Load project root .env so GOOGLE_API_KEY and other secrets are in os.environ
# (uvicorn does not read .env by itself — without this, only shell-exported vars apply.)
load_dotenv(Path(__file__).resolve().parent / ".env")
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, Response
from parlay_env.server import router as env_router
from dashboard.api import router as dashboard_router
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize DB and resources on startup."""
logger.info("Parlay starting up...")
try:
from scripts.init_db import init_db
await init_db()
logger.info("Database initialized")
except Exception as exc:
logger.warning(f"DB init failed (continuing): {exc}")
yield
logger.info("Parlay shutting down.")
app = FastAPI(
title="Parlay",
description="OpenEnv-compliant RL negotiation environment. Train agents, play scenarios.",
version="1.0.0",
lifespan=lifespan,
)
# CORS — permissive for dev; restrict origins in prod
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Routers
app.include_router(env_router)
app.include_router(dashboard_router)
# Serve static files
static_dir = Path("dashboard/static")
if static_dir.exists():
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
os.makedirs("results", exist_ok=True)
try:
app.mount(
"/results", StaticFiles(directory="results"), name="results"
)
except OSError as exc:
logger.warning("Could not mount /results: %s", exc)
os.makedirs("images", exist_ok=True)
try:
app.mount("/images", StaticFiles(directory="images"), name="images")
except OSError as exc:
logger.warning("Could not mount /images: %s", exc)
@app.get("/", include_in_schema=False)
async def serve_index() -> FileResponse:
"""Serve the main game dashboard."""
return FileResponse(
"dashboard/index.html",
headers={"Cache-Control": "no-cache, must-revalidate"},
)
@app.get("/spectate", include_in_schema=False)
async def serve_spectate() -> FileResponse:
"""Serve the spectator dashboard."""
return FileResponse(
"dashboard/spectate.html",
headers={"Cache-Control": "no-cache, must-revalidate"},
)
@app.get("/train", include_in_schema=False)
async def serve_train_results() -> FileResponse:
"""Training results: plots, eval JSON, model hub CTA."""
return FileResponse(
"dashboard/train_results.html",
headers={"Cache-Control": "no-cache, must-revalidate"},
)
@app.get("/judge", include_in_schema=False)
async def serve_judge_demo() -> FileResponse:
"""GRPO (trained) negotiator: same game UI with opponent forced to HF model."""
return FileResponse(
"dashboard/judge.html",
headers={"Cache-Control": "no-cache, must-revalidate"},
)
@app.get("/interact", include_in_schema=False)
async def serve_interact() -> FileResponse:
"""Direct model inference page — talk to the GRPO model without game scaffolding."""
return FileResponse(
"dashboard/interact.html",
headers={"Cache-Control": "no-cache, must-revalidate"},
)
@app.get("/favicon.ico", include_in_schema=False, response_model=None)
async def favicon():
"""
Serve site icon when `dashboard/static/favicon/favicon.ico` exists.
Otherwise return 204 (place your .ico in that folder — see README.txt there).
"""
ico = Path("dashboard/static/favicon/favicon.ico")
if ico.is_file():
return FileResponse(ico, media_type="image/x-icon")
return Response(status_code=204)
@app.get("/health")
async def health() -> dict:
"""
Global health check.
Returns:
status: "ok" when server is running.
db: "ok" if parlay.db is reachable, "error" otherwise.
gemini: "configured" when GOOGLE_API_KEY is set, "mock" otherwise.
version: Application version string.
"""
import os
import aiosqlite
db_status = "error"
try:
async with aiosqlite.connect("parlay.db") as db:
await db.execute("SELECT 1")
db_status = "ok"
except Exception as exc:
logger.warning(f"Health check DB probe failed: {exc}")
gemini_status = "configured" if os.environ.get("GOOGLE_API_KEY", "").strip() else "mock"
return {
"status": "ok",
"db": db_status,
"gemini": gemini_status,
"version": "1.0.0",
}