Spaces:
Sleeping
Sleeping
| """ | |
| ContextFlow OpenEnv - Simple API Server | |
| """ | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from typing import Optional, List, Dict, Any | |
| import uvicorn | |
| from models import Observation, Action, Reward, State, StepResult, TaskDifficulty, ActionType | |
| from server.contextflow_environment import ContextFlowEnvironment | |
| app = FastAPI(title="ContextFlow OpenEnv") | |
| environments: Dict[str, ContextFlowEnvironment] = {} | |
| class ResetResponse(BaseModel): | |
| observation: dict | |
| episode_id: str | |
| class StepRequest(BaseModel): | |
| action_type: str | |
| predicted_confusion: Optional[float] = None | |
| intervention_type: Optional[str] = None | |
| intervention_intensity: Optional[float] = None | |
| episode_id: Optional[str] = None | |
| async def root(): | |
| return {"message": "ContextFlow OpenEnv Environment", "version": "1.0.0"} | |
| async def health(): | |
| return {"status": "healthy"} | |
| async def reset(difficulty: str = "medium"): | |
| try: | |
| difficulty_enum = TaskDifficulty(difficulty.lower()) | |
| except ValueError: | |
| difficulty_enum = TaskDifficulty.MEDIUM | |
| env = ContextFlowEnvironment(task_difficulty=difficulty_enum) | |
| observation = env.reset() | |
| env_id = observation.episode_id | |
| environments[env_id] = env | |
| return ResetResponse( | |
| observation=observation.model_dump(), | |
| episode_id=env_id, | |
| ) | |
| async def step(request: StepRequest): | |
| if not request.episode_id or request.episode_id not in environments: | |
| return {"error": "Invalid or missing episode_id"} | |
| env = environments[request.episode_id] | |
| action = Action( | |
| action_type=ActionType(request.action_type), | |
| predicted_confusion=request.predicted_confusion, | |
| intervention_type=request.intervention_type, | |
| intervention_intensity=request.intervention_intensity, | |
| ) | |
| result = env.step(action) | |
| if result.done: | |
| del environments[request.episode_id] | |
| return result.model_dump() | |
| async def get_state(episode_id: str): | |
| if episode_id not in environments: | |
| return {"error": "Episode not found"} | |
| env = environments[episode_id] | |
| return env.get_state().model_dump() | |
| async def read_root(): | |
| return {"message": "ContextFlow OpenEnv", "version": "1.0.0"} | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |