Parlay / parlay_env /server.py
sh4shv4t's picture
Add pre-training audit scripts, OpenEnv manifest, and tune Parlay training/env (GRPO 1.5B default, min-reward filters, weighted data gen, hiring ZOPA+drift, veteran/opponent prompts, Docker/docs)
df724f2
"""
Parlay OpenEnv WebSocket server.
Implements the standard reset/step/state protocol.
OpenEnv WebSocket URL (this module only; default port from ENV_PORT or --port, default 8001):
ws://localhost:8001/env/ws
(path is /env/ws because the router uses prefix="/env" and defines @router.websocket("/ws").)
When the same router is mounted on main:app (uvicorn main:app --port 8000):
ws://localhost:8000/env/ws
"""
import json
import logging
import os
import uuid
from typing import Any
import numpy as np
from fastapi import APIRouter, FastAPI, WebSocket, WebSocketDisconnect
from agent.tom_tracker import ToMTracker
from game.scenarios import get_scenario
from .exceptions import (
EpisodeAlreadyDoneError,
InvalidActionError,
InvalidScenarioError,
SessionNotFoundError,
)
from .game_theory import compute_nash_bargaining_solution, compute_zopa
from .grader import compute_step_reward
from .models import (
BeliefState,
HiddenState,
ParlayAction,
ParlayObservation,
ParlayState,
PersonaType,
TacticalMove,
)
from .reward import (
ZOPA_EROSION_CONSECUTIVE_TURNS,
ZOPA_EROSION_RATE,
ZOPA_EROSION_TENSION_THRESHOLD,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/env", tags=["OpenEnv"])
_FALLBACK_BELIEF = BeliefState(
est_budget=0.0,
est_walk_away=0.0,
est_urgency=0.5,
est_has_alternative=False,
confidence=0.1,
)
FALLBACK_OBSERVATION = ParlayObservation(
step_count=0,
episode_done=False,
current_offer=0.0,
opponent_offer=0.0,
zopa_lower=0.0,
zopa_upper=0.0,
nash_point=0.0,
tension_score=0.0,
belief_state=_FALLBACK_BELIEF,
last_utterance="[Connection issue - AI is thinking]",
available_moves=list(TacticalMove),
credibility_points=100,
reward=0.0,
cumulative_reward=0.0,
)
_sessions: dict[str, dict[str, Any]] = {}
MAX_TURNS = int(os.getenv("MAX_TURNS_PER_EPISODE", "20"))
CP_START = int(os.getenv("CREDIBILITY_POINTS_START", "100"))
CP_REGEN = int(os.getenv("CREDIBILITY_REGEN_PER_TURN", "5"))
_SCENARIO_DEFAULTS: dict[str, dict[str, Any]] = {
"saas_enterprise": dict(budget=165_000, walk=125_000, urgency=0.55, alt=True),
"hiring_package": dict(budget=230_000, walk=195_000, urgency=0.60, alt=False),
"acquisition_term_sheet": dict(budget=16_000_000, walk=10_500_000, urgency=0.65, alt=True),
}
_CP_COSTS: dict[TacticalMove, int] = {
TacticalMove.ANCHOR_HIGH: 0,
TacticalMove.BATNA_REVEAL: 20,
TacticalMove.SILENCE: 5,
}
def _get_scenario_hidden_state(scenario_id: str, rng_seed: int = 42) -> HiddenState:
"""Return a HiddenState for the given scenario with slight randomisation."""
if scenario_id not in _SCENARIO_DEFAULTS:
raise InvalidScenarioError(f"Unknown scenario: {scenario_id}")
rng = np.random.default_rng(rng_seed)
defaults = _SCENARIO_DEFAULTS[scenario_id]
noise = rng.uniform(0.95, 1.05)
return HiddenState(
budget_ceiling=round(defaults["budget"] * noise, 2),
walk_away_price=round(defaults["walk"] * noise, 2),
urgency_score=float(np.clip(defaults["urgency"] + rng.uniform(-0.1, 0.1), 0.0, 1.0)),
has_alternative=bool(defaults["alt"]),
persona_drifted=False,
)
def _initial_belief(hidden: HiddenState) -> BeliefState:
"""Initial belief state - intentionally imprecise."""
return BeliefState(
est_budget=hidden.budget_ceiling * 0.80,
est_walk_away=hidden.walk_away_price * 1.15,
est_urgency=0.50,
est_has_alternative=False,
confidence=0.30,
)
def _get_cp_cost(move: TacticalMove | None) -> int:
if move is None:
return 0
return _CP_COSTS.get(move, 0)
def _compute_tension(state: ParlayState, action: ParlayAction) -> float:
"""Compute the current tension score for the turn."""
base = 20.0 + ((state.step_count + 1) / MAX_TURNS) * 60.0
if action.tactical_move == TacticalMove.ANCHOR_HIGH:
base += 15.0
elif action.tactical_move == TacticalMove.BATNA_REVEAL:
base += 10.0
elif action.tactical_move == TacticalMove.SILENCE:
base += 5.0
return float(max(0.0, min(100.0, base)))
def _apply_drift_event(state: ParlayState, tom: ToMTracker) -> str | None:
"""Apply scenario drift event at the current turn, if any."""
try:
scenario = get_scenario(state.scenario_id)
except Exception:
return None
for event in scenario.drift_events:
if event.trigger_turn == state.step_count:
state.drift_events_fired += 1
state.hidden_state.persona_drifted = True
tom.drift_event(
event.effect_on_urgency,
event.effect_on_has_alternative,
event_description=event.event,
)
state.belief_history = list(tom.history)
return event.event
return None
def _make_observation(
state: ParlayState,
reward: float,
utterance: str,
drift_event: str | None = None,
) -> ParlayObservation:
"""Build a ParlayObservation from the current state."""
zopa = compute_zopa(state.hidden_state.budget_ceiling, state.hidden_state.walk_away_price)
zopa_lower = zopa[0] if zopa else state.hidden_state.walk_away_price
zopa_upper = zopa[1] if zopa else state.hidden_state.budget_ceiling
nash = compute_nash_bargaining_solution(
state.hidden_state.budget_ceiling,
state.hidden_state.walk_away_price,
)
current_offer = state.offer_history[-1] if state.offer_history else 0.0
belief = state.belief_history[-1] if state.belief_history else _initial_belief(state.hidden_state)
original_zopa = state.original_zopa_width
current_zopa = max(0.0, state.hidden_state.budget_ceiling - state.hidden_state.walk_away_price)
width_pct = current_zopa / original_zopa if original_zopa > 0 else 0.0
return ParlayObservation(
step_count=state.step_count,
episode_done=state.episode_done,
current_offer=current_offer,
opponent_offer=zopa_upper * 0.9,
zopa_lower=zopa_lower,
zopa_upper=zopa_upper,
nash_point=nash,
tension_score=state.tension_score,
belief_state=belief,
last_utterance=utterance,
available_moves=list(TacticalMove),
credibility_points=state.credibility_points,
reward=reward,
cumulative_reward=state.cumulative_reward,
drift_event=drift_event,
zopa_erosion_ticks=state.zopa_erosion_ticks,
zopa_width_pct_remaining=width_pct,
)
def _coerce_message_params(msg: dict[str, Any]) -> tuple[str, dict[str, Any]]:
"""Accept both {cmd, ...} and {method, params} envelope formats."""
if "method" in msg:
return str(msg["method"]), dict(msg.get("params", {}))
command = str(msg.get("cmd", ""))
params = {k: v for k, v in msg.items() if k != "cmd"}
return command, params
@router.websocket("/ws")
async def env_websocket(websocket: WebSocket) -> None:
"""OpenEnv WebSocket endpoint."""
await websocket.accept()
logger.info("OpenEnv WebSocket client connected")
try:
while True:
raw = await websocket.receive_text()
try:
msg = json.loads(raw)
except json.JSONDecodeError:
await websocket.send_json({"error": "Invalid JSON"})
continue
command, params = _coerce_message_params(msg)
try:
match command:
case "reset":
result = await _handle_reset(params)
case "step":
result = await _handle_step(params)
case "state":
result = await _handle_state(params)
case _:
result = {"error": f"Unknown command: {command}"}
except (
InvalidActionError,
SessionNotFoundError,
EpisodeAlreadyDoneError,
InvalidScenarioError,
) as exc:
result = {"error": str(exc)}
except Exception:
logger.exception("Unhandled error in env WebSocket - returning fallback observation")
result = {
"observation": FALLBACK_OBSERVATION.model_dump(),
"done": False,
"_fallback": True,
}
await websocket.send_json(result)
except WebSocketDisconnect:
logger.info("OpenEnv WebSocket client disconnected")
async def _handle_reset(msg: dict[str, Any]) -> dict:
"""Handle reset: create a fresh episode."""
scenario_id = msg.get("scenario_id", "saas_enterprise")
persona_str = msg.get("persona", "shark")
seed = int(msg.get("seed", 42))
try:
persona = PersonaType(persona_str)
except ValueError as exc:
raise InvalidScenarioError(f"Unknown persona: {persona_str}") from exc
hidden = _get_scenario_hidden_state(scenario_id, seed)
initial_belief = _initial_belief(hidden)
session_id = str(uuid.uuid4())
original_zopa_width = hidden.budget_ceiling - hidden.walk_away_price
state = ParlayState(
session_id=session_id,
scenario_id=scenario_id,
persona=persona,
step_count=0,
cumulative_reward=0.0,
hidden_state=hidden,
belief_history=[initial_belief],
offer_history=[],
drift_events_fired=0,
episode_done=False,
credibility_points=CP_START,
original_zopa_width=original_zopa_width,
zopa_width_pct_remaining=1.0,
)
_sessions[session_id] = {
"state": state,
"tom_tracker": ToMTracker(initial_belief, persona),
}
observation = _make_observation(state, 0.0, "Negotiation started. Make your opening move.")
logger.info("Reset: session=%s, scenario=%s, persona=%s", session_id, scenario_id, persona_str)
return {"session_id": session_id, "observation": observation.model_dump(), "done": False}
async def _handle_step(msg: dict[str, Any]) -> dict:
"""Advance the episode by one action."""
session_id = msg.get("session_id")
if not session_id or session_id not in _sessions:
raise SessionNotFoundError(f"Session {session_id} not found")
session = _sessions[session_id]
state: ParlayState = session["state"]
tom: ToMTracker = session["tom_tracker"]
if state.episode_done:
raise EpisodeAlreadyDoneError(f"Episode {session_id} is already done")
action_payload = msg.get("action", msg)
try:
action = ParlayAction.model_validate(action_payload)
except Exception as exc:
raise InvalidActionError(f"Invalid action: {exc}") from exc
cp_cost = _get_cp_cost(action.tactical_move)
if state.credibility_points < cp_cost:
raise InvalidActionError("Insufficient credibility points for that move")
new_offers = list(state.offer_history)
if action.offer_amount is not None:
new_offers.append(action.offer_amount)
tom.update(
observed_offer=action.offer_amount,
observed_move=action.tactical_move,
utterance=action.utterance,
turn=state.step_count + 1,
)
new_beliefs = list(tom.history)
next_state = ParlayState(
**{
**state.model_dump(),
"step_count": state.step_count + 1,
"offer_history": new_offers,
"belief_history": new_beliefs,
"credibility_points": min(CP_START, state.credibility_points + CP_REGEN - cp_cost),
"tension_score": _compute_tension(state, action),
"hidden_state": HiddenState(**state.hidden_state.model_dump()),
}
)
# Keep belief history aligned with ToM tracker history (single source of truth).
next_state.belief_history = new_beliefs
if action.tactical_move == TacticalMove.BATNA_REVEAL:
revealed = action.offer_amount if action.offer_amount is not None else next_state.hidden_state.walk_away_price
next_state.hidden_state.last_stated_batna = float(revealed)
if next_state.tension_score >= ZOPA_EROSION_TENSION_THRESHOLD:
next_state.high_tension_streak += 1
else:
next_state.high_tension_streak = 0
if next_state.high_tension_streak >= ZOPA_EROSION_CONSECUTIVE_TURNS:
zopa_width = next_state.hidden_state.budget_ceiling - next_state.hidden_state.walk_away_price
base_width = next_state.original_zopa_width or zopa_width
shift = base_width * ZOPA_EROSION_RATE
next_state.hidden_state.budget_ceiling -= shift
next_state.hidden_state.walk_away_price += shift
next_state.zopa_erosion_ticks += 1
next_state.high_tension_streak = 0
if next_state.hidden_state.budget_ceiling <= next_state.hidden_state.walk_away_price:
next_state.walk_away = True
next_state.termination_reason = "zopa_collapsed"
current_zopa = max(0.0, next_state.hidden_state.budget_ceiling - next_state.hidden_state.walk_away_price)
next_state.zopa_width_pct_remaining = (
current_zopa / next_state.original_zopa_width if next_state.original_zopa_width > 0 else 0.0
)
if action.offer_amount is not None:
next_state.deal_reached = (
next_state.hidden_state.walk_away_price
<= action.offer_amount
<= next_state.hidden_state.budget_ceiling
)
drift_event = _apply_drift_event(next_state, tom)
step_reward = compute_step_reward(state, action, next_state, drift_event=drift_event)
next_state.cumulative_reward = state.cumulative_reward + step_reward
if step_reward >= 0.0 and action.tactical_move is None and state.hidden_state.last_stated_batna is not None:
from .grader import detect_bluff_challenge # noqa: PLC0415
if detect_bluff_challenge(
utterance=action.utterance,
opponent_stated_batna=state.hidden_state.last_stated_batna,
opponent_true_batna=state.hidden_state.budget_ceiling,
):
next_state.bluffs_caught = state.bluffs_caught + 1
next_state.episode_done = (
next_state.step_count >= MAX_TURNS
or step_reward < -100.0
or next_state.deal_reached
or next_state.walk_away
)
if next_state.episode_done and next_state.termination_reason is None:
if next_state.deal_reached:
next_state.termination_reason = "deal_reached"
elif step_reward < -100.0:
next_state.termination_reason = "reward_floor"
elif next_state.walk_away:
next_state.termination_reason = "walk_away"
else:
next_state.termination_reason = "max_turns"
_sessions[session_id] = {"state": next_state, "tom_tracker": tom}
observation = _make_observation(next_state, step_reward, action.utterance, drift_event=drift_event)
return {"observation": observation.model_dump(), "done": next_state.episode_done}
async def _handle_state(msg: dict[str, Any]) -> dict:
"""Return raw session state."""
session_id = msg.get("session_id")
if not session_id or session_id not in _sessions:
raise SessionNotFoundError(f"Session {session_id} not found")
return {"state": _sessions[session_id]["state"].model_dump()}
def get_session_state(session_id: str) -> ParlayState | None:
"""Return the in-memory session state for SSE and tests."""
session = _sessions.get(session_id)
if not session:
return None
return session["state"]
@router.get("/sessions/{session_id}")
async def get_session(session_id: str) -> dict:
"""Get session state via REST."""
if session_id not in _sessions:
raise SessionNotFoundError(f"Session {session_id} not found")
return {"state": _sessions[session_id]["state"].model_dump()}
_env_app = FastAPI(title="Parlay OpenEnv", version="1.0.0")
_env_app.include_router(router)
if __name__ == "__main__":
import argparse
import uvicorn
parser = argparse.ArgumentParser(description="Run the Parlay OpenEnv server")
parser.add_argument("--port", type=int, default=int(os.getenv("ENV_PORT", "8001")))
args = parser.parse_args()
port = int(args.port)
uvicorn.run(_env_app, host="0.0.0.0", port=port)