File size: 4,939 Bytes
57eab70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# salespath_env/client.py
"""
HTTP client for the SalesPath environment.
Used by training scripts to talk to the hosted FastAPI server.
"""

from __future__ import annotations

import requests


class SalesPathClient:
    """
    Thin wrapper around the /reset and /step HTTP endpoints.

    Example
    -------
    >>> client = SalesPathClient("http://localhost:7860")
    >>> obs = client.reset(difficulty=1)
    >>> obs = client.step("PROSPECT", "Hi, tell me about your pain points.")
    >>> print(obs["reward"])
    """

    def __init__(self, base_url: str = "http://localhost:7860"):
        self.base_url = base_url.rstrip("/")
        self._session = requests.Session()

    # ------------------------------------------------------------------
    # Core API
    # ------------------------------------------------------------------

    def reset(self, difficulty: int = 1) -> dict:
        """
        Reset the environment for a new episode.

        OpenEnv /reset returns the raw observation dict.
        Returns a flat dict with all observation fields.
        """
        resp = self._session.post(
            f"{self.base_url}/reset",
            json={"difficulty": difficulty},
            timeout=30,
        )
        resp.raise_for_status()
        data = resp.json()
        # /reset may return raw observation or wrapped {observation:{...}}
        if "observation" in data:
            flat = dict(data["observation"])
            flat.setdefault("reward", data.get("reward", 0.0))
            flat.setdefault("done",   data.get("done", False))
            return flat
        return data

    def step(
        self,
        action_type: str,
        content: str = "",
        target: str = "",
    ) -> dict:
        """
        Take one action in the environment.

        OpenEnv /step returns {observation:{...}, reward:float, done:bool}.
        This method flattens it so callers get a single dict with all
        observation fields plus reward and done at the top level.

        Returns
        -------
        dict with keys:
            prospect_response, workflow_stage, constraints_violated,
            steps_completed, turn_number, reward, reward_components,
            done, info
        """
        resp = self._session.post(
            f"{self.base_url}/step",
            json={
                "action": {
                    "action_type": action_type,
                    "content": content,
                    "target": target,
                }
            },
            timeout=30,
        )
        resp.raise_for_status()
        data = resp.json()
        # Flatten: {observation:{...}, reward, done} → one flat dict
        if "observation" in data:
            flat = dict(data["observation"])
            flat["reward"] = data.get("reward", flat.get("reward", 0.0))
            flat["done"]   = data.get("done",   flat.get("done", False))
            return flat
        return data

    def health(self) -> dict:
        resp = self._session.get(f"{self.base_url}/health", timeout=10)
        resp.raise_for_status()
        return resp.json()

    # ------------------------------------------------------------------
    # Convenience: run a full hard-coded demo episode
    # ------------------------------------------------------------------

    def run_demo_episode(self, difficulty: int = 1, verbose: bool = True) -> float:
        """
        Run one scripted episode and return total cumulative reward.
        Useful for smoke-testing the server end-to-end.
        """
        obs = self.reset(difficulty)
        if verbose:
            print(f"\n=== Episode start (difficulty={difficulty}) ===")
            print(f"Prospect: {obs.get('prospect_response', '')}\n")

        # Scripted optimal sequence for difficulty 1
        script = [
            ("PROSPECT",         "Hello! I'd love to learn about your current challenges."),
            ("QUALIFY",          "Can you tell me about your budget and decision process?"),
            ("PRESENT",          "Here's how our platform solves your inventory problem."),
            ("CLOSE",            "Based on everything, shall we move forward?"),
        ]

        total_reward = 0.0
        for action_type, content in script:
            obs = self.step(action_type, content)
            total_reward += obs.get("reward", 0.0)
            if verbose:
                print(f"[Turn {obs['turn_number']}] Agent: {action_type}")
                print(f"  Prospect: {obs['prospect_response']}")
                print(f"  Reward: {obs['reward']:.3f}  |  Done: {obs['done']}")
                if obs.get("constraints_violated"):
                    print(f"  ⚠ Violations: {obs['constraints_violated']}")
                print()
            if obs["done"]:
                break

        if verbose:
            print(f"=== Episode done. Cumulative reward: {total_reward:.3f} ===\n")
        return total_reward