Lomesh2000 commited on
Commit
57eab70
Β·
1 Parent(s): 385203e

multi agent environment learning

Browse files
Colab_Training.ipynb ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "markdown",
19
+ "source": [
20
+ "# SalesPath: OpenEnv RL Training via GRPO\n",
21
+ "\n",
22
+ "This notebook contains the complete training pipeline for the SalesPath environment. It performs:\n",
23
+ "1. **SFT Warm-start**: Fine-tunes a base model on expert sales demonstrations.\n",
24
+ "2. **GRPO RL**: Uses live rollouts against your hosted environment to optimize the agent.\n",
25
+ "\n",
26
+ "> **CRITICAL:** Before running this, ensure you are using a **T4 GPU** (`Runtime` -> `Change runtime type` -> `Hardware accelerator` -> `T4 GPU`).\n",
27
+ "> \n",
28
+ "> You must also have pushed your environment code to a **Hugging Face Space** so this notebook can interact with it."
29
+ ],
30
+ "metadata": {}
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "# 1. Install required dependencies\n",
39
+ "!pip install \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
40
+ "!pip install --no-deps trl peft accelerate bitsandbytes datasets matplotlib openenv-core"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "# 2. Clone your environment repository from Hugging Face Spaces\n",
50
+ "# ⚠️ REPLACE WITH YOUR ACTUAL HF SPACE URL\n",
51
+ "HF_SPACE_URL = \"https://huggingface.co/spaces/YOUR_USERNAME/salespath-env\"\n",
52
+ "\n",
53
+ "import os\n",
54
+ "repo_name = HF_SPACE_URL.split(\"/\")[-1]\n",
55
+ "\n",
56
+ "!git clone {HF_SPACE_URL}\n",
57
+ "os.chdir(repo_name)\n",
58
+ "print(f\"\\nWorking directory changed to: {os.getcwd()}\")"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": null,
64
+ "metadata": {},
65
+ "outputs": [],
66
+ "source": [
67
+ "# 3. Run SFT Warm-start (~10-15 minutes)\n",
68
+ "# This trains the model to understand the basic output format and sales flow.\n",
69
+ "!python training/train_sft.py"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": null,
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "# 4. Run GRPO Reinforcement Learning (~45-60 minutes)\n",
79
+ "import os\n",
80
+ "\n",
81
+ "# Derive the direct API URL for the Hugging Face space\n",
82
+ "username = HF_SPACE_URL.split(\"/\")[-2]\n",
83
+ "space_name = HF_SPACE_URL.split(\"/\")[-1]\n",
84
+ "direct_url = f\"https://{username}-{space_name}.hf.space\"\n",
85
+ "\n",
86
+ "os.environ[\"SALESPATH_ENV_URL\"] = direct_url\n",
87
+ "os.environ[\"SFT_CHECKPOINT\"] = \"./sft_checkpoint\"\n",
88
+ "\n",
89
+ "print(f\"Targeting Environment API: {direct_url}\")\n",
90
+ "\n",
91
+ "# Run the GRPO training script\n",
92
+ "!python training/train_grpo.py"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": null,
98
+ "metadata": {},
99
+ "outputs": [],
100
+ "source": [
101
+ "# 5. Plot the Training Rewards\n",
102
+ "!python training/plot_rewards.py --log ./reward_log.jsonl --out ./plots\n",
103
+ "\n",
104
+ "from IPython.display import Image, display\n",
105
+ "print(\"\\n=== Reward Curve ===\")\n",
106
+ "display(Image(\"./plots/reward_curve.png\"))\n",
107
+ "\n",
108
+ "print(\"\\n=== Reward by Difficulty ===\")\n",
109
+ "display(Image(\"./plots/reward_by_difficulty.png\"))"
110
+ ]
111
+ }
112
+ ]
113
+ }
README.md CHANGED
@@ -7,22 +7,21 @@ sdk: docker
7
  app_port: 7860
8
  pinned: false
9
  license: mit
10
- short_description: RL gym environment for sales agent training
11
  ---
12
 
13
- # SalesPath Environment
14
 
