Spaces:
Sleeping
Sleeping
| # salespath_env/client.py | |
| """ | |
| HTTP client for the SalesPath environment. | |
| Used by training scripts to talk to the hosted FastAPI server. | |
| """ | |
| from __future__ import annotations | |
| import requests | |
| class SalesPathClient: | |
| """ | |
| Thin wrapper around the /reset and /step HTTP endpoints. | |
| Example | |
| ------- | |
| >>> client = SalesPathClient("http://localhost:7860") | |
| >>> obs = client.reset(difficulty=1) | |
| >>> obs = client.step("PROSPECT", "Hi, tell me about your pain points.") | |
| >>> print(obs["reward"]) | |
| """ | |
| def __init__(self, base_url: str = "http://localhost:7860"): | |
| self.base_url = base_url.rstrip("/") | |
| self._session = requests.Session() | |
| # ------------------------------------------------------------------ | |
| # Core API | |
| # ------------------------------------------------------------------ | |
| def reset(self, difficulty: int = 1) -> dict: | |
| """ | |
| Reset the environment for a new episode. | |
| OpenEnv /reset returns the raw observation dict. | |
| Returns a flat dict with all observation fields. | |
| """ | |
| resp = self._session.post( | |
| f"{self.base_url}/reset", | |
| json={"difficulty": difficulty}, | |
| timeout=30, | |
| ) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| # /reset may return raw observation or wrapped {observation:{...}} | |
| if "observation" in data: | |
| flat = dict(data["observation"]) | |
| flat.setdefault("reward", data.get("reward", 0.0)) | |
| flat.setdefault("done", data.get("done", False)) | |
| return flat | |
| return data | |
| def step( | |
| self, | |
| action_type: str, | |
| content: str = "", | |
| target: str = "", | |
| ) -> dict: | |
| """ | |
| Take one action in the environment. | |
| OpenEnv /step returns {observation:{...}, reward:float, done:bool}. | |
| This method flattens it so callers get a single dict with all | |
| observation fields plus reward and done at the top level. | |
| Returns | |
| ------- | |
| dict with keys: | |
| prospect_response, workflow_stage, constraints_violated, | |
| steps_completed, turn_number, reward, reward_components, | |
| done, info | |
| """ | |
| resp = self._session.post( | |
| f"{self.base_url}/step", | |
| json={ | |
| "action": { | |
| "action_type": action_type, | |
| "content": content, | |
| "target": target, | |
| } | |
| }, | |
| timeout=30, | |
| ) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| # Flatten: {observation:{...}, reward, done} → one flat dict | |
| if "observation" in data: | |
| flat = dict(data["observation"]) | |
| flat["reward"] = data.get("reward", flat.get("reward", 0.0)) | |
| flat["done"] = data.get("done", flat.get("done", False)) | |
| return flat | |
| return data | |
| def health(self) -> dict: | |
| resp = self._session.get(f"{self.base_url}/health", timeout=10) | |
| resp.raise_for_status() | |
| return resp.json() | |
| # ------------------------------------------------------------------ | |
| # Convenience: run a full hard-coded demo episode | |
| # ------------------------------------------------------------------ | |
| def run_demo_episode(self, difficulty: int = 1, verbose: bool = True) -> float: | |
| """ | |
| Run one scripted episode and return total cumulative reward. | |
| Useful for smoke-testing the server end-to-end. | |
| """ | |
| obs = self.reset(difficulty) | |
| if verbose: | |
| print(f"\n=== Episode start (difficulty={difficulty}) ===") | |
| print(f"Prospect: {obs.get('prospect_response', '')}\n") | |
| # Scripted optimal sequence for difficulty 1 | |
| script = [ | |
| ("PROSPECT", "Hello! I'd love to learn about your current challenges."), | |
| ("QUALIFY", "Can you tell me about your budget and decision process?"), | |
| ("PRESENT", "Here's how our platform solves your inventory problem."), | |
| ("CLOSE", "Based on everything, shall we move forward?"), | |
| ] | |
| total_reward = 0.0 | |
| for action_type, content in script: | |
| obs = self.step(action_type, content) | |
| total_reward += obs.get("reward", 0.0) | |
| if verbose: | |
| print(f"[Turn {obs['turn_number']}] Agent: {action_type}") | |
| print(f" Prospect: {obs['prospect_response']}") | |
| print(f" Reward: {obs['reward']:.3f} | Done: {obs['done']}") | |
| if obs.get("constraints_violated"): | |
| print(f" ⚠ Violations: {obs['constraints_violated']}") | |
| print() | |
| if obs["done"]: | |
| break | |
| if verbose: | |
| print(f"=== Episode done. Cumulative reward: {total_reward:.3f} ===\n") | |
| return total_reward | |