| """ |
| Parlay OpenEnv WebSocket server. |
| Implements the standard reset/step/state protocol. |
| |
| OpenEnv WebSocket URL (this module only; default port from ENV_PORT or --port, default 8001): |
| ws://localhost:8001/env/ws |
| (path is /env/ws because the router uses prefix="/env" and defines @router.websocket("/ws").) |
| When the same router is mounted on main:app (uvicorn main:app --port 8000): |
| ws://localhost:8000/env/ws |
| """ |
| import json |
| import logging |
| import os |
| import uuid |
| from typing import Any |
|
|
| import numpy as np |
| from fastapi import APIRouter, FastAPI, WebSocket, WebSocketDisconnect |
|
|
| from agent.tom_tracker import ToMTracker |
| from game.scenarios import get_scenario |
|
|
| from .exceptions import ( |
| EpisodeAlreadyDoneError, |
| InvalidActionError, |
| InvalidScenarioError, |
| SessionNotFoundError, |
| ) |
| from .game_theory import compute_nash_bargaining_solution, compute_zopa |
| from .grader import compute_step_reward |
| from .models import ( |
| BeliefState, |
| HiddenState, |
| ParlayAction, |
| ParlayObservation, |
| ParlayState, |
| PersonaType, |
| TacticalMove, |
| ) |
| from .reward import ( |
| ZOPA_EROSION_CONSECUTIVE_TURNS, |
| ZOPA_EROSION_RATE, |
| ZOPA_EROSION_TENSION_THRESHOLD, |
| ) |
|
|
| logger = logging.getLogger(__name__) |
| router = APIRouter(prefix="/env", tags=["OpenEnv"]) |
|
|
| _FALLBACK_BELIEF = BeliefState( |
| est_budget=0.0, |
| est_walk_away=0.0, |
| est_urgency=0.5, |
| est_has_alternative=False, |
| confidence=0.1, |
| ) |
| FALLBACK_OBSERVATION = ParlayObservation( |
| step_count=0, |
| episode_done=False, |
| current_offer=0.0, |
| opponent_offer=0.0, |
| zopa_lower=0.0, |
| zopa_upper=0.0, |
| nash_point=0.0, |
| tension_score=0.0, |
| belief_state=_FALLBACK_BELIEF, |
| last_utterance="[Connection issue - AI is thinking]", |
| available_moves=list(TacticalMove), |
| credibility_points=100, |
| reward=0.0, |
| cumulative_reward=0.0, |
| ) |
|
|
| _sessions: dict[str, dict[str, Any]] = {} |
|
|
| MAX_TURNS = int(os.getenv("MAX_TURNS_PER_EPISODE", "20")) |
| CP_START = int(os.getenv("CREDIBILITY_POINTS_START", "100")) |
| CP_REGEN = int(os.getenv("CREDIBILITY_REGEN_PER_TURN", "5")) |
|
|
| _SCENARIO_DEFAULTS: dict[str, dict[str, Any]] = { |
| "saas_enterprise": dict(budget=165_000, walk=125_000, urgency=0.55, alt=True), |
| "hiring_package": dict(budget=230_000, walk=195_000, urgency=0.60, alt=False), |
| "acquisition_term_sheet": dict(budget=16_000_000, walk=10_500_000, urgency=0.65, alt=True), |
| } |
|
|
| _CP_COSTS: dict[TacticalMove, int] = { |
| TacticalMove.ANCHOR_HIGH: 0, |
| TacticalMove.BATNA_REVEAL: 20, |
| TacticalMove.SILENCE: 5, |
| } |
|
|
|
|
| def _get_scenario_hidden_state(scenario_id: str, rng_seed: int = 42) -> HiddenState: |
| """Return a HiddenState for the given scenario with slight randomisation.""" |
| if scenario_id not in _SCENARIO_DEFAULTS: |
| raise InvalidScenarioError(f"Unknown scenario: {scenario_id}") |
|
|
| rng = np.random.default_rng(rng_seed) |
| defaults = _SCENARIO_DEFAULTS[scenario_id] |
| noise = rng.uniform(0.95, 1.05) |
|
|
| return HiddenState( |
| budget_ceiling=round(defaults["budget"] * noise, 2), |
| walk_away_price=round(defaults["walk"] * noise, 2), |
| urgency_score=float(np.clip(defaults["urgency"] + rng.uniform(-0.1, 0.1), 0.0, 1.0)), |
| has_alternative=bool(defaults["alt"]), |
| persona_drifted=False, |
| ) |
|
|
|
|
| def _initial_belief(hidden: HiddenState) -> BeliefState: |
| """Initial belief state - intentionally imprecise.""" |
| return BeliefState( |
| est_budget=hidden.budget_ceiling * 0.80, |
| est_walk_away=hidden.walk_away_price * 1.15, |
| est_urgency=0.50, |
| est_has_alternative=False, |
| confidence=0.30, |
| ) |
|
|
|
|
| def _get_cp_cost(move: TacticalMove | None) -> int: |
| if move is None: |
| return 0 |
| return _CP_COSTS.get(move, 0) |
|
|
|
|
| def _compute_tension(state: ParlayState, action: ParlayAction) -> float: |
| """Compute the current tension score for the turn.""" |
| base = 20.0 + ((state.step_count + 1) / MAX_TURNS) * 60.0 |
| if action.tactical_move == TacticalMove.ANCHOR_HIGH: |
| base += 15.0 |
| elif action.tactical_move == TacticalMove.BATNA_REVEAL: |
| base += 10.0 |
| elif action.tactical_move == TacticalMove.SILENCE: |
| base += 5.0 |
| return float(max(0.0, min(100.0, base))) |
|
|
|
|
| def _apply_drift_event(state: ParlayState, tom: ToMTracker) -> str | None: |
| """Apply scenario drift event at the current turn, if any.""" |
| try: |
| scenario = get_scenario(state.scenario_id) |
| except Exception: |
| return None |
|
|
| for event in scenario.drift_events: |
| if event.trigger_turn == state.step_count: |
| state.drift_events_fired += 1 |
| state.hidden_state.persona_drifted = True |
| tom.drift_event( |
| event.effect_on_urgency, |
| event.effect_on_has_alternative, |
| event_description=event.event, |
| ) |
| state.belief_history = list(tom.history) |
| return event.event |
| return None |
|
|
|
|
| def _make_observation( |
| state: ParlayState, |
| reward: float, |
| utterance: str, |
| drift_event: str | None = None, |
| ) -> ParlayObservation: |
| """Build a ParlayObservation from the current state.""" |
| zopa = compute_zopa(state.hidden_state.budget_ceiling, state.hidden_state.walk_away_price) |
| zopa_lower = zopa[0] if zopa else state.hidden_state.walk_away_price |
| zopa_upper = zopa[1] if zopa else state.hidden_state.budget_ceiling |
| nash = compute_nash_bargaining_solution( |
| state.hidden_state.budget_ceiling, |
| state.hidden_state.walk_away_price, |
| ) |
| current_offer = state.offer_history[-1] if state.offer_history else 0.0 |
| belief = state.belief_history[-1] if state.belief_history else _initial_belief(state.hidden_state) |
|
|
| original_zopa = state.original_zopa_width |
| current_zopa = max(0.0, state.hidden_state.budget_ceiling - state.hidden_state.walk_away_price) |
| width_pct = current_zopa / original_zopa if original_zopa > 0 else 0.0 |
|
|
| return ParlayObservation( |
| step_count=state.step_count, |
| episode_done=state.episode_done, |
| current_offer=current_offer, |
| opponent_offer=zopa_upper * 0.9, |
| zopa_lower=zopa_lower, |
| zopa_upper=zopa_upper, |
| nash_point=nash, |
| tension_score=state.tension_score, |
| belief_state=belief, |
| last_utterance=utterance, |
| available_moves=list(TacticalMove), |
| credibility_points=state.credibility_points, |
| reward=reward, |
| cumulative_reward=state.cumulative_reward, |
| drift_event=drift_event, |
| zopa_erosion_ticks=state.zopa_erosion_ticks, |
| zopa_width_pct_remaining=width_pct, |
| ) |
|
|
|
|
| def _coerce_message_params(msg: dict[str, Any]) -> tuple[str, dict[str, Any]]: |
| """Accept both {cmd, ...} and {method, params} envelope formats.""" |
| if "method" in msg: |
| return str(msg["method"]), dict(msg.get("params", {})) |
| command = str(msg.get("cmd", "")) |
| params = {k: v for k, v in msg.items() if k != "cmd"} |
| return command, params |
|
|
|
|
| @router.websocket("/ws") |
| async def env_websocket(websocket: WebSocket) -> None: |
| """OpenEnv WebSocket endpoint.""" |
| await websocket.accept() |
| logger.info("OpenEnv WebSocket client connected") |
| try: |
| while True: |
| raw = await websocket.receive_text() |
| try: |
| msg = json.loads(raw) |
| except json.JSONDecodeError: |
| await websocket.send_json({"error": "Invalid JSON"}) |
| continue |
|
|
| command, params = _coerce_message_params(msg) |
| try: |
| match command: |
| case "reset": |
| result = await _handle_reset(params) |
| case "step": |
| result = await _handle_step(params) |
| case "state": |
| result = await _handle_state(params) |
| case _: |
| result = {"error": f"Unknown command: {command}"} |
| except ( |
| InvalidActionError, |
| SessionNotFoundError, |
| EpisodeAlreadyDoneError, |
| InvalidScenarioError, |
| ) as exc: |
| result = {"error": str(exc)} |
| except Exception: |
| logger.exception("Unhandled error in env WebSocket - returning fallback observation") |
| result = { |
| "observation": FALLBACK_OBSERVATION.model_dump(), |
| "done": False, |
| "_fallback": True, |
| } |
|
|
| await websocket.send_json(result) |
| except WebSocketDisconnect: |
| logger.info("OpenEnv WebSocket client disconnected") |
|
|
|
|
| async def _handle_reset(msg: dict[str, Any]) -> dict: |
| """Handle reset: create a fresh episode.""" |
| scenario_id = msg.get("scenario_id", "saas_enterprise") |
| persona_str = msg.get("persona", "shark") |
| seed = int(msg.get("seed", 42)) |
|
|
| try: |
| persona = PersonaType(persona_str) |
| except ValueError as exc: |
| raise InvalidScenarioError(f"Unknown persona: {persona_str}") from exc |
|
|
| hidden = _get_scenario_hidden_state(scenario_id, seed) |
| initial_belief = _initial_belief(hidden) |
| session_id = str(uuid.uuid4()) |
| original_zopa_width = hidden.budget_ceiling - hidden.walk_away_price |
|
|
| state = ParlayState( |
| session_id=session_id, |
| scenario_id=scenario_id, |
| persona=persona, |
| step_count=0, |
| cumulative_reward=0.0, |
| hidden_state=hidden, |
| belief_history=[initial_belief], |
| offer_history=[], |
| drift_events_fired=0, |
| episode_done=False, |
| credibility_points=CP_START, |
| original_zopa_width=original_zopa_width, |
| zopa_width_pct_remaining=1.0, |
| ) |
| _sessions[session_id] = { |
| "state": state, |
| "tom_tracker": ToMTracker(initial_belief, persona), |
| } |
|
|
| observation = _make_observation(state, 0.0, "Negotiation started. Make your opening move.") |
| logger.info("Reset: session=%s, scenario=%s, persona=%s", session_id, scenario_id, persona_str) |
| return {"session_id": session_id, "observation": observation.model_dump(), "done": False} |
|
|
|
|
| async def _handle_step(msg: dict[str, Any]) -> dict: |
| """Advance the episode by one action.""" |
| session_id = msg.get("session_id") |
| if not session_id or session_id not in _sessions: |
| raise SessionNotFoundError(f"Session {session_id} not found") |
|
|
| session = _sessions[session_id] |
| state: ParlayState = session["state"] |
| tom: ToMTracker = session["tom_tracker"] |
| if state.episode_done: |
| raise EpisodeAlreadyDoneError(f"Episode {session_id} is already done") |
|
|
| action_payload = msg.get("action", msg) |
| try: |
| action = ParlayAction.model_validate(action_payload) |
| except Exception as exc: |
| raise InvalidActionError(f"Invalid action: {exc}") from exc |
|
|
| cp_cost = _get_cp_cost(action.tactical_move) |
| if state.credibility_points < cp_cost: |
| raise InvalidActionError("Insufficient credibility points for that move") |
|
|
| new_offers = list(state.offer_history) |
| if action.offer_amount is not None: |
| new_offers.append(action.offer_amount) |
|
|
| tom.update( |
| observed_offer=action.offer_amount, |
| observed_move=action.tactical_move, |
| utterance=action.utterance, |
| turn=state.step_count + 1, |
| ) |
| new_beliefs = list(tom.history) |
|
|
| next_state = ParlayState( |
| **{ |
| **state.model_dump(), |
| "step_count": state.step_count + 1, |
| "offer_history": new_offers, |
| "belief_history": new_beliefs, |
| "credibility_points": min(CP_START, state.credibility_points + CP_REGEN - cp_cost), |
| "tension_score": _compute_tension(state, action), |
| "hidden_state": HiddenState(**state.hidden_state.model_dump()), |
| } |
| ) |
| |
| next_state.belief_history = new_beliefs |
|
|
| if action.tactical_move == TacticalMove.BATNA_REVEAL: |
| revealed = action.offer_amount if action.offer_amount is not None else next_state.hidden_state.walk_away_price |
| next_state.hidden_state.last_stated_batna = float(revealed) |
|
|
| if next_state.tension_score >= ZOPA_EROSION_TENSION_THRESHOLD: |
| next_state.high_tension_streak += 1 |
| else: |
| next_state.high_tension_streak = 0 |
|
|
| if next_state.high_tension_streak >= ZOPA_EROSION_CONSECUTIVE_TURNS: |
| zopa_width = next_state.hidden_state.budget_ceiling - next_state.hidden_state.walk_away_price |
| base_width = next_state.original_zopa_width or zopa_width |
| shift = base_width * ZOPA_EROSION_RATE |
| next_state.hidden_state.budget_ceiling -= shift |
| next_state.hidden_state.walk_away_price += shift |
| next_state.zopa_erosion_ticks += 1 |
| next_state.high_tension_streak = 0 |
|
|
| if next_state.hidden_state.budget_ceiling <= next_state.hidden_state.walk_away_price: |
| next_state.walk_away = True |
| next_state.termination_reason = "zopa_collapsed" |
|
|
| current_zopa = max(0.0, next_state.hidden_state.budget_ceiling - next_state.hidden_state.walk_away_price) |
| next_state.zopa_width_pct_remaining = ( |
| current_zopa / next_state.original_zopa_width if next_state.original_zopa_width > 0 else 0.0 |
| ) |
|
|
| if action.offer_amount is not None: |
| next_state.deal_reached = ( |
| next_state.hidden_state.walk_away_price |
| <= action.offer_amount |
| <= next_state.hidden_state.budget_ceiling |
| ) |
|
|
| drift_event = _apply_drift_event(next_state, tom) |
| step_reward = compute_step_reward(state, action, next_state, drift_event=drift_event) |
| next_state.cumulative_reward = state.cumulative_reward + step_reward |
|
|
| if step_reward >= 0.0 and action.tactical_move is None and state.hidden_state.last_stated_batna is not None: |
| from .grader import detect_bluff_challenge |
|
|
| if detect_bluff_challenge( |
| utterance=action.utterance, |
| opponent_stated_batna=state.hidden_state.last_stated_batna, |
| opponent_true_batna=state.hidden_state.budget_ceiling, |
| ): |
| next_state.bluffs_caught = state.bluffs_caught + 1 |
|
|
| next_state.episode_done = ( |
| next_state.step_count >= MAX_TURNS |
| or step_reward < -100.0 |
| or next_state.deal_reached |
| or next_state.walk_away |
| ) |
| if next_state.episode_done and next_state.termination_reason is None: |
| if next_state.deal_reached: |
| next_state.termination_reason = "deal_reached" |
| elif step_reward < -100.0: |
| next_state.termination_reason = "reward_floor" |
| elif next_state.walk_away: |
| next_state.termination_reason = "walk_away" |
| else: |
| next_state.termination_reason = "max_turns" |
|
|
| _sessions[session_id] = {"state": next_state, "tom_tracker": tom} |
| observation = _make_observation(next_state, step_reward, action.utterance, drift_event=drift_event) |
| return {"observation": observation.model_dump(), "done": next_state.episode_done} |
|
|
|
|
| async def _handle_state(msg: dict[str, Any]) -> dict: |
| """Return raw session state.""" |
| session_id = msg.get("session_id") |
| if not session_id or session_id not in _sessions: |
| raise SessionNotFoundError(f"Session {session_id} not found") |
| return {"state": _sessions[session_id]["state"].model_dump()} |
|
|
|
|
| def get_session_state(session_id: str) -> ParlayState | None: |
| """Return the in-memory session state for SSE and tests.""" |
| session = _sessions.get(session_id) |
| if not session: |
| return None |
| return session["state"] |
|
|
|
|
| @router.get("/sessions/{session_id}") |
| async def get_session(session_id: str) -> dict: |
| """Get session state via REST.""" |
| if session_id not in _sessions: |
| raise SessionNotFoundError(f"Session {session_id} not found") |
| return {"state": _sessions[session_id]["state"].model_dump()} |
|
|
|
|
| _env_app = FastAPI(title="Parlay OpenEnv", version="1.0.0") |
| _env_app.include_router(router) |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| import uvicorn |
|
|
| parser = argparse.ArgumentParser(description="Run the Parlay OpenEnv server") |
| parser.add_argument("--port", type=int, default=int(os.getenv("ENV_PORT", "8001"))) |
| args = parser.parse_args() |
|
|
| port = int(args.port) |
| uvicorn.run(_env_app, host="0.0.0.0", port=port) |
|
|