Spaces:
Sleeping
Sleeping
File size: 4,939 Bytes
57eab70 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | # salespath_env/client.py
"""
HTTP client for the SalesPath environment.
Used by training scripts to talk to the hosted FastAPI server.
"""
from __future__ import annotations
import requests
class SalesPathClient:
"""
Thin wrapper around the /reset and /step HTTP endpoints.
Example
-------
>>> client = SalesPathClient("http://localhost:7860")
>>> obs = client.reset(difficulty=1)
>>> obs = client.step("PROSPECT", "Hi, tell me about your pain points.")
>>> print(obs["reward"])
"""
def __init__(self, base_url: str = "http://localhost:7860"):
self.base_url = base_url.rstrip("/")
self._session = requests.Session()
# ------------------------------------------------------------------
# Core API
# ------------------------------------------------------------------
def reset(self, difficulty: int = 1) -> dict:
"""
Reset the environment for a new episode.
OpenEnv /reset returns the raw observation dict.
Returns a flat dict with all observation fields.
"""
resp = self._session.post(
f"{self.base_url}/reset",
json={"difficulty": difficulty},
timeout=30,
)
resp.raise_for_status()
data = resp.json()
# /reset may return raw observation or wrapped {observation:{...}}
if "observation" in data:
flat = dict(data["observation"])
flat.setdefault("reward", data.get("reward", 0.0))
flat.setdefault("done", data.get("done", False))
return flat
return data
def step(
self,
action_type: str,
content: str = "",
target: str = "",
) -> dict:
"""
Take one action in the environment.
OpenEnv /step returns {observation:{...}, reward:float, done:bool}.
This method flattens it so callers get a single dict with all
observation fields plus reward and done at the top level.
Returns
-------
dict with keys:
prospect_response, workflow_stage, constraints_violated,
steps_completed, turn_number, reward, reward_components,
done, info
"""
resp = self._session.post(
f"{self.base_url}/step",
json={
"action": {
"action_type": action_type,
"content": content,
"target": target,
}
},
timeout=30,
)
resp.raise_for_status()
data = resp.json()
# Flatten: {observation:{...}, reward, done} → one flat dict
if "observation" in data:
flat = dict(data["observation"])
flat["reward"] = data.get("reward", flat.get("reward", 0.0))
flat["done"] = data.get("done", flat.get("done", False))
return flat
return data
def health(self) -> dict:
resp = self._session.get(f"{self.base_url}/health", timeout=10)
resp.raise_for_status()
return resp.json()
# ------------------------------------------------------------------
# Convenience: run a full hard-coded demo episode
# ------------------------------------------------------------------
def run_demo_episode(self, difficulty: int = 1, verbose: bool = True) -> float:
"""
Run one scripted episode and return total cumulative reward.
Useful for smoke-testing the server end-to-end.
"""
obs = self.reset(difficulty)
if verbose:
print(f"\n=== Episode start (difficulty={difficulty}) ===")
print(f"Prospect: {obs.get('prospect_response', '')}\n")
# Scripted optimal sequence for difficulty 1
script = [
("PROSPECT", "Hello! I'd love to learn about your current challenges."),
("QUALIFY", "Can you tell me about your budget and decision process?"),
("PRESENT", "Here's how our platform solves your inventory problem."),
("CLOSE", "Based on everything, shall we move forward?"),
]
total_reward = 0.0
for action_type, content in script:
obs = self.step(action_type, content)
total_reward += obs.get("reward", 0.0)
if verbose:
print(f"[Turn {obs['turn_number']}] Agent: {action_type}")
print(f" Prospect: {obs['prospect_response']}")
print(f" Reward: {obs['reward']:.3f} | Done: {obs['done']}")
if obs.get("constraints_violated"):
print(f" ⚠ Violations: {obs['constraints_violated']}")
print()
if obs["done"]:
break
if verbose:
print(f"=== Episode done. Cumulative reward: {total_reward:.3f} ===\n")
return total_reward
|