namish10 commited on
Commit
788411f
·
verified ·
1 Parent(s): 879da30

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +97 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ContextFlow OpenEnv - Simple API Server
3
+ """
4
+
5
+ from fastapi import FastAPI
6
+ from pydantic import BaseModel
7
+ from typing import Optional, List, Dict, Any
8
+ import uvicorn
9
+
10
+ from models import Observation, Action, Reward, State, StepResult, TaskDifficulty, ActionType
11
+ from server.contextflow_environment import ContextFlowEnvironment
12
+
13
+ app = FastAPI(title="ContextFlow OpenEnv")
14
+
15
+ environments: Dict[str, ContextFlowEnvironment] = {}
16
+
17
+
18
+ class ResetResponse(BaseModel):
19
+ observation: dict
20
+ episode_id: str
21
+
22
+
23
+ class StepRequest(BaseModel):
24
+ action_type: str
25
+ predicted_confusion: Optional[float] = None
26
+ intervention_type: Optional[str] = None
27
+ intervention_intensity: Optional[float] = None
28
+ episode_id: Optional[str] = None
29
+
30
+
31
+ @app.get("/")
32
+ async def root():
33
+ return {"message": "ContextFlow OpenEnv Environment", "version": "1.0.0"}
34
+
35
+
36
+ @app.get("/health")
37
+ async def health():
38
+ return {"status": "healthy"}
39
+
40
+
41
+ @app.post("/reset", response_model=ResetResponse)
42
+ async def reset(difficulty: str = "medium"):
43
+ try:
44
+ difficulty_enum = TaskDifficulty(difficulty.lower())
45
+ except ValueError:
46
+ difficulty_enum = TaskDifficulty.MEDIUM
47
+
48
+ env = ContextFlowEnvironment(task_difficulty=difficulty_enum)
49
+ observation = env.reset()
50
+
51
+ env_id = observation.episode_id
52
+ environments[env_id] = env
53
+
54
+ return ResetResponse(
55
+ observation=observation.model_dump(),
56
+ episode_id=env_id,
57
+ )
58
+
59
+
60
+ @app.post("/step")
61
+ async def step(request: StepRequest):
62
+ if not request.episode_id or request.episode_id not in environments:
63
+ return {"error": "Invalid or missing episode_id"}
64
+
65
+ env = environments[request.episode_id]
66
+
67
+ action = Action(
68
+ action_type=ActionType(request.action_type),
69
+ predicted_confusion=request.predicted_confusion,
70
+ intervention_type=request.intervention_type,
71
+ intervention_intensity=request.intervention_intensity,
72
+ )
73
+
74
+ result = env.step(action)
75
+
76
+ if result.done:
77
+ del environments[request.episode_id]
78
+
79
+ return result.model_dump()
80
+
81
+
82
+ @app.get("/state/{episode_id}")
83
+ async def get_state(episode_id: str):
84
+ if episode_id not in environments:
85
+ return {"error": "Episode not found"}
86
+
87
+ env = environments[episode_id]
88
+ return env.get_state().model_dump()
89
+
90
+
91
+ @app.get("/")
92
+ async def read_root():
93
+ return {"message": "ContextFlow OpenEnv", "version": "1.0.0"}
94
+
95
+
96
+ if __name__ == "__main__":
97
+ uvicorn.run(app, host="0.0.0.0", port=7860)