Spaces:
Sleeping
Sleeping
| """ | |
| OpenEnv Server β Strategic Negotiation Environment | |
| FastAPI + WebSocket server exposing reset(), step(), state() endpoints. | |
| Compatible with the OpenEnv client protocol. | |
| """ | |
| import json | |
| import os | |
| import uuid | |
| from typing import Optional | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Query | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| from env_wrapper import EnvWrapper, Observation | |
| from tasks import ALL_TASKS, TaskConfig, get_grader | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Task Registry | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| TASK_MAP = {task.name: task for task in ALL_TASKS} | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Request / Response Models | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class ResetRequest(BaseModel): | |
| task: str = "task_a_easy" | |
| class StepRequest(BaseModel): | |
| action_type: str # "OFFER", "ACCEPT", "REJECT" | |
| price: int = 0 | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Session Management | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class SessionState: | |
| """Holds a single environment episode.""" | |
| def __init__(self, task_config: TaskConfig): | |
| self.env = EnvWrapper( | |
| opp_type=task_config.opp_type, | |
| a_val=task_config.agent_value, | |
| o_val=task_config.opponent_value, | |
| agent_role=task_config.agent_role, | |
| max_rounds=task_config.max_rounds, | |
| ) | |
| self.task_config = task_config | |
| self.done = False | |
| self.rewards = [] | |
| self.steps = 0 | |
| self.deal_made = False | |
| # In-memory session store (keyed by session_id) | |
| sessions: dict[str, SessionState] = {} | |
| MAX_SESSIONS = 200 | |
| def _cleanup_sessions(): | |
| """Evict oldest sessions when limit is exceeded.""" | |
| while len(sessions) > MAX_SESSIONS: | |
| oldest_key = next(iter(sessions)) | |
| del sessions[oldest_key] | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # FastAPI Application | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI( | |
| title="Strategic Negotiation Environment", | |
| description="OpenEnv-compliant negotiation simulation where AI agents learn to negotiate under uncertainty.", | |
| version="1.0.0", | |
| ) | |
| # ββ Health & Info ββ | |
| async def health(): | |
| """Health check endpoint.""" | |
| return {"status": "healthy"} | |
| async def root(): | |
| """Environment info and available tasks.""" | |
| return { | |
| "status": "running", | |
| "environment": "negotiation-env", | |
| "version": "1.0.0", | |
| "tasks": [ | |
| { | |
| "name": t.name, | |
| "difficulty": t.difficulty, | |
| "description": t.description, | |
| "success_threshold": t.success_threshold, | |
| } | |
| for t in ALL_TASKS | |
| ], | |
| } | |
| # ββ HTTP Endpoints ββ | |
| async def reset(request: ResetRequest = None): | |
| """Reset the environment and start a new episode.""" | |
| if request is None: | |
| request = ResetRequest() | |
| task_name = request.task | |
| if task_name not in TASK_MAP: | |
| return JSONResponse( | |
| status_code=400, | |
| content={"error": f"Unknown task: {task_name}. Available: {list(TASK_MAP.keys())}"}, | |
| ) | |
| task_config = TASK_MAP[task_name] | |
| session_id = str(uuid.uuid4()) | |
| session = SessionState(task_config) | |
| obs = session.env.reset() | |
| sessions[session_id] = session | |
| _cleanup_sessions() | |
| return { | |
| "session_id": session_id, | |
| "observation": obs.model_dump(), | |
| "reward": 0.0, | |
| "done": False, | |
| "info": { | |
| "task": task_name, | |
| "difficulty": task_config.difficulty, | |
| "max_rounds": task_config.max_rounds, | |
| }, | |
| } | |
| async def step(request: StepRequest, session_id: str = Query(...)): | |
| """Execute one step in the environment.""" | |
| if session_id not in sessions: | |
| return JSONResponse( | |
| status_code=404, | |
| content={"error": "Session not found. Call POST /reset first."}, | |
| ) | |
| session = sessions[session_id] | |
| if session.done: | |
| obs = session.env.state() | |
| return { | |
| "observation": obs.model_dump(), | |
| "reward": 0.0, | |
| "done": True, | |
| "info": {"error": "Episode already ended. Call /reset for a new one."}, | |
| } | |
| # Parse action | |
| action_type = request.action_type.upper() | |
| price = request.price | |
| if action_type == "OFFER": | |
| action_str = f"OFFER {price}" | |
| else: | |
| action_str = action_type | |
| price = 0 | |
| # Step environment | |
| obs, reward, done, info = session.env.step(action_str, price) | |
| session.done = done | |
| session.rewards.append(reward) | |
| session.steps += 1 | |
| if done and info.get("deal_type") in ("agent_accepted", "opponent_accepted"): | |
| session.deal_made = True | |
| # If done, compute final graded score | |
| if done: | |
| grader = get_grader(session.task_config) | |
| grade_result = grader.grade(session.rewards, session.steps, session.deal_made) | |
| info["grader_score"] = grade_result["score"] | |
| info["grader_success"] = grade_result["success"] | |
| info["threshold"] = grade_result["threshold"] | |
| return { | |
| "observation": obs.model_dump(), | |
| "reward": reward, | |
| "done": done, | |
| "info": info, | |
| } | |
| async def state(session_id: str = Query(...)): | |
| """Get current environment state without taking an action.""" | |
| if session_id not in sessions: | |
| return JSONResponse( | |
| status_code=404, | |
| content={"error": "Session not found."}, | |
| ) | |
| session = sessions[session_id] | |
| obs = session.env.state() | |
| return { | |
| "observation": obs.model_dump(), | |
| "done": session.done, | |
| "info": { | |
| "task": session.task_config.name, | |
| "steps": session.steps, | |
| }, | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # WebSocket Endpoint (persistent session) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| async def websocket_endpoint(ws: WebSocket): | |
| """WebSocket endpoint for persistent environment sessions.""" | |
| await ws.accept() | |
| session: Optional[SessionState] = None | |
| try: | |
| while True: | |
| data = await ws.receive_text() | |
| msg = json.loads(data) | |
| msg_type = msg.get("type", "") | |
| if msg_type == "reset": | |
| task_name = msg.get("task", "task_a_easy") | |
| if task_name not in TASK_MAP: | |
| await ws.send_json({"error": f"Unknown task: {task_name}. Available: {list(TASK_MAP.keys())}"}) | |
| continue | |
| task_config = TASK_MAP[task_name] | |
| session = SessionState(task_config) | |
| obs = session.env.reset() | |
| await ws.send_json({ | |
| "type": "reset", | |
| "observation": obs.model_dump(), | |
| "reward": 0.0, | |
| "done": False, | |
| "info": { | |
| "task": task_name, | |
| "difficulty": task_config.difficulty, | |
| "max_rounds": task_config.max_rounds, | |
| }, | |
| }) | |
| elif msg_type == "step": | |
| if session is None: | |
| await ws.send_json({"error": "No active session. Send a reset message first."}) | |
| continue | |
| action = msg.get("action", {}) | |
| action_type = action.get("action_type", "REJECT").upper() | |
| price = action.get("price", 0) | |
| if action_type == "OFFER": | |
| action_str = f"OFFER {price}" | |
| else: | |
| action_str = action_type | |
| price = 0 | |
| obs, reward, done, info = session.env.step(action_str, price) | |
| session.done = done | |
| session.rewards.append(reward) | |
| session.steps += 1 | |
| if done and info.get("deal_type") in ("agent_accepted", "opponent_accepted"): | |
| session.deal_made = True | |
| if done: | |
| grader = get_grader(session.task_config) | |
| grade_result = grader.grade(session.rewards, session.steps, session.deal_made) | |
| info["grader_score"] = grade_result["score"] | |
| info["grader_success"] = grade_result["success"] | |
| info["threshold"] = grade_result["threshold"] | |
| await ws.send_json({ | |
| "type": "step", | |
| "observation": obs.model_dump(), | |
| "reward": reward, | |
| "done": done, | |
| "info": info, | |
| }) | |
| elif msg_type == "state": | |
| if session is None: | |
| await ws.send_json({"error": "No active session."}) | |
| continue | |
| obs = session.env.state() | |
| await ws.send_json({ | |
| "type": "state", | |
| "observation": obs.model_dump(), | |
| "done": session.done, | |
| "info": { | |
| "task": session.task_config.name, | |
| "steps": session.steps, | |
| }, | |
| }) | |
| else: | |
| await ws.send_json({"error": f"Unknown message type: {msg_type}. Use: reset, step, state"}) | |
| except WebSocketDisconnect: | |
| pass | |
| except Exception as e: | |
| try: | |
| await ws.send_json({"error": str(e)}) | |
| except Exception: | |
| pass | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Startup | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", "7860")) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |