Dev Shah commited on
Commit
5107b13
·
1 Parent(s): f4c15b1

fix: make task parsing robust to string IDs for openenv task validation

Browse files
Files changed (1) hide show
  1. server/app.py +20 -12
server/app.py CHANGED
@@ -5,8 +5,8 @@ Endpoints mirror the OpenEnv spec.
5
 
6
  from fastapi import FastAPI, HTTPException, Body
7
  from fastapi.middleware.cors import CORSMiddleware
8
- from pydantic import BaseModel
9
- from typing import Optional
10
  import uvicorn
11
  import os
12
  import sys
@@ -29,19 +29,27 @@ app.add_middleware(
29
  _envs: dict[int, EmailTriageEnv] = {}
30
 
31
 
 
 
 
 
 
 
 
32
  class ResetRequest(BaseModel):
33
- task: int = 1
34
 
35
 
36
  class StepRequest(BaseModel):
37
- task: int = 1
38
  action: Action
39
 
40
 
41
- def _get_env(task: int) -> EmailTriageEnv:
42
- if task not in _envs:
43
- raise HTTPException(status_code=400, detail=f"Task {task} not initialised. Call /reset first.")
44
- return _envs[task]
 
45
 
46
 
47
  @app.get("/health")
@@ -51,7 +59,7 @@ def health():
51
 
52
  @app.post("/reset")
53
  def reset(req: Optional[ResetRequest] = Body(default=None)):
54
- task = req.task if req else 1
55
  env = EmailTriageEnv(task=task)
56
  obs = env.reset()
57
  _envs[task] = env
@@ -72,15 +80,15 @@ def step(req: StepRequest):
72
 
73
 
74
  @app.get("/state")
75
- def state(task: int = 1):
76
  env = _get_env(task)
77
  return {"state": env.state(), "score": env.score()}
78
 
79
 
80
  @app.get("/score")
81
- def score(task: int = 1):
82
  env = _get_env(task)
83
- return {"score": env.score(), "task": task}
84
 
85
 
86
  def main():
 
5
 
6
  from fastapi import FastAPI, HTTPException, Body
7
  from fastapi.middleware.cors import CORSMiddleware
8
+ from pydantic import BaseModel, Field
9
+ from typing import Optional, Union
10
  import uvicorn
11
  import os
12
  import sys
 
29
  _envs: dict[int, EmailTriageEnv] = {}
30
 
31
 
32
+ def _parse_task(task: Union[int, str]) -> int:
33
+ if isinstance(task, str):
34
+ if task.startswith("task"):
35
+ return int(task[4:])
36
+ return int(task)
37
+ return task
38
+
39
  class ResetRequest(BaseModel):
40
+ task: Union[int, str] = 1
41
 
42
 
43
  class StepRequest(BaseModel):
44
+ task: Union[int, str] = 1
45
  action: Action
46
 
47
 
48
+ def _get_env(task: Union[int, str]) -> EmailTriageEnv:
49
+ task_int = _parse_task(task)
50
+ if task_int not in _envs:
51
+ raise HTTPException(status_code=400, detail=f"Task {task_int} not initialised. Call /reset first.")
52
+ return _envs[task_int]
53
 
54
 
55
  @app.get("/health")
 
59
 
60
  @app.post("/reset")
61
  def reset(req: Optional[ResetRequest] = Body(default=None)):
62
+ task = _parse_task(req.task if req else 1)
63
  env = EmailTriageEnv(task=task)
64
  obs = env.reset()
65
  _envs[task] = env
 
80
 
81
 
82
  @app.get("/state")
83
+ def state(task: Union[int, str] = 1):
84
  env = _get_env(task)
85
  return {"state": env.state(), "score": env.score()}
86
 
87
 
88
  @app.get("/score")
89
+ def score(task: Union[int, str] = 1):
90
  env = _get_env(task)
91
+ return {"score": env.score(), "task": _parse_task(task)}
92
 
93
 
94
  def main():