Spaces:
Sleeping
Sleeping
File size: 2,569 Bytes
ac030bc 51fe7a3 ac030bc c5a9938 ac030bc 51fe7a3 ac030bc 0f991d0 51fe7a3 ac030bc 51fe7a3 551943d 51fe7a3 ac030bc 51fe7a3 ac030bc 51fe7a3 ac030bc 51fe7a3 1af5ac7 51fe7a3 1af5ac7 51fe7a3 c5a9938 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | from dotenv import load_dotenv
load_dotenv()
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from models import Action, Observation, StepResult
from .config import ACTION_SCHEMA
from .data import TASK_CONFIGS
from .environment import EmailSortingEnvironment
app = FastAPI(title="Sieve")
app.add_middleware(
CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]
)
env = EmailSortingEnvironment()
@app.post("/reset", response_model=Observation)
def reset(task_id: str = "email_classification"):
try:
return env.reset(task_id)
except ValueError as exec:
raise HTTPException(status_code=400, detail=str(exec))
@app.post("/step", response_model=StepResult)
def step(action: Action):
if not env.task_id:
raise HTTPException(
status_code=400, detail="Not initialized, call /reset first."
)
if env.done:
raise HTTPException(
status_code=400, detail="Episode already finished, call /reset first."
)
observation, reward, status, info = env.step(action)
return StepResult(observation=observation, reward=reward, done=status, info=info)
@app.get("/state")
def state():
return env.state()
@app.get("/tasks")
def list_tasks():
return {
"tasks": [
{
"id": task_id,
"name": config["name"],
"difficulty": config["difficulty"],
"description": config["description"],
"max_steps": config["max_steps"],
"action_schema": ACTION_SCHEMA,
}
for task_id, config in TASK_CONFIGS.items()
]
}
@app.get("/grader")
def grader():
score = env.last_grader_score
if score is None and env.episode_actions:
score = env.compute_final_score()
success = score is not None and 0.0 < score < 1.0
return {
"task_id": env.task_id,
"score": score,
"success": success,
"done": env.done,
"processed_count": len(env.processed_emails),
"total_emails": len(env.email_queue),
"episode_actions_summary": [
{
"email_id": action["email_id"],
"action_type": action["action_type"],
"correct_action": action.get("correct_action"),
}
for action in env.episode_actions
],
}
def main():
import uvicorn
uvicorn.run("server.app:app", host="0.0.0.0", port=7860)
if __name__ == "__main__":
main()
|