namish10's picture
Upload folder using huggingface_hub
a2896bf verified
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse
import asyncio
import json
from typing import Optional
from models import Observation, Action, Reward, State, StepResult, TaskDifficulty
from server.contextflow_environment import ContextFlowEnvironment
app = FastAPI(title="ContextFlow OpenEnv")
connections: dict[str, WebSocket] = {}
environments: dict[str, ContextFlowEnvironment] = {}
@app.get("/")
async def root():
return {"message": "ContextFlow OpenEnv Environment", "version": "1.0.0"}
@app.get("/health")
async def health():
return {"status": "healthy"}
@app.post("/reset")
async def reset(difficulty: Optional[str] = "medium"):
try:
difficulty_enum = TaskDifficulty(difficulty.lower())
except ValueError:
difficulty_enum = TaskDifficulty.MEDIUM
env = ContextFlowEnvironment(task_difficulty=difficulty_enum)
observation = env.reset()
env_id = observation.episode_id
environments[env_id] = env
return {
"observation": observation.model_dump(),
"episode_id": env_id,
}
@app.post("/step")
async def step(action: Action):
if not action.episode_id or action.episode_id not in environments:
return JSONResponse(
status_code=400,
content={"error": "Invalid or missing episode_id"}
)
env = environments[action.episode_id]
result = env.step(action)
if result.done:
del environments[action.episode_id]
return result.model_dump()
@app.get("/state/{episode_id}")
async def get_state(episode_id: str):
if episode_id not in environments:
return JSONResponse(
status_code=404,
content={"error": "Episode not found"}
)
env = environments[episode_id]
return env.state().model_dump()
@app.websocket("/ws/{episode_id}")
async def websocket_endpoint(websocket: WebSocket, episode_id: str):
await websocket.accept()
connections[episode_id] = websocket
if episode_id not in environments:
await websocket.send_json({"error": "Episode not found"})
await websocket.close()
return
try:
while True:
data = await websocket.receive_text()
message = json.loads(data)
if message["type"] == "reset":
difficulty = message.get("difficulty", "medium")
env = ContextFlowEnvironment(task_difficulty=TaskDifficulty(difficulty))
observation = env.reset()
environments[episode_id] = env
await websocket.send_json({
"type": "reset",
"observation": observation.model_dump()
})
elif message["type"] == "step":
if episode_id not in environments:
await websocket.send_json({"error": "Episode not found"})
continue
env = environments[episode_id]
action = Action(**message["action"])
result = env.step(action)
if result.done:
del environments[episode_id]
await websocket.send_json({
"type": "step",
"result": result.model_dump()
})
elif message["type"] == "state":
if episode_id not in environments:
await websocket.send_json({"error": "Episode not found"})
continue
env = environments[episode_id]
await websocket.send_json({
"type": "state",
"state": env.state().model_dump()
})
except WebSocketDisconnect:
pass
finally:
if episode_id in connections:
del connections[episode_id]
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)