YashashMathur commited on
Commit
36e8c44
·
verified ·
1 Parent(s): f56b653

Fix: Use query parameters instead of body for API endpoints

Browse files
Files changed (1) hide show
  1. env/server.py +13 -15
env/server.py CHANGED
@@ -5,7 +5,7 @@ Provides REST and WebSocket endpoints for HuggingFace Spaces deployment.
5
  """
6
 
7
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
8
- from pydantic import BaseModel
9
  from typing import Optional, Dict, Any
10
  import json
11
  import asyncio
@@ -19,7 +19,7 @@ envs: Dict[str, SQLAnalystEnv] = {}
19
 
20
 
21
  class ResetRequest(BaseModel):
22
- task_id: str = "monthly_signups"
23
 
24
 
25
  class StepRequest(BaseModel):
@@ -42,10 +42,10 @@ async def root():
42
 
43
 
44
  @app.post("/reset")
45
- async def reset(req: ResetRequest) -> Dict[str, Any]:
46
- session_id = req.task_id
47
 
48
- env = SQLAnalystEnv(task_id=req.task_id)
49
  result = env.reset()
50
  envs[session_id] = env
51
 
@@ -65,15 +65,17 @@ async def reset(req: ResetRequest) -> Dict[str, Any]:
65
 
66
 
67
  @app.post("/step")
68
- async def step(req: StepRequest) -> Dict[str, Any]:
69
- session_id = req.session_id
70
-
 
 
71
  if session_id not in envs:
72
  return {"error": "Session not found. Call /reset first."}
73
 
74
  env = envs[session_id]
75
 
76
- action = Action(sql_query=req.sql_query, submit_answer=req.submit_answer)
77
 
78
  result = env.step(action)
79
 
@@ -108,9 +110,7 @@ async def step(req: StepRequest) -> Dict[str, Any]:
108
 
109
 
110
  @app.post("/state")
111
- async def state(req: StateRequest) -> Dict[str, Any]:
112
- session_id = req.session_id
113
-
114
  if session_id not in envs:
115
  return {"error": "Session not found. Call /reset first."}
116
 
@@ -129,9 +129,7 @@ async def state(req: StateRequest) -> Dict[str, Any]:
129
 
130
 
131
  @app.post("/delete")
132
- async def delete_session(req: StateRequest) -> Dict[str, str]:
133
- session_id = req.session_id
134
-
135
  if session_id in envs:
136
  del envs[session_id]
137
  return {"status": "deleted", "session_id": session_id}
 
5
  """
6
 
7
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
8
+ from pydantic import BaseModel, Field
9
  from typing import Optional, Dict, Any
10
  import json
11
  import asyncio
 
19
 
20
 
21
  class ResetRequest(BaseModel):
22
+ task_id: str = Field(default="monthly_signups")
23
 
24
 
25
  class StepRequest(BaseModel):
 
42
 
43
 
44
  @app.post("/reset")
45
+ async def reset(task_id: str = "monthly_signups") -> Dict[str, Any]:
46
+ session_id = task_id
47
 
48
+ env = SQLAnalystEnv(task_id=task_id)
49
  result = env.reset()
50
  envs[session_id] = env
51
 
 
65
 
66
 
67
  @app.post("/step")
68
+ async def step(
69
+ session_id: str,
70
+ sql_query: Optional[str] = None,
71
+ submit_answer: Optional[str] = None,
72
+ ) -> Dict[str, Any]:
73
  if session_id not in envs:
74
  return {"error": "Session not found. Call /reset first."}
75
 
76
  env = envs[session_id]
77
 
78
+ action = Action(sql_query=sql_query, submit_answer=submit_answer)
79
 
80
  result = env.step(action)
81
 
 
110
 
111
 
112
  @app.post("/state")
113
+ async def state(session_id: str) -> Dict[str, Any]:
 
 
114
  if session_id not in envs:
115
  return {"error": "Session not found. Call /reset first."}
116
 
 
129
 
130
 
131
  @app.post("/delete")
132
+ async def delete_session(session_id: str) -> Dict[str, str]:
 
 
133
  if session_id in envs:
134
  del envs[session_id]
135
  return {"status": "deleted", "session_id": session_id}