MridulNegi2005's picture
Add OpenEnv server (app.py) β€” expose /reset /step /state /ws /health endpoints
d9d8c27
"""
OpenEnv Server β€” Strategic Negotiation Environment
FastAPI + WebSocket server exposing reset(), step(), state() endpoints.
Compatible with the OpenEnv client protocol.
"""
import json
import os
import uuid
from typing import Optional
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Query
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from env_wrapper import EnvWrapper, Observation
from tasks import ALL_TASKS, TaskConfig, get_grader
# ─────────────────────────────────────────────
# Task Registry
# ─────────────────────────────────────────────
TASK_MAP = {task.name: task for task in ALL_TASKS}
# ─────────────────────────────────────────────
# Request / Response Models
# ─────────────────────────────────────────────
class ResetRequest(BaseModel):
task: str = "task_a_easy"
class StepRequest(BaseModel):
action_type: str # "OFFER", "ACCEPT", "REJECT"
price: int = 0
# ─────────────────────────────────────────────
# Session Management
# ─────────────────────────────────────────────
class SessionState:
"""Holds a single environment episode."""
def __init__(self, task_config: TaskConfig):
self.env = EnvWrapper(
opp_type=task_config.opp_type,
a_val=task_config.agent_value,
o_val=task_config.opponent_value,
agent_role=task_config.agent_role,
max_rounds=task_config.max_rounds,
)
self.task_config = task_config
self.done = False
self.rewards = []
self.steps = 0
self.deal_made = False
# In-memory session store (keyed by session_id)
sessions: dict[str, SessionState] = {}
MAX_SESSIONS = 200
def _cleanup_sessions():
"""Evict oldest sessions when limit is exceeded."""
while len(sessions) > MAX_SESSIONS:
oldest_key = next(iter(sessions))
del sessions[oldest_key]
# ─────────────────────────────────────────────
# FastAPI Application
# ─────────────────────────────────────────────
app = FastAPI(
title="Strategic Negotiation Environment",
description="OpenEnv-compliant negotiation simulation where AI agents learn to negotiate under uncertainty.",
version="1.0.0",
)
# ── Health & Info ──
@app.get("/health")
async def health():
"""Health check endpoint."""
return {"status": "healthy"}
@app.get("/")
async def root():
"""Environment info and available tasks."""
return {
"status": "running",
"environment": "negotiation-env",
"version": "1.0.0",
"tasks": [
{
"name": t.name,
"difficulty": t.difficulty,
"description": t.description,
"success_threshold": t.success_threshold,
}
for t in ALL_TASKS
],
}
# ── HTTP Endpoints ──
@app.post("/reset")
async def reset(request: ResetRequest = None):
"""Reset the environment and start a new episode."""
if request is None:
request = ResetRequest()
task_name = request.task
if task_name not in TASK_MAP:
return JSONResponse(
status_code=400,
content={"error": f"Unknown task: {task_name}. Available: {list(TASK_MAP.keys())}"},
)
task_config = TASK_MAP[task_name]
session_id = str(uuid.uuid4())
session = SessionState(task_config)
obs = session.env.reset()
sessions[session_id] = session
_cleanup_sessions()
return {
"session_id": session_id,
"observation": obs.model_dump(),
"reward": 0.0,
"done": False,
"info": {
"task": task_name,
"difficulty": task_config.difficulty,
"max_rounds": task_config.max_rounds,
},
}
@app.post("/step")
async def step(request: StepRequest, session_id: str = Query(...)):
"""Execute one step in the environment."""
if session_id not in sessions:
return JSONResponse(
status_code=404,
content={"error": "Session not found. Call POST /reset first."},
)
session = sessions[session_id]
if session.done:
obs = session.env.state()
return {
"observation": obs.model_dump(),
"reward": 0.0,
"done": True,
"info": {"error": "Episode already ended. Call /reset for a new one."},
}
# Parse action
action_type = request.action_type.upper()
price = request.price
if action_type == "OFFER":
action_str = f"OFFER {price}"
else:
action_str = action_type
price = 0
# Step environment
obs, reward, done, info = session.env.step(action_str, price)
session.done = done
session.rewards.append(reward)
session.steps += 1
if done and info.get("deal_type") in ("agent_accepted", "opponent_accepted"):
session.deal_made = True
# If done, compute final graded score
if done:
grader = get_grader(session.task_config)
grade_result = grader.grade(session.rewards, session.steps, session.deal_made)
info["grader_score"] = grade_result["score"]
info["grader_success"] = grade_result["success"]
info["threshold"] = grade_result["threshold"]
return {
"observation": obs.model_dump(),
"reward": reward,
"done": done,
"info": info,
}
@app.get("/state")
async def state(session_id: str = Query(...)):
"""Get current environment state without taking an action."""
if session_id not in sessions:
return JSONResponse(
status_code=404,
content={"error": "Session not found."},
)
session = sessions[session_id]
obs = session.env.state()
return {
"observation": obs.model_dump(),
"done": session.done,
"info": {
"task": session.task_config.name,
"steps": session.steps,
},
}
# ─────────────────────────────────────────────
# WebSocket Endpoint (persistent session)
# ─────────────────────────────────────────────
@app.websocket("/ws")
async def websocket_endpoint(ws: WebSocket):
"""WebSocket endpoint for persistent environment sessions."""
await ws.accept()
session: Optional[SessionState] = None
try:
while True:
data = await ws.receive_text()
msg = json.loads(data)
msg_type = msg.get("type", "")
if msg_type == "reset":
task_name = msg.get("task", "task_a_easy")
if task_name not in TASK_MAP:
await ws.send_json({"error": f"Unknown task: {task_name}. Available: {list(TASK_MAP.keys())}"})
continue
task_config = TASK_MAP[task_name]
session = SessionState(task_config)
obs = session.env.reset()
await ws.send_json({
"type": "reset",
"observation": obs.model_dump(),
"reward": 0.0,
"done": False,
"info": {
"task": task_name,
"difficulty": task_config.difficulty,
"max_rounds": task_config.max_rounds,
},
})
elif msg_type == "step":
if session is None:
await ws.send_json({"error": "No active session. Send a reset message first."})
continue
action = msg.get("action", {})
action_type = action.get("action_type", "REJECT").upper()
price = action.get("price", 0)
if action_type == "OFFER":
action_str = f"OFFER {price}"
else:
action_str = action_type
price = 0
obs, reward, done, info = session.env.step(action_str, price)
session.done = done
session.rewards.append(reward)
session.steps += 1
if done and info.get("deal_type") in ("agent_accepted", "opponent_accepted"):
session.deal_made = True
if done:
grader = get_grader(session.task_config)
grade_result = grader.grade(session.rewards, session.steps, session.deal_made)
info["grader_score"] = grade_result["score"]
info["grader_success"] = grade_result["success"]
info["threshold"] = grade_result["threshold"]
await ws.send_json({
"type": "step",
"observation": obs.model_dump(),
"reward": reward,
"done": done,
"info": info,
})
elif msg_type == "state":
if session is None:
await ws.send_json({"error": "No active session."})
continue
obs = session.env.state()
await ws.send_json({
"type": "state",
"observation": obs.model_dump(),
"done": session.done,
"info": {
"task": session.task_config.name,
"steps": session.steps,
},
})
else:
await ws.send_json({"error": f"Unknown message type: {msg_type}. Use: reset, step, state"})
except WebSocketDisconnect:
pass
except Exception as e:
try:
await ws.send_json({"error": str(e)})
except Exception:
pass
# ─────────────────────────────────────────────
# Startup
# ─────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", "7860"))
uvicorn.run(app, host="0.0.0.0", port=port)