from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import JSONResponse import asyncio import json from typing import Optional from models import Observation, Action, Reward, State, StepResult, TaskDifficulty from server.contextflow_environment import ContextFlowEnvironment app = FastAPI(title="ContextFlow OpenEnv") connections: dict[str, WebSocket] = {} environments: dict[str, ContextFlowEnvironment] = {} @app.get("/") async def root(): return {"message": "ContextFlow OpenEnv Environment", "version": "1.0.0"} @app.get("/health") async def health(): return {"status": "healthy"} @app.post("/reset") async def reset(difficulty: Optional[str] = "medium"): try: difficulty_enum = TaskDifficulty(difficulty.lower()) except ValueError: difficulty_enum = TaskDifficulty.MEDIUM env = ContextFlowEnvironment(task_difficulty=difficulty_enum) observation = env.reset() env_id = observation.episode_id environments[env_id] = env return { "observation": observation.model_dump(), "episode_id": env_id, } @app.post("/step") async def step(action: Action): if not action.episode_id or action.episode_id not in environments: return JSONResponse( status_code=400, content={"error": "Invalid or missing episode_id"} ) env = environments[action.episode_id] result = env.step(action) if result.done: del environments[action.episode_id] return result.model_dump() @app.get("/state/{episode_id}") async def get_state(episode_id: str): if episode_id not in environments: return JSONResponse( status_code=404, content={"error": "Episode not found"} ) env = environments[episode_id] return env.state().model_dump() @app.websocket("/ws/{episode_id}") async def websocket_endpoint(websocket: WebSocket, episode_id: str): await websocket.accept() connections[episode_id] = websocket if episode_id not in environments: await websocket.send_json({"error": "Episode not found"}) await websocket.close() return try: while True: data = await websocket.receive_text() message = json.loads(data) if message["type"] == "reset": difficulty = message.get("difficulty", "medium") env = ContextFlowEnvironment(task_difficulty=TaskDifficulty(difficulty)) observation = env.reset() environments[episode_id] = env await websocket.send_json({ "type": "reset", "observation": observation.model_dump() }) elif message["type"] == "step": if episode_id not in environments: await websocket.send_json({"error": "Episode not found"}) continue env = environments[episode_id] action = Action(**message["action"]) result = env.step(action) if result.done: del environments[episode_id] await websocket.send_json({ "type": "step", "result": result.model_dump() }) elif message["type"] == "state": if episode_id not in environments: await websocket.send_json({"error": "Episode not found"}) continue env = environments[episode_id] await websocket.send_json({ "type": "state", "state": env.state().model_dump() }) except WebSocketDisconnect: pass finally: if episode_id in connections: del connections[episode_id] if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)