namish10's picture
Upload app.py with huggingface_hub
788411f verified
"""
ContextFlow OpenEnv - Simple API Server
"""
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional, List, Dict, Any
import uvicorn
from models import Observation, Action, Reward, State, StepResult, TaskDifficulty, ActionType
from server.contextflow_environment import ContextFlowEnvironment
app = FastAPI(title="ContextFlow OpenEnv")
environments: Dict[str, ContextFlowEnvironment] = {}
class ResetResponse(BaseModel):
observation: dict
episode_id: str
class StepRequest(BaseModel):
action_type: str
predicted_confusion: Optional[float] = None
intervention_type: Optional[str] = None
intervention_intensity: Optional[float] = None
episode_id: Optional[str] = None
@app.get("/")
async def root():
return {"message": "ContextFlow OpenEnv Environment", "version": "1.0.0"}
@app.get("/health")
async def health():
return {"status": "healthy"}
@app.post("/reset", response_model=ResetResponse)
async def reset(difficulty: str = "medium"):
try:
difficulty_enum = TaskDifficulty(difficulty.lower())
except ValueError:
difficulty_enum = TaskDifficulty.MEDIUM
env = ContextFlowEnvironment(task_difficulty=difficulty_enum)
observation = env.reset()
env_id = observation.episode_id
environments[env_id] = env
return ResetResponse(
observation=observation.model_dump(),
episode_id=env_id,
)
@app.post("/step")
async def step(request: StepRequest):
if not request.episode_id or request.episode_id not in environments:
return {"error": "Invalid or missing episode_id"}
env = environments[request.episode_id]
action = Action(
action_type=ActionType(request.action_type),
predicted_confusion=request.predicted_confusion,
intervention_type=request.intervention_type,
intervention_intensity=request.intervention_intensity,
)
result = env.step(action)
if result.done:
del environments[request.episode_id]
return result.model_dump()
@app.get("/state/{episode_id}")
async def get_state(episode_id: str):
if episode_id not in environments:
return {"error": "Episode not found"}
env = environments[episode_id]
return env.get_state().model_dump()
@app.get("/")
async def read_root():
return {"message": "ContextFlow OpenEnv", "version": "1.0.0"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)