nl2sql-bench / env_server
ritvik360's picture
Upload folder using huggingface_hub
a39d8ef verified
raw
history blame contribute delete
921 Bytes
import os
import sys
from fastapi import FastAPI, Request
import uvicorn
sys.path.insert(0, "./server")
from environment import NL2SQLEnvironment
from models import NL2SQLAction
app = FastAPI()
env = NL2SQLEnvironment()
@app.post("/reset")
async def reset(request: Request):
data = await request.json()
# Now we take task_name directly from the API call
task_name = data.get("task_name", "simple-filter")
print(f"🔄 Environment Resetting for Task: {task_name}")
obs = env.reset(task_name=task_name)
return {"observation": obs.__dict__}
@app.post("/step")
async def step(request: Request):
data = await request.json()
query = data.get("query", "")
print(f"⏩ Executing SQL: {query[:60]}...")
action = NL2SQLAction(query=query)
obs = env.step(action)
return {"observation": obs.__dict__}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)