Spaces:
Running
Running
feat: enhance reset functionality and grading logic with optional request body
Browse files- adv_rebuild.py +14 -6
- server/app.py +10 -5
- src/models.py +4 -1
adv_rebuild.py
CHANGED
|
@@ -6,7 +6,10 @@ def write_file(path, content):
|
|
| 6 |
|
| 7 |
models_py = """
|
| 8 |
from pydantic import BaseModel, Field
|
| 9 |
-
from typing import Dict, Literal, List
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
class Observation(BaseModel):
|
| 12 |
time_step: int
|
|
@@ -184,8 +187,9 @@ TASKS = {
|
|
| 184 |
"""
|
| 185 |
|
| 186 |
main_py = """
|
| 187 |
-
from fastapi import FastAPI, HTTPException
|
| 188 |
-
from
|
|
|
|
| 189 |
from src.env import DesalEnv
|
| 190 |
from src.tasks import TASKS
|
| 191 |
import subprocess
|
|
@@ -198,7 +202,11 @@ def health_check():
|
|
| 198 |
return {"status": "ok", "message": "Advanced DesalEnv is running", "features": ["weather", "salinity", "mechanics"]}
|
| 199 |
|
| 200 |
@app.post("/reset")
|
| 201 |
-
def reset_env(task_id: str = "easy_spring"):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
if task_id not in TASKS:
|
| 203 |
raise HTTPException(status_code=404, detail="Task not found")
|
| 204 |
obs = env.reset(TASKS[task_id])
|
|
@@ -225,11 +233,11 @@ def list_tasks():
|
|
| 225 |
@app.get("/grader")
|
| 226 |
def grader():
|
| 227 |
if env.state is None:
|
| 228 |
-
return {"score": 0.
|
| 229 |
# Grade relative to typical maximum and minimum returns to generate a 0.0-1.0 range
|
| 230 |
baseline_offset = env.config.max_steps * 1000.0 # Compensate for penalties
|
| 231 |
scale_factor = env.config.max_steps * 1500.0
|
| 232 |
-
score = max(0.
|
| 233 |
return {"score": score}
|
| 234 |
|
| 235 |
@app.post("/baseline")
|
|
|
|
| 6 |
|
| 7 |
models_py = """
|
| 8 |
from pydantic import BaseModel, Field
|
| 9 |
+
from typing import Dict, Literal, List, Optional
|
| 10 |
+
|
| 11 |
+
class ResetRequest(BaseModel):
|
| 12 |
+
task_id: str = "easy_spring"
|
| 13 |
|
| 14 |
class Observation(BaseModel):
|
| 15 |
time_step: int
|
|
|
|
| 187 |
"""
|
| 188 |
|
| 189 |
main_py = """
|
| 190 |
+
from fastapi import FastAPI, HTTPException, Body
|
| 191 |
+
from typing import Optional
|
| 192 |
+
from src.models import Action, TaskConfig, ResetRequest
|
| 193 |
from src.env import DesalEnv
|
| 194 |
from src.tasks import TASKS
|
| 195 |
import subprocess
|
|
|
|
| 202 |
return {"status": "ok", "message": "Advanced DesalEnv is running", "features": ["weather", "salinity", "mechanics"]}
|
| 203 |
|
| 204 |
@app.post("/reset")
|
| 205 |
+
def reset_env(task_id: str = "easy_spring", req: Optional[ResetRequest] = None):
|
| 206 |
+
# Support both GET query params and POST JSON body for task_id
|
| 207 |
+
if req and req.task_id != "easy_spring":
|
| 208 |
+
task_id = req.task_id
|
| 209 |
+
|
| 210 |
if task_id not in TASKS:
|
| 211 |
raise HTTPException(status_code=404, detail="Task not found")
|
| 212 |
obs = env.reset(TASKS[task_id])
|
|
|
|
| 233 |
@app.get("/grader")
|
| 234 |
def grader():
|
| 235 |
if env.state is None:
|
| 236 |
+
return {"score": 0.001}
|
| 237 |
# Grade relative to typical maximum and minimum returns to generate a 0.0-1.0 range
|
| 238 |
baseline_offset = env.config.max_steps * 1000.0 # Compensate for penalties
|
| 239 |
scale_factor = env.config.max_steps * 1500.0
|
| 240 |
+
score = max(0.001, min(0.999, (env.total_reward + baseline_offset) / scale_factor))
|
| 241 |
return {"score": score}
|
| 242 |
|
| 243 |
@app.post("/baseline")
|
server/app.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
-
from fastapi import FastAPI, HTTPException
|
| 2 |
-
from src.models import Action, TaskConfig
|
| 3 |
from src.env import DesalEnv
|
| 4 |
from src.tasks import TASKS
|
| 5 |
import subprocess
|
|
|
|
| 6 |
|
| 7 |
app = FastAPI(title="Advanced Municipal Desalination Plant Env")
|
| 8 |
env = DesalEnv()
|
|
@@ -12,7 +13,11 @@ def health_check():
|
|
| 12 |
return {"status": "ok", "message": "Advanced DesalEnv is running", "features": ["weather", "salinity", "mechanics"]}
|
| 13 |
|
| 14 |
@app.post("/reset")
|
| 15 |
-
def reset_env(task_id: str = "easy_spring"):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
if task_id not in TASKS:
|
| 17 |
raise HTTPException(status_code=404, detail="Task not found")
|
| 18 |
obs = env.reset(TASKS[task_id])
|
|
@@ -39,11 +44,11 @@ def list_tasks():
|
|
| 39 |
@app.get("/grader")
|
| 40 |
def grader():
|
| 41 |
if env.state is None:
|
| 42 |
-
return {"score": 0.
|
| 43 |
# Grade relative to typical maximum and minimum returns to generate a 0.0-1.0 range
|
| 44 |
baseline_offset = env.config.max_steps * 1000.0 # Compensate for penalties
|
| 45 |
scale_factor = env.config.max_steps * 1500.0
|
| 46 |
-
score = max(0.
|
| 47 |
return {"score": score}
|
| 48 |
|
| 49 |
@app.post("/baseline")
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException, Body
|
| 2 |
+
from src.models import Action, TaskConfig, ResetRequest
|
| 3 |
from src.env import DesalEnv
|
| 4 |
from src.tasks import TASKS
|
| 5 |
import subprocess
|
| 6 |
+
from typing import Optional
|
| 7 |
|
| 8 |
app = FastAPI(title="Advanced Municipal Desalination Plant Env")
|
| 9 |
env = DesalEnv()
|
|
|
|
| 13 |
return {"status": "ok", "message": "Advanced DesalEnv is running", "features": ["weather", "salinity", "mechanics"]}
|
| 14 |
|
| 15 |
@app.post("/reset")
|
| 16 |
+
def reset_env(task_id: str = "easy_spring", req: Optional[ResetRequest] = None):
|
| 17 |
+
# Support both GET query params and POST JSON body for task_id
|
| 18 |
+
if req and req.task_id != "easy_spring":
|
| 19 |
+
task_id = req.task_id
|
| 20 |
+
|
| 21 |
if task_id not in TASKS:
|
| 22 |
raise HTTPException(status_code=404, detail="Task not found")
|
| 23 |
obs = env.reset(TASKS[task_id])
|
|
|
|
| 44 |
@app.get("/grader")
|
| 45 |
def grader():
|
| 46 |
if env.state is None:
|
| 47 |
+
return {"score": 0.001}
|
| 48 |
# Grade relative to typical maximum and minimum returns to generate a 0.0-1.0 range
|
| 49 |
baseline_offset = env.config.max_steps * 1000.0 # Compensate for penalties
|
| 50 |
scale_factor = env.config.max_steps * 1500.0
|
| 51 |
+
score = max(0.001, min(0.999, (env.total_reward + baseline_offset) / scale_factor))
|
| 52 |
return {"score": score}
|
| 53 |
|
| 54 |
@app.post("/baseline")
|
src/models.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
| 1 |
from pydantic import BaseModel, Field
|
| 2 |
-
from typing import Dict, Literal, List
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
class Observation(BaseModel):
|
| 5 |
time_step: int
|
|
|
|
| 1 |
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import Dict, Literal, List, Optional
|
| 3 |
+
|
| 4 |
+
class ResetRequest(BaseModel):
|
| 5 |
+
task_id: str = "easy_spring"
|
| 6 |
|
| 7 |
class Observation(BaseModel):
|
| 8 |
time_step: int
|