File size: 3,355 Bytes
57eab70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6a02dd
 
57eab70
 
 
 
 
 
e6a02dd
57eab70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6a02dd
 
 
 
 
 
57eab70
 
 
 
 
 
 
 
 
e6a02dd
57eab70
 
 
 
e6a02dd
57eab70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# salespath_env/server/app.py
"""
Custom stateful FastAPI server for SalesPath.

Why not create_fastapi_app?
  OpenEnv's built-in HTTP /reset and /step endpoints are STATELESS —
  they create a new Environment instance per request and destroy it.
  State is preserved only over WebSocket sessions.

  For our training loop (HTTP polling), we need a persistent environment
  that survives across /reset + multiple /step calls. This file provides
  that by keeping a single global SalesPathEnvironment instance.

  The response envelope matches OpenEnv exactly:
    { "observation": {...}, "reward": float, "done": bool }
  so all existing clients work without changes.
"""

from typing import Any, Dict, Optional

from fastapi import FastAPI
from pydantic import BaseModel

from ..models import SalesPathAction
from .salespath_environment import SalesPathEnvironment


# ---------------------------------------------------------------------------
# Single persistent environment instance
# ---------------------------------------------------------------------------

_env: SalesPathEnvironment = SalesPathEnvironment()


# ---------------------------------------------------------------------------
# Request models
# ---------------------------------------------------------------------------

class ResetRequest(BaseModel):
    difficulty: int = 1
    seed: Optional[int] = None
    episode_id: Optional[str] = None


class ActionPayload(BaseModel):
    action_type: str
    content: str = ""
    target: str = ""
    format_ok: bool = True


class StepRequest(BaseModel):
    action: ActionPayload


# ---------------------------------------------------------------------------
# FastAPI app
# ---------------------------------------------------------------------------

app = FastAPI(
    title="SalesPath Environment",
    description="OpenEnv-compatible RL environment for B2B sales agent training.",
    version="0.1.0",
)


@app.post("/reset")
def reset(req: ResetRequest = ResetRequest()):
    """Start a new episode."""
    obs = _env.reset(
        seed=req.seed,
        episode_id=req.episode_id,
        difficulty=req.difficulty,
    )
    return {
        "observation": obs.model_dump(),
        "reward": obs.reward,
        "done": obs.done,
    }


@app.post("/step")
def step(req: StepRequest):
    """Take one action in the current episode."""
    action = SalesPathAction(
        action_type=req.action.action_type,
        content=req.action.content,
        target=req.action.target,
        format_ok=req.action.format_ok,
    )
    obs = _env.step(action)
    return {
        "observation": obs.model_dump(),
        "reward": obs.reward,
        "done": obs.done,
    }


@app.get("/health")
def health():
    return {"status": "healthy"}


@app.get("/state")
def state():
    """Expose internal state (for debugging). Hidden state excluded."""
    s = _env.state
    return {
        "episode_id":           s.episode_id,
        "turn_number":          s.turn_number,
        "workflow_stage":       s.workflow_stage,
        "steps_completed":      s.steps_completed,
        "constraints_violated": s.constraints_violated,
        "objections_handled":   s.objections_handled,
        "difficulty":           s.difficulty,
        "done":                 s.done,
        "prospect_profile":     s.prospect_profile,
    }