Zenoharsh01 commited on
Commit
bef3e25
·
verified ·
1 Parent(s): 3d64a53

Update server/app.py

Browse files
Files changed (1) hide show
  1. server/app.py +35 -30
server/app.py CHANGED
@@ -1,30 +1,35 @@
1
- from fastapi import FastAPI, Request
2
- from env import SOCEnvironment
3
- from models import Action
4
-
5
- app = FastAPI()
6
- soc_env = SOCEnvironment()
7
-
8
- @app.post("/reset")
9
- async def reset(request: Request):
10
- # The validation script sends an empty JSON payload during the ping test
11
- body = await request.json() if await request.body() else {}
12
- task_id = body.get("task_id", "task_1_triage")
13
-
14
- obs = soc_env.reset(task_id)
15
- return {"observation": obs.model_dump()}
16
-
17
- @app.post("/step")
18
- async def step(action: Action):
19
- # Receives the action from the external agent and executes it
20
- obs, reward, done, info = soc_env.step(action)
21
- return {
22
- "observation": obs.model_dump(),
23
- "reward": reward.score_delta,
24
- "done": done,
25
- "info": info
26
- }
27
-
28
- @app.get("/health")
29
- def health():
30
- return {"status": "ok"}
 
 
 
 
 
 
1
+ import uvicorn
2
+ from fastapi import FastAPI, Request
3
+ from env import SOCEnvironment
4
+ from models import Action
5
+
6
+ app = FastAPI()
7
+ soc_env = SOCEnvironment()
8
+
9
+ @app.post("/reset")
10
+ async def reset(request: Request):
11
+ body = await request.json() if await request.body() else {}
12
+ task_id = body.get("task_id", "task_1_triage")
13
+ obs = soc_env.reset(task_id)
14
+ return {"observation": obs.model_dump()}
15
+
16
+ @app.post("/step")
17
+ async def step(action: Action):
18
+ obs, reward, done, info = soc_env.step(action)
19
+ return {
20
+ "observation": obs.model_dump(),
21
+ "reward": reward.score_delta,
22
+ "done": done,
23
+ "info": info
24
+ }
25
+
26
+ @app.get("/health")
27
+ def health():
28
+ return {"status": "ok"}
29
+
30
+ # Added to satisfy the multi-mode deployment validator
31
+ def main():
32
+ uvicorn.run("server.app:app", host="0.0.0.0", port=7860)
33
+
34
+ if __name__ == "__main__":
35
+ main()