15
- A [OpenEnv](https://github.com/openenv)-compatible Reinforcement Learning gym environment for training sales agents via LLM fine-tuning.
 
16
 
17
- ## API Endpoints
 
 
18
 
19
- | Method | Endpoint | Description |
20
- |--------|----------|-------------|
21
- | `POST` | `/reset` | Reset the environment, returns initial observation |
22
- | `POST` | `/step` | Take an action, returns next observation + reward |
23
- | `GET` | `/health` | Health check |
24
 
25
- ## Quick Start
26
 
27
  ### Reset
28
  ```bash
@@ -35,13 +34,141 @@ curl -X POST https://imsachin010-salespath-env.hf.space/reset \
35
  ```bash
36
  curl -X POST https://imsachin010-salespath-env.hf.space/step \
37
  -H "Content-Type: application/json" \
38
- -d '{"action": {"action_type": "PROSPECT", "content": "Hello, tell me about your workflow challenges."}}'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  ```
40
 
41
- ## Action Types
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- - `PROSPECT` β€” Initial outreach and discovery
44
- - `QUALIFY` β€” Qualify the lead
45
- - `PITCH` β€” Deliver the sales pitch
46
- - `HANDLE_OBJECTION` β€” Handle prospect objections
47
- - `CLOSE` β€” Attempt to close the deal
 
7
  app_port: 7860
8
  pinned: false
9
  license: mit
10
+ short_description: RL gym environment for training B2B sales agents via GRPO
11
  ---
12
 
13
+ # SalesPath β€” RL Environment for B2B Sales Agents
14
 
15
+ A [OpenEnv](https://github.com/openenv)-compatible reinforcement learning gym
16
+ that trains an LLM to navigate the full B2B sales process through GRPO.
17
 
18
+ The agent must learn to qualify leads, handle objections, offer demos,
19
+ negotiate, and close β€” all while respecting business rules enforced by a
20
+ deterministic rule-based ProspectSimulator (no LLM on the environment side).
21
 
22
+ ---
 
 
 
 
23
 
24
+ ## Quick Start (hosted on HF Spaces)
25
 
26
  ### Reset
27
  ```bash
 
34
  ```bash
35
  curl -X POST https://imsachin010-salespath-env.hf.space/step \
36
  -H "Content-Type: application/json" \
37
+ -d '{
38
+ "action": {
39
+ "action_type": "PROSPECT",
40
+ "content": "Hello! I understand you have inventory tracking challenges. Tell me more."
41
+ }
42
+ }'
43
+ ```
44
+
45
+ ### Health check
46
+ ```bash
47
+ curl https://imsachin010-salespath-env.hf.space/health
48
+ ```
49
+
50
+ ---
51
+
52
+ ## Action Space
53
+
54
+ | Action | When to use |
55
+ |---|---|
56
+ | `PROSPECT` | Opening turn only β€” initial outreach |
57
+ | `QUALIFY` | Uncover budget, decision maker, pain points |
58
+ | `PRESENT` | Pitch the solution (requires QUALIFY first) |
59
+ | `HANDLE_OBJECTION` | Respond to pricing / timing objections |
60
+ | `OFFER_DEMO` | Schedule a live product demo |
61
+ | `NEGOTIATE` | Discuss pricing/terms (requires OFFER_DEMO + known budget) |
62
+ | `CLOSE` | Attempt to sign the deal |
63
+ | `FOLLOW_UP` | Re-engage after prospect silence |
64
+ | `DISQUALIFY` | End the conversation (correct only if budget < threshold AND no decision maker) |
65
+
66
+ ---
67
+
68
+ ## Business Rules Enforced
69
+
70
+ | Rule | Description |
71
+ |---|---|
72
+ | R01 | Must QUALIFY before PRESENT |
73
+ | R02 | Must OFFER_DEMO before NEGOTIATE |
74
+ | R03 | Cannot NEGOTIATE while budget is unknown |
75
+ | R04 | Discount in NEGOTIATE only after 2 objections handled |
76
+ | R05 | Cannot repeat the same action on consecutive turns |
77
+ | R06 | First action must be PROSPECT |
78
+ | R07 | FOLLOW_UP only valid after prospect silence (no response for 1+ turns) |
79
+ | R08 | DISQUALIFY valid only when budget < threshold AND no decision maker |
80
+ | R09 | Must OFFER_DEMO before CLOSE (difficulty 2+) |
81
+
82
+ ---
83
+
84
+ ## Reward Function
85
+
86
+ Composite weighted reward computed every step:
87
+
88
+ | Component | Weight | Description |
89
+ |---|---|---|
90
+ | `r_outcome` | 0.40 | +1.0 on successful close, +0.5 on valid DISQUALIFY, -0.5 on bad close |
91
+ | `r_compliance` | 0.30 | -0.2 per rule violation this turn |
92
+ | `r_ordering` | 0.15 | Fraction of workflow steps completed in correct order |
93
+ | `r_efficiency` | 0.10 | Penalty for turns beyond the optimal episode length |
94
+ | `r_format` | 0.05 | +1.0 for valid action type, -0.1 for invalid |
95
+
96
+ ---
97
+
98
+ ## Difficulty Levels
99
+
100
+ | Level | Description | Correct terminal action |
101
+ |---|---|---|
102
+ | 1 | Budget known, decision maker present, easy close | CLOSE |
103
+ | 2 | Budget hidden, 1 objection, demo required | CLOSE |
104
+ | 3 | Budget hidden, 2 objections, stalling prospect | CLOSE |
105
+ | 4 | Misleading signals, low budget, no decision maker | DISQUALIFY |
106
+
107
+ ---
108
+
109
+ ## Training Pipeline
110
+
111
+ ```
112
+ sft_demos.jsonl (14 expert demos)
113
+ ↓
114
+ train_sft.py ← SFT warm-start (SFTTrainer, TRL)
115
+ ↓
116
+ sft_checkpoint/
117
+ ↓
118
+ train_grpo.py ← GRPO RL fine-tuning (GRPOTrainer, TRL + Unsloth 4-bit)
119
+ ↓
120
+ grpo_checkpoint/ + reward_log.jsonl
121
+ ↓
122
+ plot_rewards.py ← reward curves
123
  ```
124
 
125
+ ### Commands
126
+
127
+ ```bash
128
+ # 1. Smoke test (no GPU, ~30 seconds)
129
+ python training/train_test.py
130
+
131
+ # 2. SFT warm-start (~15 min on T4)
132
+ python training/train_sft.py
133
+
134
+ # 3. Full GRPO training (~60 min on T4)
135
+ uvicorn salespath_env.server.app:app --port 7860 &
136
+ python training/train_grpo.py
137
+
138
+ # 4. Plot reward curves
139
+ python training/plot_rewards.py
140
+ ```
141
+
142
+ ---
143
+
144
+ ## File Structure
145
+
146
+ ```
147
+ salespath-env/
148
+ β”œβ”€β”€ salespath_env/
149
+ β”‚ β”œβ”€β”€ client.py ← HTTP client for training scripts
150
+ β”‚ β”œβ”€β”€ models.py ← SalesPathAction / Observation / State
151
+ β”‚ └── server/
152
+ β”‚ β”œβ”€β”€ app.py ← FastAPI app (OpenEnv)
153
+ β”‚ β”œβ”€β”€ salespath_environment.py
154
+ β”‚ β”œβ”€β”€ prospect_simulator.py ← Rule-based, no LLM
155
+ β”‚ β”œβ”€β”€ rules.py ← 9 business rules (R01–R09)
156
+ β”‚ β”œβ”€β”€ reward.py ← 5-component reward function
157
+ β”‚ └── task_bank.py ← Prospect profiles (4 difficulty levels)
158
+ β”œβ”€β”€ training/
159
+ β”‚ β”œβ”€β”€ sft_demos.jsonl ← Expert demonstration data
160
+ β”‚ β”œβ”€β”€ train_test.py ← Smoke test (no GPU)
161
+ β”‚ β”œβ”€β”€ train_sft.py ← SFT warm-start
162
+ β”‚ β”œβ”€β”€ train_grpo.py ← GRPO RL training
163
+ β”‚ └── plot_rewards.py ← Reward curve visualisation
164
+ β”œβ”€β”€ Dockerfile
165
+ └── requirements.txt
166
+ ```
167
+
168
+ ---
169
+
170
+ ## Links
171
 
172
+ - πŸ“ Blog post: _[add HuggingFace blog link here]_
173
+ - πŸŽ₯ Demo video: _[add YouTube link here]_
174
+ - πŸ€— HF Space: https://huggingface.co/spaces/imsachin010/salespath-env
 
 
pyproject.toml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "salespath-env"
3
+ version = "0.1.0"
4
+ description = "OpenEnv RL environment for training B2B sales agents via GRPO"
5
+ requires-python = ">=3.10"
6
+ license = { text = "MIT" }
7
+
8
+ dependencies = [
9
+ "openenv-core>=0.2.3",
10
+ "fastapi>=0.110.0",
11
+ "uvicorn[standard]>=0.29.0",
12
+ "pydantic>=2.0",
13
+ ]
14
+
15
+ [project.optional-dependencies]
16
+ training = [
17
+ "trl>=0.8.6",
18
+ "transformers>=4.40.0",
19
+ "datasets>=2.18.0",
20
+ "peft>=0.10.0",
21
+ "bitsandbytes>=0.43.0",
22
+ "accelerate>=0.28.0",
23
+ "torch>=2.2.0",
24
+ "matplotlib>=3.8.0",
25
+ "unsloth",
26
+ ]
27
+
28
+ [build-system]
29
+ requires = ["setuptools>=68", "wheel"]
30
+ build-backend = "setuptools.backends.legacy:build"
31
+
32
+ [tool.setuptools.packages.find]
33
+ where = ["."]
34
+ include = ["salespath_env*"]
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  fastapi>=0.110.0
2
  uvicorn[standard]>=0.29.0
3
  pydantic>=2.0
 
1
+ # Environment server (used by Dockerfile)
2
  fastapi>=0.110.0
3
  uvicorn[standard]>=0.29.0
4
  pydantic>=2.0
salespath_env/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (178 Bytes). View file
 
salespath_env/__pycache__/client.cpython-312.pyc ADDED
Binary file (5.72 kB). View file
 
salespath_env/__pycache__/models.cpython-312.pyc ADDED
Binary file (3.15 kB). View file
 
salespath_env/client.py CHANGED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # salespath_env/client.py
2
+ """
3
+ HTTP client for the SalesPath environment.
4
+ Used by training scripts to talk to the hosted FastAPI server.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import requests
10
+
11
+
12
+ class SalesPathClient:
13
+ """
14
+ Thin wrapper around the /reset and /step HTTP endpoints.
15
+
16
+ Example
17
+ -------
18
+ >>> client = SalesPathClient("http://localhost:7860")
19
+ >>> obs = client.reset(difficulty=1)
20
+ >>> obs = client.step("PROSPECT", "Hi, tell me about your pain points.")
21
+ >>> print(obs["reward"])
22
+ """
23
+
24
+ def __init__(self, base_url: str = "http://localhost:7860"):
25
+ self.base_url = base_url.rstrip("/")
26
+ self._session = requests.Session()
27
+
28
+ # ------------------------------------------------------------------
29
+ # Core API
30
+ # ------------------------------------------------------------------
31
+
32
+ def reset(self, difficulty: int = 1) -> dict:
33
+ """
34
+ Reset the environment for a new episode.
35
+
36
+ OpenEnv /reset returns the raw observation dict.
37
+ Returns a flat dict with all observation fields.
38
+ """
39
+ resp = self._session.post(
40
+ f"{self.base_url}/reset",
41
+ json={"difficulty": difficulty},
42
+ timeout=30,
43
+ )
44
+ resp.raise_for_status()
45
+ data = resp.json()
46
+ # /reset may return raw observation or wrapped {observation:{...}}
47
+ if "observation" in data:
48
+ flat = dict(data["observation"])
49
+ flat.setdefault("reward", data.get("reward", 0.0))
50
+ flat.setdefault("done", data.get("done", False))
51
+ return flat
52
+ return data
53
+
54
+ def step(
55
+ self,
56
+ action_type: str,
57
+ content: str = "",
58
+ target: str = "",
59
+ ) -> dict:
60
+ """
61
+ Take one action in the environment.
62
+
63
+ OpenEnv /step returns {observation:{...}, reward:float, done:bool}.
64
+ This method flattens it so callers get a single dict with all
65
+ observation fields plus reward and done at the top level.
66
+
67
+ Returns
68
+ -------
69
+ dict with keys:
70
+ prospect_response, workflow_stage, constraints_violated,
71
+ steps_completed, turn_number, reward, reward_components,
72
+ done, info
73
+ """
74
+ resp = self._session.post(
75
+ f"{self.base_url}/step",
76
+ json={
77
+ "action": {
78
+ "action_type": action_type,
79
+ "content": content,
80
+ "target": target,
81
+ }
82
+ },
83
+ timeout=30,
84
+ )
85
+ resp.raise_for_status()
86
+ data = resp.json()
87
+ # Flatten: {observation:{...}, reward, done} β†’ one flat dict
88
+ if "observation" in data:
89
+ flat = dict(data["observation"])
90
+ flat["reward"] = data.get("reward", flat.get("reward", 0.0))
91
+ flat["done"] = data.get("done", flat.get("done", False))
92
+ return flat
93
+ return data
94
+
95
+ def health(self) -> dict:
96
+ resp = self._session.get(f"{self.base_url}/health", timeout=10)
97
+ resp.raise_for_status()
98
+ return resp.json()
99
+
100
+ # ------------------------------------------------------------------
101
+ # Convenience: run a full hard-coded demo episode
102
+ # ------------------------------------------------------------------
103
+
104
+ def run_demo_episode(self, difficulty: int = 1, verbose: bool = True) -> float:
105
+ """
106
+ Run one scripted episode and return total cumulative reward.
107
+ Useful for smoke-testing the server end-to-end.
108
+ """
109
+ obs = self.reset(difficulty)
110
+ if verbose:
111
+ print(f"\n=== Episode start (difficulty={difficulty}) ===")
112
+ print(f"Prospect: {obs.get('prospect_response', '')}\n")
113
+
114
+ # Scripted optimal sequence for difficulty 1
115
+ script = [
116
+ ("PROSPECT", "Hello! I'd love to learn about your current challenges."),
117
+ ("QUALIFY", "Can you tell me about your budget and decision process?"),
118
+ ("PRESENT", "Here's how our platform solves your inventory problem."),
119
+ ("CLOSE", "Based on everything, shall we move forward?"),
120
+ ]
121
+
122
+ total_reward = 0.0
123
+ for action_type, content in script:
124
+ obs = self.step(action_type, content)
125
+ total_reward += obs.get("reward", 0.0)
126
+ if verbose:
127
+ print(f"[Turn {obs['turn_number']}] Agent: {action_type}")
128
+ print(f" Prospect: {obs['prospect_response']}")
129
+ print(f" Reward: {obs['reward']:.3f} | Done: {obs['done']}")
130
+ if obs.get("constraints_violated"):
131
+ print(f" ⚠ Violations: {obs['constraints_violated']}")
132
+ print()
133
+ if obs["done"]:
134
+ break
135
+
136
+ if verbose:
137
+ print(f"=== Episode done. Cumulative reward: {total_reward:.3f} ===\n")
138
+ return total_reward
salespath_env/server/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (173 Bytes). View file
 
salespath_env/server/__pycache__/app.cpython-312.pyc ADDED
Binary file (4.34 kB). View file
 
salespath_env/server/__pycache__/prospect_simulator.cpython-312.pyc ADDED
Binary file (5.03 kB). View file
 
salespath_env/server/__pycache__/reward.cpython-312.pyc ADDED
Binary file (3.01 kB). View file
 
salespath_env/server/__pycache__/rules.cpython-312.pyc ADDED
Binary file (7.36 kB). View file
 
salespath_env/server/__pycache__/salespath_environment.cpython-312.pyc ADDED
Binary file (6.62 kB). View file
 
salespath_env/server/__pycache__/task_bank.cpython-312.pyc ADDED
Binary file (2.66 kB). View file
 
salespath_env/server/app.py CHANGED
@@ -1,18 +1,116 @@
1
- # salespath_env/server/app.py
2
-
3
- from openenv.core.env_server import create_fastapi_app
4
-
5
- from ..models import (
6
- SalesPathAction,
7
- SalesPathObservation,
8
- )
9
- from .salespath_environment import (
10
- SalesPathEnvironment,
11
- )
12
-
13
-
14
- app = create_fastapi_app(
15
- SalesPathEnvironment,
16
- SalesPathAction,
17
- SalesPathObservation,
18
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # salespath_env/server/app.py
2
+ """
3
+ Custom stateful FastAPI server for SalesPath.
4
+
5
+ Why not create_fastapi_app?
6
+ OpenEnv's built-in HTTP /reset and /step endpoints are STATELESS β€”
7
+ they create a new Environment instance per request and destroy it.
8
+ State is preserved only over WebSocket sessions.
9
+
10
+ For our training loop (HTTP polling), we need a persistent environment
11
+ that survives across /reset + multiple /step calls. This file provides
12
+ that by keeping a single global SalesPathEnvironment instance.
13
+
14
+ The response envelope matches OpenEnv exactly:
15
+ { "observation": {...}, "reward": float, "done": bool }
16
+ so all existing clients work without changes.
17
+ """
18
+
19
+ from typing import Any, Dict, Optional
20
+
21
+ from fastapi import FastAPI
22
+ from pydantic import BaseModel
23
+
24
+ from ..models import SalesPathAction
25
+ from .salespath_environment import SalesPathEnvironment
26
+
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # Single persistent environment instance
30
+ # ---------------------------------------------------------------------------
31
+
32
+ _env: SalesPathEnvironment = SalesPathEnvironment()
33
+
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Request models
37
+ # ---------------------------------------------------------------------------
38
+
39
+ class ResetRequest(BaseModel):
40
+ difficulty: int = 1
41
+
42
+
43
+ class ActionPayload(BaseModel):
44
+ action_type: str
45
+ content: str = ""
46
+ target: str = ""
47
+
48
+
49
+ class StepRequest(BaseModel):
50
+ action: ActionPayload
51
+
52
+
53
+ # ---------------------------------------------------------------------------
54
+ # FastAPI app
55
+ # ---------------------------------------------------------------------------
56
+
57
+ app = FastAPI(
58
+ title="SalesPath Environment",
59
+ description="OpenEnv-compatible RL environment for B2B sales agent training.",
60
+ version="0.1.0",
61
+ )
62
+
63
+
64
+ @app.post("/reset")
65
+ def reset(req: ResetRequest = ResetRequest()):
66
+ """
67
+ Start a new episode.
68
+ Resets the environment and returns the initial observation.
69
+ """
70
+ obs = _env.reset(difficulty=req.difficulty)
71
+ return {
72
+ "observation": obs.model_dump(),
73
+ "reward": obs.reward,
74
+ "done": obs.done,
75
+ }
76
+
77
+
78
+ @app.post("/step")
79
+ def step(req: StepRequest):
80
+ """
81
+ Take one action in the current episode.
82
+ Returns the next observation, reward, and done flag.
83
+ """
84
+ action = SalesPathAction(
85
+ action_type=req.action.action_type,
86
+ content=req.action.content,
87
+ target=req.action.target,
88
+ )
89
+ obs = _env.step(action)
90
+ return {
91
+ "observation": obs.model_dump(),
92
+ "reward": obs.reward,
93
+ "done": obs.done,
94
+ }
95
+
96
+
97
+ @app.get("/health")
98
+ def health():
99
+ return {"status": "healthy"}
100
+
101
+
102
+ @app.get("/state")
103
+ def state():
104
+ """Expose internal state (for debugging). Hidden state excluded."""
105
+ s = _env.state
106
+ return {
107
+ "episode_id": s.episode_id,
108
+ "turn_number": s.turn_number,
109
+ "workflow_stage": s.workflow_stage,
110
+ "steps_completed": s.steps_completed,
111
+ "constraints_violated": s.constraints_violated,
112
+ "objections_handled": s.objections_handled,
113
+ "difficulty": s.difficulty,
114
+ "done": s.done,
115
+ "prospect_profile": s.prospect_profile,
116
+ }
salespath_env/server/prospect_simulator.py CHANGED
@@ -1,162 +1,176 @@
1
- # salespath_env/server/prospect_simulator.py
2
-
3
- from ..models import SalesPathAction, SalesPathState
4
-
5
-
6
- RESPONSE_TEXT = {
7
- "open:positive_signal": "That sounds interesting. Tell me more about how this works.",
8
- "open:neutral_signal": "I see. We're evaluating a few options at the moment.",
9
-
10
- "objection:price": "The pricing seems higher than what we budgeted for.",
11
- "objection:timing": "The timing isn't ideal β€” we're in the middle of a quarter close.",
12
- "objection:premature_pitch": (
13
- "I'm not sure we're ready to discuss solutions yet. "
14
- "What do you know about our current situation?"
15
- ),
16
-
17
- "deflect:budget_not_discussed": (
18
- "We haven't really talked about what we're looking for yet."
19
- ),
20
- "deflect:stall": (
21
- "Let me get back to you on this. A lot is happening on our end."
22
- ),
23
-
24
- "accept:demo_scheduled": (
25
- "Yes, let's set up a demo. What time works next week?"
26
- ),
27
- "accept:close_success": (
28
- "Alright, I think we can move forward with this. "
29
- "Send over the paperwork."
30
- ),
31
-
32
- "reject:close_failed": (
33
- "I don't think we're ready to commit at this point."
34
- ),
35
-
36
- "silence": "",
37
-
38
- "exit:disqualified": (
39
- "I think we're done here. This isn't the right fit."
40
- ),
41
- }
42
-
43
-
44
- class ProspectSimulator:
45
- """
46
- Pure rule-based simulator.
47
- No LLM. No transformers. Deterministic behavior.
48
- """
49
-
50
- def respond(
51
- self,
52
- action: SalesPathAction,
53
- state: SalesPathState,
54
- ) -> tuple[str, str]:
55
- """
56
- Returns:
57
- (response_token, response_text)
58
- """
59
-
60
- token = self._get_token(action, state)
61
- text = RESPONSE_TEXT[token]
62
-
63
- return token, text
64
-
65
- def _get_token(
66
- self,
67
- action: SalesPathAction,
68
- state: SalesPathState,
69
- ) -> str:
70
- atype = action.action_type
71
- difficulty = state.difficulty
72
- turn = state.turn_number
73
- profile = state.prospect_profile
74
- hidden = state.hidden_state
75
- objections = state.objections_handled
76
-
77
- # -----------------------------
78
- # Rule-triggered responses first
79
- # -----------------------------
80
-
81
- if state.constraints_violated:
82
- latest = state.constraints_violated[-1]
83
-
84
- if latest == "R01":
85
- return "objection:premature_pitch"
86
-
87
- if latest == "R03":
88
- return "deflect:budget_not_discussed"
89
-
90
- # -----------------------------
91
- # Action-based responses
92
- # -----------------------------
93
-
94
- if atype == "PROSPECT":
95
- return "open:positive_signal"
96
-
97
- if atype == "QUALIFY":
98
- # Reveal budget if hidden
99
- if profile.get("budget_signal") == "unknown":
100
- state.prospect_profile["budget_signal"] = hidden.get(
101
- "revealed_budget",
102
- "medium",
103
- )
104
-
105
- return "open:neutral_signal"
106
-
107
- if atype == "PRESENT":
108
- if difficulty >= 2:
109
- if objections == 0:
110
- return "objection:price"
111
-
112
- return "open:positive_signal"
113
-
114
- if atype == "HANDLE_OBJECTION":
115
- state.objections_handled += 1
116
-
117
- required_objections = hidden.get("num_objections", 1)
118
-
119
- if state.objections_handled >= required_objections:
120
- return "open:positive_signal"
121
-
122
- if objections == 0:
123
- return "objection:timing"
124
-
125
- return "open:positive_signal"
126
-
127
- if atype == "OFFER_DEMO":
128
- return "accept:demo_scheduled"
129
-
130
- if atype == "NEGOTIATE":
131
- return "open:neutral_signal"
132
-
133
- if atype == "CLOSE":
134
- true_budget = hidden.get("true_budget", 0.7)
135
- close_threshold = hidden.get("close_threshold", 0.5)
136
- decision_maker = profile.get("decision_maker", True)
137
-
138
- if (
139
- true_budget >= close_threshold
140
- and decision_maker
141
- ):
142
- return "accept:close_success"
143
-
144
- return "reject:close_failed"
145
-
146
- if atype == "FOLLOW_UP":
147
- return "open:neutral_signal"
148
-
149
- if atype == "DISQUALIFY":
150
- return "exit:disqualified"
151
-
152
- # -----------------------------
153
- # Difficulty 3+ mode shift
154
- # -----------------------------
155
-
156
- if difficulty >= 3 and turn >= 10:
157
- import random
158
-
159
- if random.random() < hidden.get("stall_probability", 0.0):
160
- return "deflect:stall"
161
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  return "open:neutral_signal"
 
1
+ # salespath_env/server/prospect_simulator.py
2
+
3
+ import random
4
+
5
+ from ..models import SalesPathAction, SalesPathState
6
+
7
+
8
+ RESPONSE_TEXT = {
9
+ "open:positive_signal": "That sounds interesting. Tell me more about how this works.",
10
+ "open:neutral_signal": "I see. We're evaluating a few options at the moment.",
11
+
12
+ "objection:price": "The pricing seems higher than what we budgeted for.",
13
+ "objection:timing": "The timing isn't ideal β€” we're in the middle of a quarter close.",
14
+ "objection:premature_pitch": (
15
+ "I'm not sure we're ready to discuss solutions yet. "
16
+ "What do you know about our current situation?"
17
+ ),
18
+
19
+ "deflect:budget_not_discussed": (
20
+ "We haven't really talked about what we're looking for yet."
21
+ ),
22
+ "deflect:stall": (
23
+ "Let me get back to you on this. A lot is happening on our end."
24
+ ),
25
+
26
+ "accept:demo_scheduled": (
27
+ "Yes, let's set up a demo. What time works next week?"
28
+ ),
29
+ "accept:close_success": (
30
+ "Alright, I think we can move forward with this. "
31
+ "Send over the paperwork."
32
+ ),
33
+
34
+ "reject:close_failed": (
35
+ "I don't think we're ready to commit at this point."
36
+ ),
37
+
38
+ "silence": "",
39
+
40
+ "exit:disqualified": (
41
+ "I think we're done here. This isn't the right fit."
42
+ ),
43
+ }
44
+
45
+ # Prefix injected into QUALIFY response to reveal budget signal
46
+ # without mutating prospect_profile (immutable prospect state).
47
+ BUDGET_REVEAL_TEXT = {
48
+ "high": "We do have solid budget allocated for this initiative. ",
49
+ "medium": "We have some budget set aside, though flexibility is limited. ",
50
+ "low": "Our budget is quite constrained right now. ",
51
+ }
52
+
53
+
54
+ class ProspectSimulator:
55
+ """
56
+ Pure rule-based simulator. No LLM. No transformers.
57
+ Deterministic per action type.
58
+
59
+ Immutability guarantee:
60
+ This class NEVER mutates state.prospect_profile.
61
+ Budget reveal is surfaced via the response *text* only.
62
+ The environment (salespath_environment.py) owns all state writes.
63
+ """
64
+
65
+ def respond(
66
+ self,
67
+ action: SalesPathAction,
68
+ state: SalesPathState,
69
+ ) -> tuple[str, str]:
70
+ """
71
+ Returns:
72
+ (response_token, response_text)
73
+ """
74
+ token = self._get_token(action, state)
75
+ text = self._build_text(token, action, state)
76
+ return token, text
77
+
78
+ # ------------------------------------------------------------------
79
+ # Private helpers
80
+ # ------------------------------------------------------------------
81
+
82
+ def _build_text(
83
+ self,
84
+ token: str,
85
+ action: SalesPathAction,
86
+ state: SalesPathState,
87
+ ) -> str:
88
+ base = RESPONSE_TEXT[token]
89
+
90
+ # Inject budget reveal into QUALIFY response text.
91
+ # We read from hidden_state, not prospect_profile, so no mutation needed.
92
+ if action.action_type == "QUALIFY":
93
+ budget_signal = state.prospect_profile.get("budget_signal", "unknown")
94
+ if budget_signal == "unknown":
95
+ revealed = state.hidden_state.get("revealed_budget", "medium")
96
+ prefix = BUDGET_REVEAL_TEXT.get(revealed, "")
97
+ return prefix + base
98
+
99
+ return base
100
+
101
+ def _get_token(
102
+ self,
103
+ action: SalesPathAction,
104
+ state: SalesPathState,
105
+ ) -> str:
106
+ atype = action.action_type
107
+ difficulty = state.difficulty
108
+ turn = state.turn_number
109
+ profile = state.prospect_profile
110
+ hidden = state.hidden_state
111
+ objections = state.objections_handled
112
+
113
+ # --------------------------------------------------
114
+ # 1. Rule-violation responses (highest priority)
115
+ # --------------------------------------------------
116
+ if state.constraints_violated:
117
+ latest = state.constraints_violated[-1]
118
+ if latest == "R01":
119
+ return "objection:premature_pitch"
120
+ if latest == "R03":
121
+ return "deflect:budget_not_discussed"
122
+
123
+ # --------------------------------------------------
124
+ # 2. Stall injection for difficulty 3+
125
+ # FIX: moved BEFORE action branches so it can
126
+ # actually fire (was dead code in original).
127
+ # --------------------------------------------------
128
+ if difficulty >= 3 and turn >= 5:
129
+ stall_prob = hidden.get("stall_probability", 0.0)
130
+ if stall_prob > 0.0 and random.random() < stall_prob:
131
+ return "deflect:stall"
132
+
133
+ # --------------------------------------------------
134
+ # 3. Action-based deterministic responses
135
+ # --------------------------------------------------
136
+ if atype == "PROSPECT":
137
+ return "open:positive_signal"
138
+
139
+ if atype == "QUALIFY":
140
+ return "open:neutral_signal"
141
+
142
+ if atype == "PRESENT":
143
+ if difficulty >= 2 and objections == 0:
144
+ return "objection:price"
145
+ return "open:positive_signal"
146
+
147
+ if atype == "HANDLE_OBJECTION":
148
+ state.objections_handled += 1 # only non-profile mutation
149
+ required = hidden.get("num_objections", 1)
150
+ if state.objections_handled >= required:
151
+ return "open:positive_signal"
152
+ if objections == 0:
153
+ return "objection:timing"
154
+ return "open:positive_signal"
155
+
156
+ if atype == "OFFER_DEMO":
157
+ return "accept:demo_scheduled"
158
+
159
+ if atype == "NEGOTIATE":
160
+ return "open:neutral_signal"
161
+
162
+ if atype == "CLOSE":
163
+ true_budget = hidden.get("true_budget", 0.7)
164
+ close_threshold = hidden.get("close_threshold", 0.5)
165
+ decision_maker = profile.get("decision_maker", True)
166
+ if true_budget >= close_threshold and decision_maker:
167
+ return "accept:close_success"
168
+ return "reject:close_failed"
169
+
170
+ if atype == "FOLLOW_UP":
171
+ return "open:neutral_signal"
172
+
173
+ if atype == "DISQUALIFY":
174
+ return "exit:disqualified"
175
+
176
  return "open:neutral_signal"
salespath_env/server/rules.py CHANGED
@@ -78,10 +78,15 @@ def _no_repeat_action(
78
  """
79
  R05:
80
  Same action twice in a row is invalid.
81
- """
82
- if state.conversation_history:
83
- last_action = state.conversation_history[-1].get("action_type", "")
84
- return last_action == action.action_type
 
 
 
 
 
85
  return False
86
 
87
 
@@ -104,13 +109,35 @@ def _followup_timing(
104
  ) -> bool:
105
  """
106
  R07:
107
- FOLLOW_UP only valid after silence.
108
- If prospect just responded last turn, violation.
 
109
  """
110
  if action.action_type == "FOLLOW_UP":
111
- if state.conversation_history:
112
- last_speaker = state.conversation_history[-1].get("speaker", "agent")
113
- return last_speaker == "prospect"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  return False
115
 
116
 
@@ -120,15 +147,20 @@ def _disqualify_logic(
120
  ) -> bool:
121
  """
122
  R08:
123
- DISQUALIFY only when prospect is genuinely not closeable.
124
- Violation if prospect is actually closeable.
 
 
 
125
  """
126
  if action.action_type == "DISQUALIFY":
127
  true_budget = state.hidden_state.get("true_budget", 0.5)
128
  close_threshold = state.hidden_state.get("close_threshold", 0.5)
129
  decision_maker = state.prospect_profile.get("decision_maker", True)
130
 
131
- return (true_budget >= close_threshold) and decision_maker
 
 
132
 
133
  return False
134
 
 
78
  """
79
  R05:
80
  Same action twice in a row is invalid.
81
+ FIX: conversation_history alternates agent/prospect entries.
82
+ Must filter to agent-only turns before comparing.
83
+ """
84
+ agent_turns = [
85
+ e for e in state.conversation_history
86
+ if e.get("speaker") == "agent"
87
+ ]
88
+ if agent_turns:
89
+ return agent_turns[-1].get("action_type", "") == action.action_type
90
  return False
91
 
92
 
 
109
  ) -> bool:
110
  """
111
  R07:
112
+ FOLLOW_UP only valid after prospect silence (no response for 1+ agent turns).
113
+ Violation if the prospect HAS replied since the last agent action.
114
+ FIX: Previous logic was inverted β€” it was blocking valid FOLLOW_UP.
115
  """
116
  if action.action_type == "FOLLOW_UP":
117
+ if not state.conversation_history:
118
+ return True # Nothing happened yet β€” FOLLOW_UP makes no sense
119
+
120
+ agent_turns = [
121
+ e for e in state.conversation_history
122
+ if e.get("speaker") == "agent"
123
+ ]
124
+ prospect_turns = [
125
+ e for e in state.conversation_history
126
+ if e.get("speaker") == "prospect"
127
+ ]
128
+
129
+ if not agent_turns:
130
+ return True
131
+
132
+ last_agent_turn_num = agent_turns[-1]["turn"]
133
+ last_prospect_turn_num = max(
134
+ (e["turn"] for e in prospect_turns),
135
+ default=0,
136
+ )
137
+
138
+ # Violation if prospect already responded AFTER the last agent turn
139
+ return last_prospect_turn_num >= last_agent_turn_num
140
+
141
  return False
142
 
143
 
 
147
  ) -> bool:
148
  """
149
  R08:
150
+ DISQUALIFY is correct ONLY when:
151
+ - true_budget < close_threshold AND
152
+ - decision_maker is False
153
+ Violation if prospect is actually closeable OR has a decision maker.
154
+ FIX: Both conditions must hold for a valid disqualification.
155
  """
156
  if action.action_type == "DISQUALIFY":
157
  true_budget = state.hidden_state.get("true_budget", 0.5)
158
  close_threshold = state.hidden_state.get("close_threshold", 0.5)
159
  decision_maker = state.prospect_profile.get("decision_maker", True)
160
 
161
+ # Valid disqualify requires: low budget AND no decision maker
162
+ valid_disqualify = (true_budget < close_threshold) and (not decision_maker)
163
+ return not valid_disqualify # Violation if NOT a valid disqualify case
164
 
165
  return False
166
 
salespath_env/server/salespath_environment.py CHANGED
@@ -220,6 +220,20 @@ class SalesPathEnvironment(Environment):
220
  )
221
  )
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  state.conversation_history.append(
224
  {
225
  "turn": state.turn_number,
 
220
  )
221
  )
222
 
223
+ # -----------------------------------
224
+ # Budget reveal (env owns state write)
225
+ # Simulator surfaced the info via text;
226
+ # now we update prospect_profile so rules
227
+ # (e.g. R03) can see the revealed value.
228
+ # -----------------------------------
229
+ if (
230
+ action.action_type == "QUALIFY"
231
+ and state.prospect_profile.get("budget_signal") == "unknown"
232
+ ):
233
+ state.prospect_profile["budget_signal"] = (
234
+ state.hidden_state.get("revealed_budget", "medium")
235
+ )
236
+
237
  state.conversation_history.append(
238
  {
239
  "turn": state.turn_number,
training/__pycache__/plot_rewards.cpython-312.pyc ADDED
Binary file (5.92 kB). View file
 
training/__pycache__/train_grpo.cpython-312.pyc ADDED
Binary file (13.6 kB). View file
 
training/__pycache__/train_sft.cpython-312.pyc ADDED
Binary file (5.39 kB). View file
 
training/__pycache__/train_test.cpython-312.pyc ADDED
Binary file (8.78 kB). View file
 
training/plot_rewards.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ plot_rewards.py β€” Visualise GRPO training progress
3
+ ====================================================
4
+ Reads reward_log.jsonl written by train_grpo.py and
5
+ produces two plots:
6
+
7
+ 1. Mean reward per step (with min/max band)
8
+ 2. Reward by difficulty level
9
+
10
+ Run:
11
+ python training/plot_rewards.py
12
+ python training/plot_rewards.py --log ./reward_log.jsonl --out ./plots/
13
+ """
14
+
15
+ import argparse
16
+ import json
17
+ import os
18
+ from collections import defaultdict
19
+
20
+ def load_log(path: str) -> list[dict]:
21
+ records = []
22
+ with open(path) as f:
23
+ for line in f:
24
+ line = line.strip()
25
+ if line:
26
+ records.append(json.loads(line))
27
+ return records
28
+
29
+
30
+ def plot(log_path: str, out_dir: str):
31
+ try:
32
+ import matplotlib.pyplot as plt
33
+ except ImportError:
34
+ print("❌ matplotlib not installed. pip install matplotlib")
35
+ return
36
+
37
+ os.makedirs(out_dir, exist_ok=True)
38
+ records = load_log(log_path)
39
+
40
+ if not records:
41
+ print(f"❌ No records found in {log_path}")
42
+ return
43
+
44
+ steps = [r["step"] for r in records]
45
+ means = [r["mean_reward"] for r in records]
46
+ maxes = [r["max_reward"] for r in records]
47
+ mins = [r["min_reward"] for r in records]
48
+ difficulties = [r["difficulty"] for r in records]
49
+
50
+ # --- Plot 1: mean reward with band ---
51
+ fig, ax = plt.subplots(figsize=(10, 5))
52
+ ax.plot(steps, means, label="Mean reward", color="#4C6EF5", linewidth=2)
53
+ ax.fill_between(steps, mins, maxes, alpha=0.2, color="#4C6EF5", label="Min/Max band")
54
+ ax.axhline(0, color="gray", linestyle="--", linewidth=0.8)
55
+
56
+ # Mark difficulty changes
57
+ prev_d = None
58
+ for s, d in zip(steps, difficulties):
59
+ if d != prev_d:
60
+ ax.axvline(s, color="orange", linestyle=":", linewidth=1.2, alpha=0.7)
61
+ ax.text(s + 0.5, ax.get_ylim()[0] * 0.9, f"D{d}", fontsize=8, color="orange")
62
+ prev_d = d
63
+
64
+ ax.set_xlabel("Training Step")
65
+ ax.set_ylabel("Episode Reward")
66
+ ax.set_title("SalesPath GRPO β€” Mean Reward per Step")
67
+ ax.legend()
68
+ ax.grid(True, alpha=0.3)
69
+ plt.tight_layout()
70
+ path1 = os.path.join(out_dir, "reward_curve.png")
71
+ plt.savefig(path1, dpi=150)
72
+ print(f"βœ… Saved: {path1}")
73
+
74
+ # --- Plot 2: per-difficulty box ---
75
+ by_diff = defaultdict(list)
76
+ for r in records:
77
+ by_diff[r["difficulty"]].append(r["mean_reward"])
78
+
79
+ fig2, ax2 = plt.subplots(figsize=(7, 5))
80
+ labels = sorted(by_diff.keys())
81
+ data = [by_diff[d] for d in labels]
82
+ ax2.boxplot(data, labels=[f"Difficulty {d}" for d in labels], patch_artist=True)
83
+ ax2.set_ylabel("Mean Episode Reward")
84
+ ax2.set_title("Reward Distribution by Difficulty Level")
85
+ ax2.grid(True, alpha=0.3)
86
+ plt.tight_layout()
87
+ path2 = os.path.join(out_dir, "reward_by_difficulty.png")
88
+ plt.savefig(path2, dpi=150)
89
+ print(f"βœ… Saved: {path2}")
90
+
91
+ # Print quick stats
92
+ print(f"\nπŸ“Š Training summary ({len(records)} steps)")
93
+ print(f" Start mean reward : {means[0]:+.3f}")
94
+ print(f" End mean reward : {means[-1]:+.3f}")
95
+ print(f" Best mean reward : {max(means):+.3f} (step {steps[means.index(max(means))]})")
96
+
97
+
98
+ if __name__ == "__main__":
99
+ parser = argparse.ArgumentParser()
100
+ parser.add_argument("--log", default="./reward_log.jsonl")
101
+ parser.add_argument("--out", default="./plots")
102
+ args = parser.parse_args()
103
+ plot(args.log, args.out)
training/sft_demos.jsonl ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Meridian Retail, a medium retail company. Pain points: manual inventory tracking, slow reporting. Begin the sales conversation. [Stage: START | Turn: 0]"}, {"role": "assistant", "content": "ACTION: PROSPECT\nCONTENT: Hi! I understand Meridian Retail is dealing with manual inventory tracking and slow reporting. I'd love to learn more about how these challenges are impacting your day-to-day operations."}]}
2
+ {"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Meridian Retail, a medium retail company. Pain points: manual inventory tracking, slow reporting. [Stage: PROSPECT | Turn: 1 | Prospect said: That sounds interesting. Tell me more about how this works.]"}, {"role": "assistant", "content": "ACTION: QUALIFY\nCONTENT: Great! Before I dive in, could you share more about your budget range and who the key decision makers are for this type of initiative?"}]}
3
+ {"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Meridian Retail, a medium retail company. Pain points: manual inventory tracking, slow reporting. [Stage: QUALIFY | Turn: 2 | Prospect said: I see. We're evaluating a few options at the moment. Budget signal: high. Steps completed: PROSPECT, QUALIFY]"}, {"role": "assistant", "content": "ACTION: PRESENT\nCONTENT: Perfect. Our platform automates inventory tracking in real-time and generates reports in seconds β€” exactly the pain points you mentioned. It integrates with your existing systems with zero downtime."}]}
4
+ {"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Meridian Retail, a medium retail company. [Stage: PRESENT | Turn: 3 | Prospect said: That sounds interesting. Tell me more. | Steps completed: PROSPECT, QUALIFY, PRESENT]"}, {"role": "assistant", "content": "ACTION: CLOSE\nCONTENT: Based on everything we've discussed, I believe our solution is an excellent fit for Meridian Retail. Shall we move forward and get the paperwork started?"}]}
5
+ {"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Apex Logistics, an enterprise logistics company. Pain points: route optimization, driver coordination, fuel tracking. Budget: unknown. Begin the sales conversation. [Stage: START | Turn: 0]"}, {"role": "assistant", "content": "ACTION: PROSPECT\nCONTENT: Hello! I see Apex Logistics is tackling route optimization and driver coordination challenges. These are critical in logistics. I'd love to explore how we can help β€” could you walk me through your current process?"}]}
6
+ {"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Apex Logistics. [Stage: PROSPECT | Turn: 1 | Prospect said: That sounds interesting. | Budget: unknown]"}, {"role": "assistant", "content": "ACTION: QUALIFY\nCONTENT: To make sure I propose the right solution, could you share your approximate budget for this initiative and confirm who the final decision maker would be?"}]}
7
+ {"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Apex Logistics. [Stage: QUALIFY | Turn: 2 | Prospect said: We have some budget set aside. Budget signal: medium | Steps: PROSPECT, QUALIFY]"}, {"role": "assistant", "content": "ACTION: PRESENT\nCONTENT: Our platform optimizes routes using real-time traffic and GPS data, cutting fuel costs by up to 20% and improving driver coordination through a unified dashboard. Companies your size see ROI within 6 months."}]}
8
+ {"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Apex Logistics. [Stage: PRESENT | Turn: 3 | Prospect said: The pricing seems higher than what we budgeted. | Objections: 0 | Steps: PROSPECT, QUALIFY, PRESENT]"}, {"role": "assistant", "content": "ACTION: HANDLE_OBJECTION\nCONTENT: That's a fair concern. The upfront cost is offset by the fuel savings and reduced overtime β€” most clients break even within 4 months. We also offer flexible payment plans. Would that make it more accessible?"}]}
9
+ {"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Apex Logistics. [Stage: HANDLE_OBJECTION | Turn: 4 | Prospect said: That sounds reasonable. | Objections handled: 1 | Steps: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION]"}, {"role": "assistant", "content": "ACTION: OFFER_DEMO\nCONTENT: I'd love to show you the platform in action with your actual data. Can we schedule a 30-minute live demo next week so your team can see the route optimization in real time?"}]}
10
+ {"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Apex Logistics. [Stage: OFFER_DEMO | Turn: 5 | Prospect said: Yes, let's set up a demo. | Steps: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO]"}, {"role": "assistant", "content": "ACTION: CLOSE\nCONTENT: Excellent! Given the demo and the ROI we've discussed, I think we're aligned. Are you ready to move forward with the agreement?"}]}
11
+ {"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Cipher Tech, a small tech company. Budget signal: high (but misleading β€” actual budget is low). No decision maker present. [Stage: START | Turn: 0]"}, {"role": "assistant", "content": "ACTION: PROSPECT\nCONTENT: Hi! I understand you're facing security and compliance challenges. Could you tell me more about your current setup and what you're trying to achieve?"}]}
12
+ {"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Cipher Tech. [Stage: QUALIFY | Turn: 2 | Prospect said: Our budget is quite constrained. No decision maker available. | Budget: low | Decision maker: false | Steps: PROSPECT, QUALIFY]"}, {"role": "assistant", "content": "ACTION: DISQUALIFY\nCONTENT: Thank you for your transparency. Based on your current budget constraints and the absence of a decision maker, I don't think we're the right fit at this time. I'll leave the door open for when circumstances change."}]}
13
+ {"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Nova Financial. [Stage: OFFER_DEMO | Turn: 7 | Prospect said: Let me get back to you. A lot is happening on our end. | Objections: 2 | Budget: medium | Steps: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, HANDLE_OBJECTION, OFFER_DEMO]"}, {"role": "assistant", "content": "ACTION: NEGOTIATE\nCONTENT: I understand you're busy. To make this decision easier, I can offer a 3-month pilot at a reduced rate so your team can validate the ROI before full commitment."}]}
14
+ {"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Atlas Health. [Stage: NEGOTIATE | Turn: 9 | Prospect said: I see. We're evaluating options. | Steps: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE]"}, {"role": "assistant", "content": "ACTION: CLOSE\nCONTENT: We've covered your pain points, seen the demo, and aligned on terms. I believe this is the right solution for Atlas Health. Can we finalize the agreement today?"}]}
training/train_grpo.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ train_grpo.py β€” Full GRPO RL Training
3
+ =======================================
4
+ Stage 2: loads the SFT checkpoint and fine-tunes with GRPO
5
+ using live rollouts against the SalesPath environment.
6
+
7
+ Architecture
8
+ ------------
9
+ SFT checkpoint β†’ Unsloth 4-bit QLoRA β†’ GRPOTrainer (TRL)
10
+ ↓
11
+ SalesPath env (HTTP)
12
+ reward = composite score
13
+
14
+ Recommended hardware : A100 / T4 GPU (Google Colab)
15
+ Expected runtime : ~45-90 min for 200 steps on T4
16
+
17
+ Run:
18
+ # 1. Start the env server in another terminal:
19
+ # uvicorn salespath_env.server.app:app --port 7860
20
+ #
21
+ # 2. Then run this script:
22
+ python training/train_grpo.py
23
+
24
+ Outputs:
25
+ ./grpo_checkpoint/ ← final RL-trained model
26
+ reward_log.jsonl ← per-step reward components for plotting
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import json
32
+ import os
33
+ import re
34
+ import sys
35
+ import time
36
+ from typing import Any
37
+
38
+ import torch
39
+
40
+ # ---------------------------------------------------------------------------
41
+ # Config
42
+ # ---------------------------------------------------------------------------
43
+ ENV_URL = os.environ.get("SALESPATH_ENV_URL", "http://localhost:7860")
44
+ SFT_CHECKPOINT = os.environ.get("SFT_CHECKPOINT", "./sft_checkpoint")
45
+ OUTPUT_DIR = "./grpo_checkpoint"
46
+ REWARD_LOG_PATH = "./reward_log.jsonl"
47
+
48
+ MODEL_NAME = SFT_CHECKPOINT # start from SFT weights
49
+ MAX_SEQ_LEN = 1024
50
+ LORA_R = 16
51
+ LORA_ALPHA = 16
52
+
53
+ # GRPO hyper-parameters
54
+ NUM_TRAIN_STEPS = 200 # increase to 500+ for best results
55
+ ROLLOUTS_PER_STEP = 8 # episodes collected before each gradient update
56
+ DIFFICULTY_SCHEDULE = { # step β†’ difficulty to use for rollouts
57
+ 0: 1,
58
+ 50: 2,
59
+ 100: 3,
60
+ 150: 4,
61
+ }
62
+ LR = 5e-6
63
+ KL_COEFF = 0.05 # keep close to SFT policy
64
+ GRAD_ACCUM = 4
65
+ BATCH_SIZE = 2
66
+
67
+ REPORT_TO = "none" # swap to "wandb" for live reward curves
68
+
69
+ # ---------------------------------------------------------------------------
70
+ # 1. Load model (Unsloth 4-bit QLoRA)
71
+ # ---------------------------------------------------------------------------
72
+
73
+ try:
74
+ from unsloth import FastLanguageModel
75
+ USE_UNSLOTH = True
76
+ except ImportError:
77
+ USE_UNSLOTH = False
78
+ print("⚠️ Unsloth not found β€” falling back to HuggingFace transformers.")
79
+
80
+ if USE_UNSLOTH:
81
+ model, tokenizer = FastLanguageModel.from_pretrained(
82
+ model_name=MODEL_NAME,
83
+ max_seq_length=MAX_SEQ_LEN,
84
+ dtype=None,
85
+ load_in_4bit=True,
86
+ )
87
+ model = FastLanguageModel.get_peft_model(
88
+ model,
89
+ r=LORA_R,
90
+ lora_alpha=LORA_ALPHA,
91
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
92
+ "gate_proj", "up_proj", "down_proj"],
93
+ lora_dropout=0.0,
94
+ bias="none",
95
+ use_gradient_checkpointing="unsloth",
96
+ random_state=42,
97
+ )
98
+ else:
99
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
100
+ from peft import get_peft_model, LoraConfig, TaskType
101
+
102
+ bnb = BitsAndBytesConfig(
103
+ load_in_4bit=True,
104
+ bnb_4bit_use_double_quant=True,
105
+ bnb_4bit_quant_type="nf4",
106
+ bnb_4bit_compute_dtype=torch.bfloat16,
107
+ )
108
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
109
+ model = AutoModelForCausalLM.from_pretrained(
110
+ MODEL_NAME, quantization_config=bnb, device_map="auto"
111
+ )
112
+ model = get_peft_model(model, LoraConfig(
113
+ r=LORA_R, lora_alpha=LORA_ALPHA,
114
+ target_modules=["q_proj", "v_proj"],
115
+ task_type=TaskType.CAUSAL_LM,
116
+ ))
117
+
118
+ tokenizer.pad_token = tokenizer.eos_token
119
+ tokenizer.padding_side = "right"
120
+ print(f"βœ… Model loaded from: {MODEL_NAME}")
121
+
122
+ # ---------------------------------------------------------------------------
123
+ # 2. Environment client
124
+ # ---------------------------------------------------------------------------
125
+
126
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
127
+ from salespath_env.client import SalesPathClient
128
+
129
+ client = SalesPathClient(ENV_URL)
130
+ print(f"βœ… Connected to env at {ENV_URL} β†’ {client.health()}")
131
+
132
+ # ---------------------------------------------------------------------------
133
+ # 3. Prompt / action helpers
134
+ # ---------------------------------------------------------------------------
135
+
136
+ SYSTEM_PROMPT = (
137
+ "You are a professional B2B sales agent. "
138
+ "Follow the correct sales process to close deals.\n"
139
+ "Always respond with exactly ONE action in this format:\n"
140
+ "ACTION: <ACTION_TYPE>\n"
141
+ "CONTENT: <your message>\n\n"
142
+ "Valid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, "
143
+ "OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"
144
+ )
145
+
146
+ ACTION_RE = re.compile(
147
+ r"ACTION:\s*([A-Z_]+)\s*\nCONTENT:\s*(.+)",
148
+ re.DOTALL,
149
+ )
150
+
151
+
152
+ def obs_to_user_message(obs: dict, stage: str, turn: int) -> str:
153
+ parts = [obs.get("prospect_response", "")]
154
+ if obs.get("steps_completed"):
155
+ parts.append(f"Steps completed: {', '.join(obs['steps_completed'])}")
156
+ if obs.get("constraints_violated"):
157
+ parts.append(f"⚠ Violations: {', '.join(obs['constraints_violated'])}")
158
+ parts.append(f"[Stage: {stage} | Turn: {turn}]")
159
+ return "\n".join(parts)
160
+
161
+
162
+ def parse_action(text: str) -> tuple[str, str]:
163
+ """Extract (action_type, content) from model output."""
164
+ m = ACTION_RE.search(text.strip())
165
+ if m:
166
+ return m.group(1).strip(), m.group(2).strip()
167
+ # Fallback: if the model doesn't follow format, treat whole text as QUALIFY
168
+ return "QUALIFY", text.strip()
169
+
170
+
171
+ def generate_action(messages: list[dict]) -> str:
172
+ """Run one forward pass; return raw generated text."""
173
+ inputs = tokenizer.apply_chat_template(
174
+ messages,
175
+ tokenize=True,
176
+ add_generation_prompt=True,
177
+ return_tensors="pt",
178
+ ).to(model.device)
179
+
180
+ with torch.no_grad():
181
+ output_ids = model.generate(
182
+ inputs,
183
+ max_new_tokens=128,
184
+ do_sample=True,
185
+ temperature=0.7,
186
+ top_p=0.9,
187
+ pad_token_id=tokenizer.eos_token_id,
188
+ )
189
+
190
+ new_tokens = output_ids[0, inputs.shape[-1]:]
191
+ return tokenizer.decode(new_tokens, skip_special_tokens=True)
192
+
193
+
194
+ # ---------------------------------------------------------------------------
195
+ # 4. Rollout collector
196
+ # ---------------------------------------------------------------------------
197
+
198
+ def run_episode(difficulty: int) -> list[dict]:
199
+ """
200
+ Run one complete episode; return list of
201
+ {prompt_messages, completion, reward, reward_components} dicts.
202
+ """
203
+ obs = client.reset(difficulty=difficulty)
204
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
205
+ samples = []
206
+
207
+ for _ in range(20): # hard cap matches env MAX_TURNS
208
+ user_msg = obs_to_user_message(
209
+ obs,
210
+ obs.get("workflow_stage", "START"),
211
+ obs.get("turn_number", 0),
212
+ )
213
+ messages.append({"role": "user", "content": user_msg})
214
+
215
+ # Generate & parse
216
+ completion = generate_action(list(messages))
217
+ action_type, content = parse_action(completion)
218
+
219
+ # Step env
220
+ obs = client.step(action_type, content)
221
+
222
+ samples.append({
223
+ "messages": list(messages),
224
+ "completion": completion,
225
+ "reward": obs["reward"],
226
+ "reward_components": obs.get("reward_components", {}),
227
+ })
228
+
229
+ messages.append({"role": "assistant", "content": completion})
230
+
231
+ if obs["done"]:
232
+ break
233
+
234
+ return samples
235
+
236
+
237
+ def collect_rollouts(
238
+ n: int,
239
+ difficulty: int,
240
+ ) -> tuple[list[str], list[str], list[float]]:
241
+ """
242
+ Collect n episode rollouts.
243
+ Returns (prompts, completions, rewards) as flat lists for GRPOTrainer.
244
+ """
245
+ prompts, completions, rewards = [], [], []
246
+
247
+ for ep in range(n):
248
+ samples = run_episode(difficulty)
249
+ for s in samples:
250
+ prompt_text = tokenizer.apply_chat_template(
251
+ s["messages"],
252
+ tokenize=False,
253
+ add_generation_prompt=True,
254
+ )
255
+ prompts.append(prompt_text)
256
+ completions.append(s["completion"])
257
+ rewards.append(s["reward"])
258
+
259
+ ep_reward = sum(s["reward"] for s in samples)
260
+ print(f" ep {ep+1}/{n} steps={len(samples)} ep_reward={ep_reward:+.3f}")
261
+
262
+ return prompts, completions, rewards
263
+
264
+
265
+ # ---------------------------------------------------------------------------
266
+ # 5. Reward log helper
267
+ # ---------------------------------------------------------------------------
268
+
269
+ reward_log: list[dict] = []
270
+
271
+ def log_rewards(step: int, rewards: list[float], difficulty: int) -> None:
272
+ entry = {
273
+ "step": step,
274
+ "difficulty": difficulty,
275
+ "mean_reward": sum(rewards) / len(rewards),
276
+ "max_reward": max(rewards),
277
+ "min_reward": min(rewards),
278
+ "n_samples": len(rewards),
279
+ }
280
+ reward_log.append(entry)
281
+ with open(REWARD_LOG_PATH, "a") as f:
282
+ f.write(json.dumps(entry) + "\n")
283
+ print(
284
+ f" πŸ“Š step={step:4d} diff={difficulty} "
285
+ f"mean={entry['mean_reward']:+.3f} "
286
+ f"max={entry['max_reward']:+.3f}"
287
+ )
288
+
289
+
290
+ # ---------------------------------------------------------------------------
291
+ # 6. GRPOTrainer setup
292
+ # ---------------------------------------------------------------------------
293
+
294
+ from datasets import Dataset
295
+ from trl import GRPOTrainer, GRPOConfig
296
+
297
+
298
+ def make_reward_fn(precomputed: dict[str, float]):
299
+ """
300
+ GRPOTrainer calls reward_funcs(prompts, completions) β†’ list[float].
301
+ We pre-run rollouts and store results; the reward_fn just looks them up.
302
+ """
303
+ def reward_fn(prompts: list[str], completions: list[str], **kwargs) -> list[float]:
304
+ return [
305
+ precomputed.get(p + c, 0.0)
306
+ for p, c in zip(prompts, completions)
307
+ ]
308
+ return reward_fn
309
+
310
+
311
+ grpo_config = GRPOConfig(
312
+ output_dir=OUTPUT_DIR,
313
+ num_train_epochs=1, # we control steps manually
314
+ per_device_train_batch_size=BATCH_SIZE,
315
+ gradient_accumulation_steps=GRAD_ACCUM,
316
+ learning_rate=LR,
317
+ kl_coeff=KL_COEFF,
318
+ logging_steps=1,
319
+ save_steps=50,
320
+ fp16=not USE_UNSLOTH,
321
+ report_to=REPORT_TO,
322
+ max_completion_length=128,
323
+ remove_unused_columns=False,
324
+ )
325
+
326
+
327
+ # ---------------------------------------------------------------------------
328
+ # 7. Training loop
329
+ # ---------------------------------------------------------------------------
330
+
331
+ print(f"\nπŸš€ Starting GRPO training for {NUM_TRAIN_STEPS} steps")
332
+ print(f" Rollouts per step : {ROLLOUTS_PER_STEP}")
333
+ print(f" KL coefficient : {KL_COEFF}")
334
+ print(f" Difficulty schedule: {DIFFICULTY_SCHEDULE}\n")
335
+
336
+ for step in range(NUM_TRAIN_STEPS):
337
+ # Determine difficulty for this step
338
+ difficulty = 1
339
+ for threshold, d in sorted(DIFFICULTY_SCHEDULE.items()):
340
+ if step >= threshold:
341
+ difficulty = d
342
+
343
+ print(f"\n[Step {step+1}/{NUM_TRAIN_STEPS}] difficulty={difficulty}")
344
+
345
+ # -- Collect rollouts --
346
+ prompts, completions, rewards = collect_rollouts(
347
+ ROLLOUTS_PER_STEP, difficulty
348
+ )
349
+ log_rewards(step + 1, rewards, difficulty)
350
+
351
+ # -- Build dataset for this step --
352
+ reward_lookup = {
353
+ p + c: r
354
+ for p, c, r in zip(prompts, completions, rewards)
355
+ }
356
+ step_dataset = Dataset.from_dict({
357
+ "prompt": prompts,
358
+ "completion": completions,
359
+ })
360
+
361
+ # -- GRPOTrainer one-step update --
362
+ trainer = GRPOTrainer(
363
+ model=model,
364
+ reward_funcs=make_reward_fn(reward_lookup),
365
+ args=grpo_config,
366
+ train_dataset=step_dataset,
367
+ processing_class=tokenizer,
368
+ )
369
+ trainer.train()
370
+
371
+ # Save checkpoint every 50 steps
372
+ if (step + 1) % 50 == 0:
373
+ ckpt = os.path.join(OUTPUT_DIR, f"step_{step+1}")
374
+ model.save_pretrained(ckpt)
375
+ tokenizer.save_pretrained(ckpt)
376
+ print(f" πŸ’Ύ Checkpoint saved: {ckpt}")
377
+
378
+
379
+ # ---------------------------------------------------------------------------
380
+ # 8. Final save
381
+ # ---------------------------------------------------------------------------
382
+
383
+ model.save_pretrained(OUTPUT_DIR)
384
+ tokenizer.save_pretrained(OUTPUT_DIR)
385
+ print(f"\nβœ… GRPO training complete.")
386
+ print(f" Model β†’ {OUTPUT_DIR}")
387
+ print(f" Rewards β†’ {REWARD_LOG_PATH}")
388
+ print("\nPlot rewards with: python training/plot_rewards.py")
training/train_sft.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ train_sft.py β€” SFT Warm-Start Stage
3
+ =====================================
4
+ Fine-tunes a base LLM on expert sales demonstrations BEFORE GRPO.
5
+ SFT teaches the model the correct action FORMAT and rough ordering,
6
+ giving GRPO a much better starting policy.
7
+
8
+ Recommended hardware : T4 GPU (Google Colab free tier)
9
+ Expected runtime : ~10-15 minutes for 14 demos Γ— 3 epochs
10
+
11
+ Run:
12
+ python training/train_sft.py
13
+
14
+ Outputs:
15
+ ./sft_checkpoint/ ← load this as base in train_grpo.py
16
+ """
17
+
18
+ import json
19
+ import os
20
+ import sys
21
+
22
+ # ---------------------------------------------------------------------------
23
+ # Config β€” tweak these
24
+ # ---------------------------------------------------------------------------
25
+ MODEL_NAME = "unsloth/Qwen2.5-1.5B-Instruct" # swap for 0.5B on tiny GPU
26
+ OUTPUT_DIR = "./sft_checkpoint"
27
+ DATA_PATH = os.path.join(os.path.dirname(__file__), "sft_demos.jsonl")
28
+ MAX_SEQ_LEN = 1024
29
+ NUM_EPOCHS = 3
30
+ BATCH_SIZE = 2
31
+ GRAD_ACCUM = 4
32
+ LR = 2e-4
33
+ LORA_R = 16
34
+ LORA_ALPHA = 16
35
+
36
+ # ---------------------------------------------------------------------------
37
+ # 1. Load model with Unsloth 4-bit QLoRA
38
+ # ---------------------------------------------------------------------------
39
+ try:
40
+ from unsloth import FastLanguageModel
41
+ USE_UNSLOTH = True
42
+ except ImportError:
43
+ USE_UNSLOTH = False
44
+ print("⚠️ Unsloth not installed β€” falling back to plain HuggingFace.")
45
+ print(" Install with: pip install unsloth")
46
+
47
+ if USE_UNSLOTH:
48
+ model, tokenizer = FastLanguageModel.from_pretrained(
49
+ model_name=MODEL_NAME,
50
+ max_seq_length=MAX_SEQ_LEN,
51
+ dtype=None, # auto-detect: bf16 on Ampere+, fp16 otherwise
52
+ load_in_4bit=True,
53
+ )
54
+ model = FastLanguageModel.get_peft_model(
55
+ model,
56
+ r=LORA_R,
57
+ lora_alpha=LORA_ALPHA,
58
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
59
+ "gate_proj", "up_proj", "down_proj"],
60
+ lora_dropout=0.05,
61
+ bias="none",
62
+ use_gradient_checkpointing="unsloth",
63
+ random_state=42,
64
+ )
65
+ else:
66
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
67
+ from peft import get_peft_model, LoraConfig, TaskType
68
+ import torch
69
+
70
+ bnb_config = BitsAndBytesConfig(
71
+ load_in_4bit=True,
72
+ bnb_4bit_use_double_quant=True,
73
+ bnb_4bit_quant_type="nf4",
74
+ bnb_4bit_compute_dtype=torch.bfloat16,
75
+ )
76
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
77
+ model = AutoModelForCausalLM.from_pretrained(
78
+ MODEL_NAME,
79
+ quantization_config=bnb_config,
80
+ device_map="auto",
81
+ )
82
+ lora_config = LoraConfig(
83
+ r=LORA_R, lora_alpha=LORA_ALPHA,
84
+ target_modules=["q_proj", "v_proj"],
85
+ task_type=TaskType.CAUSAL_LM,
86
+ )
87
+ model = get_peft_model(model, lora_config)
88
+
89
+ tokenizer.pad_token = tokenizer.eos_token
90
+ tokenizer.padding_side = "right"
91
+
92
+ print(f"βœ… Model loaded: {MODEL_NAME} (4-bit QLoRA, r={LORA_R})")
93
+
94
+ # ---------------------------------------------------------------------------
95
+ # 2. Load & format SFT dataset
96
+ # ---------------------------------------------------------------------------
97
+
98
+ def load_sft_data(path: str) -> list[dict]:
99
+ records = []
100
+ with open(path) as f:
101
+ for line in f:
102
+ line = line.strip()
103
+ if line:
104
+ records.append(json.loads(line))
105
+ return records
106
+
107
+
108
+ def format_chat(record: dict) -> str:
109
+ """
110
+ Apply the model's chat template to convert messages β†’ a single string.
111
+ """
112
+ return tokenizer.apply_chat_template(
113
+ record["messages"],
114
+ tokenize=False,
115
+ add_generation_prompt=False,
116
+ )
117
+
118
+
119
+ raw_data = load_sft_data(DATA_PATH)
120
+ print(f"βœ… Loaded {len(raw_data)} SFT demonstrations from {DATA_PATH}")
121
+
122
+ from datasets import Dataset
123
+
124
+ formatted = [{"text": format_chat(r)} for r in raw_data]
125
+ dataset = Dataset.from_list(formatted)
126
+ print(f" Sample:\n{formatted[0]['text'][:300]}\n...")
127
+
128
+ # ---------------------------------------------------------------------------
129
+ # 3. SFT Trainer
130
+ # ---------------------------------------------------------------------------
131
+
132
+ from trl import SFTTrainer, SFTConfig
133
+
134
+ sft_config = SFTConfig(
135
+ output_dir=OUTPUT_DIR,
136
+ num_train_epochs=NUM_EPOCHS,
137
+ per_device_train_batch_size=BATCH_SIZE,
138
+ gradient_accumulation_steps=GRAD_ACCUM,
139
+ learning_rate=LR,
140
+ warmup_ratio=0.1,
141
+ lr_scheduler_type="cosine",
142
+ logging_steps=1,
143
+ save_strategy="epoch",
144
+ fp16=not USE_UNSLOTH, # Unsloth handles this internally
145
+ bf16=False,
146
+ max_seq_length=MAX_SEQ_LEN,
147
+ dataset_text_field="text",
148
+ report_to="none", # swap to "wandb" if you have W&B set up
149
+ )
150
+
151
+ trainer = SFTTrainer(
152
+ model=model,
153
+ tokenizer=tokenizer,
154
+ train_dataset=dataset,
155
+ args=sft_config,
156
+ )
157
+
158
+ print("πŸš€ Starting SFT training...")
159
+ trainer_stats = trainer.train()
160
+
161
+ print(f"\nβœ… SFT done.")
162
+ print(f" Loss : {trainer_stats.training_loss:.4f}")
163
+ print(f" Saved : {OUTPUT_DIR}")
164
+
165
+ # ---------------------------------------------------------------------------
166
+ # 4. Save final checkpoint
167
+ # ---------------------------------------------------------------------------
168
+
169
+ model.save_pretrained(OUTPUT_DIR)
170
+ tokenizer.save_pretrained(OUTPUT_DIR)
171
+ print(f"βœ… Checkpoint saved to {OUTPUT_DIR}")
172
+ print("\nNext step β†’ run: python training/train_grpo.py")
training/train_test.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ train_test.py β€” Quick smoke test (no GPU, no LLM needed).
3
+
4
+ Tests the FULL pipeline end-to-end in ~30 seconds:
5
+ 1. Starts the env server in a subprocess
6
+ 2. Runs 4 scripted episodes (one per difficulty)
7
+ 3. Prints reward traces and rule-violation checks
8
+ 4. Verifies the three fixed bugs (R05, R07, R08) behave correctly
9
+
10
+ Run:
11
+ python training/train_test.py
12
+ """
13
+
14
+ import subprocess
15
+ import sys
16
+ import time
17
+ import os
18
+
19
+ # Add project root to path so we can import client
20
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
21
+
22
+
23
+ # ---------------------------------------------------------------------------
24
+ # 1. Start server in background
25
+ # ---------------------------------------------------------------------------
26
+
27
+ def start_server(port: int = 7860):
28
+ proc = subprocess.Popen(
29
+ [
30
+ sys.executable, "-m", "uvicorn",
31
+ "salespath_env.server.app:app",
32
+ "--host", "0.0.0.0",
33
+ "--port", str(port),
34
+ "--log-level", "error",
35
+ ],
36
+ cwd=os.path.join(os.path.dirname(__file__), ".."),
37
+ )
38
+ print(f"⏳ Starting server on port {port}...")
39
+ time.sleep(4) # wait for uvicorn to be ready
40
+ return proc
41
+
42
+
43
+ # ---------------------------------------------------------------------------
44
+ # 2. Import client
45
+ # ---------------------------------------------------------------------------
46
+
47
+ from salespath_env.client import SalesPathClient
48
+
49
+
50
+ # ---------------------------------------------------------------------------
51
+ # 3. Test episodes
52
+ # ---------------------------------------------------------------------------
53
+
54
+ EPISODES = {
55
+ 1: [
56
+ # Happy path β€” difficulty 1
57
+ ("PROSPECT", "Hello, tell me about your challenges."),
58
+ ("QUALIFY", "What's your budget and who decides?"),
59
+ ("PRESENT", "Here's how we solve your inventory problem."),
60
+ ("CLOSE", "Shall we move forward?"),
61
+ ],
62
+ 2: [
63
+ # Objection + demo β€” difficulty 2
64
+ ("PROSPECT", "Hi, I'd like to learn more about your operations."),
65
+ ("QUALIFY", "Can you share your budget range?"),
66
+ ("PRESENT", "Here is our solution."),
67
+ ("HANDLE_OBJECTION", "Totally understand the pricing concern β€” here's why the ROI works."),
68
+ ("OFFER_DEMO", "Let me show you a live demo next week."),
69
+ ("CLOSE", "Ready to move forward?"),
70
+ ],
71
+ 3: [
72
+ # Full hard path β€” difficulty 3
73
+ ("PROSPECT", "Hello, I've researched your compliance challenges."),
74
+ ("QUALIFY", "Who is the decision maker and what is the budget?"),
75
+ ("PRESENT", "Here's how we address audit trails and data silos."),
76
+ ("HANDLE_OBJECTION", "I understand the timing is tough β€” many clients felt the same."),
77
+ ("HANDLE_OBJECTION", "On the price point β€” we can structure payments quarterly."),
78
+ ("OFFER_DEMO", "Let me show you a live demo."),
79
+ ("NEGOTIATE", "Here is a pilot option at a reduced rate."),
80
+ ("CLOSE", "Shall we proceed?"),
81
+ ],
82
+ 4: [
83
+ # Trap case β€” correct action is DISQUALIFY
84
+ ("PROSPECT", "Hi, tell me about your security needs."),
85
+ ("QUALIFY", "What is your budget and who decides?"),
86
+ ("DISQUALIFY", "Given the budget constraints and no decision maker, this isn't the right time."),
87
+ ],
88
+ }
89
+
90
+
91
+ def run_episode(client: SalesPathClient, difficulty: int, script: list) -> dict:
92
+ obs = client.reset(difficulty=difficulty)
93
+ print(f"\n{'='*60}")
94
+ print(f" Difficulty {difficulty} | Prospect: {obs.get('prospect_response', '')[:80]}")
95
+ print(f"{'='*60}")
96
+
97
+ results = {"rewards": [], "violations": [], "turns": 0, "done": False}
98
+
99
+ for action_type, content in script:
100
+ obs = client.step(action_type, content)
101
+ results["rewards"].append(obs["reward"])
102
+ results["violations"].extend(obs.get("constraints_violated", []))
103
+ results["turns"] = obs["turn_number"]
104
+ results["done"] = obs["done"]
105
+
106
+ status = "βœ…" if not obs.get("constraints_violated") else "⚠️"
107
+ print(
108
+ f" {status} Turn {obs['turn_number']:2d} "
109
+ f"{action_type:<20} "
110
+ f"reward={obs['reward']:+.3f} "
111
+ f"violations={obs.get('constraints_violated', [])}"
112
+ )
113
+ if obs["done"]:
114
+ break
115
+
116
+ total = sum(results["rewards"])
117
+ print(f"\n Cumulative reward: {total:+.3f} | Violations: {results['violations']}")
118
+ return results
119
+
120
+
121
+ # ---------------------------------------------------------------------------
122
+ # 4. Bug-regression checks
123
+ # ---------------------------------------------------------------------------
124
+
125
+ def test_r05_no_repeat(client: SalesPathClient):
126
+ """R05: same action twice in a row must fire a violation."""
127
+ print("\n--- BUG CHECK: R05 no-repeat ---")
128
+ client.reset(difficulty=1)
129
+ client.step("PROSPECT", "Hello.")
130
+ obs = client.step("PROSPECT", "Hello again.") # should violate R05
131
+ violated = obs.get("constraints_violated", [])
132
+ ok = "R05" in violated
133
+ print(f" R05 fired on consecutive PROSPECT: {'βœ… PASS' if ok else '❌ FAIL'} {violated}")
134
+ return ok
135
+
136
+
137
+ def test_r07_followup(client: SalesPathClient):
138
+ """R07: FOLLOW_UP after a prospect response should be a violation."""
139
+ print("\n--- BUG CHECK: R07 followup timing ---")
140
+ client.reset(difficulty=1)
141
+ client.step("PROSPECT", "Hello.") # prospect responds positively
142
+ obs = client.step("FOLLOW_UP", "Just checking in.") # violation β€” prospect already replied
143
+ violated = obs.get("constraints_violated", [])
144
+ ok = "R07" in violated
145
+ print(f" R07 fired when prospect already responded: {'βœ… PASS' if ok else '❌ FAIL'} {violated}")
146
+ return ok
147
+
148
+
149
+ def test_r08_disqualify(client: SalesPathClient):
150
+ """R08: DISQUALIFY on a closeable difficulty-1 prospect must be a violation."""
151
+ print("\n--- BUG CHECK: R08 disqualify logic ---")
152
+ client.reset(difficulty=1) # high budget, decision maker present β†’ closeable
153
+ client.step("PROSPECT", "Hello.")
154
+ client.step("QUALIFY", "What's your budget?")
155
+ obs = client.step("DISQUALIFY", "I don't think you're a fit.")
156
+ violated = obs.get("constraints_violated", [])
157
+ ok = "R08" in violated
158
+ print(f" R08 fired on valid prospect: {'βœ… PASS' if ok else '❌ FAIL'} {violated}")
159
+ return ok
160
+
161
+
162
+ # ---------------------------------------------------------------------------
163
+ # 5. Main
164
+ # ---------------------------------------------------------------------------
165
+
166
+ def main():
167
+ PORT = 7860
168
+ server = start_server(PORT)
169
+
170
+ try:
171
+ client = SalesPathClient(f"http://localhost:{PORT}")
172
+
173
+ # Health check
174
+ try:
175
+ h = client.health()
176
+ print(f"βœ… Server healthy: {h}")
177
+ except Exception as e:
178
+ print(f"❌ Server not responding: {e}")
179
+ return
180
+
181
+ # Run all difficulty episodes
182
+ all_rewards = {}
183
+ for diff, script in EPISODES.items():
184
+ result = run_episode(client, diff, script)
185
+ all_rewards[diff] = sum(result["rewards"])
186
+
187
+ # Bug regression suite
188
+ print("\n" + "="*60)
189
+ print(" BUG REGRESSION SUITE")
190
+ print("="*60)
191
+ r05_ok = test_r05_no_repeat(client)
192
+ r07_ok = test_r07_followup(client)
193
+ r08_ok = test_r08_disqualify(client)
194
+
195
+ # Summary
196
+ print("\n" + "="*60)
197
+ print(" SUMMARY")
198
+ print("="*60)
199
+ for diff, total in all_rewards.items():
200
+ print(f" Difficulty {diff}: cumulative reward = {total:+.3f}")
201
+
202
+ bugs_passed = sum([r05_ok, r07_ok, r08_ok])
203
+ print(f"\n Bug fixes passing: {bugs_passed}/3")
204
+ print(f"\n{'βœ… ALL SYSTEMS GO' if bugs_passed == 3 else '⚠️ SOME CHECKS FAILED'}")
205
+
206
+ finally:
207
+ server.terminate()
208
+ print("\nServer stopped.")
209
+
210
+
211
+ if __name__ == "__main__":
212
+ main()