Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from fastapi.responses import JSONResponse | |
| import asyncio | |
| import json | |
| from typing import Optional | |
| from models import Observation, Action, Reward, State, StepResult, TaskDifficulty | |
| from server.contextflow_environment import ContextFlowEnvironment | |
| app = FastAPI(title="ContextFlow OpenEnv") | |
| connections: dict[str, WebSocket] = {} | |
| environments: dict[str, ContextFlowEnvironment] = {} | |
| async def root(): | |
| return {"message": "ContextFlow OpenEnv Environment", "version": "1.0.0"} | |
| async def health(): | |
| return {"status": "healthy"} | |
| async def reset(difficulty: Optional[str] = "medium"): | |
| try: | |
| difficulty_enum = TaskDifficulty(difficulty.lower()) | |
| except ValueError: | |
| difficulty_enum = TaskDifficulty.MEDIUM | |
| env = ContextFlowEnvironment(task_difficulty=difficulty_enum) | |
| observation = env.reset() | |
| env_id = observation.episode_id | |
| environments[env_id] = env | |
| return { | |
| "observation": observation.model_dump(), | |
| "episode_id": env_id, | |
| } | |
| async def step(action: Action): | |
| if not action.episode_id or action.episode_id not in environments: | |
| return JSONResponse( | |
| status_code=400, | |
| content={"error": "Invalid or missing episode_id"} | |
| ) | |
| env = environments[action.episode_id] | |
| result = env.step(action) | |
| if result.done: | |
| del environments[action.episode_id] | |
| return result.model_dump() | |
| async def get_state(episode_id: str): | |
| if episode_id not in environments: | |
| return JSONResponse( | |
| status_code=404, | |
| content={"error": "Episode not found"} | |
| ) | |
| env = environments[episode_id] | |
| return env.state().model_dump() | |
| async def websocket_endpoint(websocket: WebSocket, episode_id: str): | |
| await websocket.accept() | |
| connections[episode_id] = websocket | |
| if episode_id not in environments: | |
| await websocket.send_json({"error": "Episode not found"}) | |
| await websocket.close() | |
| return | |
| try: | |
| while True: | |
| data = await websocket.receive_text() | |
| message = json.loads(data) | |
| if message["type"] == "reset": | |
| difficulty = message.get("difficulty", "medium") | |
| env = ContextFlowEnvironment(task_difficulty=TaskDifficulty(difficulty)) | |
| observation = env.reset() | |
| environments[episode_id] = env | |
| await websocket.send_json({ | |
| "type": "reset", | |
| "observation": observation.model_dump() | |
| }) | |
| elif message["type"] == "step": | |
| if episode_id not in environments: | |
| await websocket.send_json({"error": "Episode not found"}) | |
| continue | |
| env = environments[episode_id] | |
| action = Action(**message["action"]) | |
| result = env.step(action) | |
| if result.done: | |
| del environments[episode_id] | |
| await websocket.send_json({ | |
| "type": "step", | |
| "result": result.model_dump() | |
| }) | |
| elif message["type"] == "state": | |
| if episode_id not in environments: | |
| await websocket.send_json({"error": "Episode not found"}) | |
| continue | |
| env = environments[episode_id] | |
| await websocket.send_json({ | |
| "type": "state", | |
| "state": env.state().model_dump() | |
| }) | |
| except WebSocketDisconnect: | |
| pass | |
| finally: | |
| if episode_id in connections: | |
| del connections[episode_id] | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |