Spaces:
Sleeping
Sleeping
Lomesh2000 commited on
Commit Β·
57eab70
1
Parent(s): 385203e
multi agent environment learning
Browse files- Colab_Training.ipynb +113 -0
- README.md +144 -17
- pyproject.toml +34 -0
- requirements.txt +1 -0
- salespath_env/__pycache__/__init__.cpython-312.pyc +0 -0
- salespath_env/__pycache__/client.cpython-312.pyc +0 -0
- salespath_env/__pycache__/models.cpython-312.pyc +0 -0
- salespath_env/client.py +138 -0
- salespath_env/server/__pycache__/__init__.cpython-312.pyc +0 -0
- salespath_env/server/__pycache__/app.cpython-312.pyc +0 -0
- salespath_env/server/__pycache__/prospect_simulator.cpython-312.pyc +0 -0
- salespath_env/server/__pycache__/reward.cpython-312.pyc +0 -0
- salespath_env/server/__pycache__/rules.cpython-312.pyc +0 -0
- salespath_env/server/__pycache__/salespath_environment.cpython-312.pyc +0 -0
- salespath_env/server/__pycache__/task_bank.cpython-312.pyc +0 -0
- salespath_env/server/app.py +116 -18
- salespath_env/server/prospect_simulator.py +175 -161
- salespath_env/server/rules.py +44 -12
- salespath_env/server/salespath_environment.py +14 -0
- training/__pycache__/plot_rewards.cpython-312.pyc +0 -0
- training/__pycache__/train_grpo.cpython-312.pyc +0 -0
- training/__pycache__/train_sft.cpython-312.pyc +0 -0
- training/__pycache__/train_test.cpython-312.pyc +0 -0
- training/plot_rewards.py +103 -0
- training/sft_demos.jsonl +14 -0
- training/train_grpo.py +388 -0
- training/train_sft.py +172 -0
- training/train_test.py +212 -0
Colab_Training.ipynb
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 0,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"colab": {
|
| 6 |
+
"provenance": []
|
| 7 |
+
},
|
| 8 |
+
"kernelspec": {
|
| 9 |
+
"name": "python3",
|
| 10 |
+
"display_name": "Python 3"
|
| 11 |
+
},
|
| 12 |
+
"language_info": {
|
| 13 |
+
"name": "python"
|
| 14 |
+
}
|
| 15 |
+
},
|
| 16 |
+
"cells": [
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "markdown",
|
| 19 |
+
"source": [
|
| 20 |
+
"# SalesPath: OpenEnv RL Training via GRPO\n",
|
| 21 |
+
"\n",
|
| 22 |
+
"This notebook contains the complete training pipeline for the SalesPath environment. It performs:\n",
|
| 23 |
+
"1. **SFT Warm-start**: Fine-tunes a base model on expert sales demonstrations.\n",
|
| 24 |
+
"2. **GRPO RL**: Uses live rollouts against your hosted environment to optimize the agent.\n",
|
| 25 |
+
"\n",
|
| 26 |
+
"> **CRITICAL:** Before running this, ensure you are using a **T4 GPU** (`Runtime` -> `Change runtime type` -> `Hardware accelerator` -> `T4 GPU`).\n",
|
| 27 |
+
"> \n",
|
| 28 |
+
"> You must also have pushed your environment code to a **Hugging Face Space** so this notebook can interact with it."
|
| 29 |
+
],
|
| 30 |
+
"metadata": {}
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"cell_type": "code",
|
| 34 |
+
"execution_count": null,
|
| 35 |
+
"metadata": {},
|
| 36 |
+
"outputs": [],
|
| 37 |
+
"source": [
|
| 38 |
+
"# 1. Install required dependencies\n",
|
| 39 |
+
"!pip install \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
|
| 40 |
+
"!pip install --no-deps trl peft accelerate bitsandbytes datasets matplotlib openenv-core"
|
| 41 |
+
]
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"cell_type": "code",
|
| 45 |
+
"execution_count": null,
|
| 46 |
+
"metadata": {},
|
| 47 |
+
"outputs": [],
|
| 48 |
+
"source": [
|
| 49 |
+
"# 2. Clone your environment repository from Hugging Face Spaces\n",
|
| 50 |
+
"# β οΈ REPLACE WITH YOUR ACTUAL HF SPACE URL\n",
|
| 51 |
+
"HF_SPACE_URL = \"https://huggingface.co/spaces/YOUR_USERNAME/salespath-env\"\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"import os\n",
|
| 54 |
+
"repo_name = HF_SPACE_URL.split(\"/\")[-1]\n",
|
| 55 |
+
"\n",
|
| 56 |
+
"!git clone {HF_SPACE_URL}\n",
|
| 57 |
+
"os.chdir(repo_name)\n",
|
| 58 |
+
"print(f\"\\nWorking directory changed to: {os.getcwd()}\")"
|
| 59 |
+
]
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"cell_type": "code",
|
| 63 |
+
"execution_count": null,
|
| 64 |
+
"metadata": {},
|
| 65 |
+
"outputs": [],
|
| 66 |
+
"source": [
|
| 67 |
+
"# 3. Run SFT Warm-start (~10-15 minutes)\n",
|
| 68 |
+
"# This trains the model to understand the basic output format and sales flow.\n",
|
| 69 |
+
"!python training/train_sft.py"
|
| 70 |
+
]
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"cell_type": "code",
|
| 74 |
+
"execution_count": null,
|
| 75 |
+
"metadata": {},
|
| 76 |
+
"outputs": [],
|
| 77 |
+
"source": [
|
| 78 |
+
"# 4. Run GRPO Reinforcement Learning (~45-60 minutes)\n",
|
| 79 |
+
"import os\n",
|
| 80 |
+
"\n",
|
| 81 |
+
"# Derive the direct API URL for the Hugging Face space\n",
|
| 82 |
+
"username = HF_SPACE_URL.split(\"/\")[-2]\n",
|
| 83 |
+
"space_name = HF_SPACE_URL.split(\"/\")[-1]\n",
|
| 84 |
+
"direct_url = f\"https://{username}-{space_name}.hf.space\"\n",
|
| 85 |
+
"\n",
|
| 86 |
+
"os.environ[\"SALESPATH_ENV_URL\"] = direct_url\n",
|
| 87 |
+
"os.environ[\"SFT_CHECKPOINT\"] = \"./sft_checkpoint\"\n",
|
| 88 |
+
"\n",
|
| 89 |
+
"print(f\"Targeting Environment API: {direct_url}\")\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"# Run the GRPO training script\n",
|
| 92 |
+
"!python training/train_grpo.py"
|
| 93 |
+
]
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"cell_type": "code",
|
| 97 |
+
"execution_count": null,
|
| 98 |
+
"metadata": {},
|
| 99 |
+
"outputs": [],
|
| 100 |
+
"source": [
|
| 101 |
+
"# 5. Plot the Training Rewards\n",
|
| 102 |
+
"!python training/plot_rewards.py --log ./reward_log.jsonl --out ./plots\n",
|
| 103 |
+
"\n",
|
| 104 |
+
"from IPython.display import Image, display\n",
|
| 105 |
+
"print(\"\\n=== Reward Curve ===\")\n",
|
| 106 |
+
"display(Image(\"./plots/reward_curve.png\"))\n",
|
| 107 |
+
"\n",
|
| 108 |
+
"print(\"\\n=== Reward by Difficulty ===\")\n",
|
| 109 |
+
"display(Image(\"./plots/reward_by_difficulty.png\"))"
|
| 110 |
+
]
|
| 111 |
+
}
|
| 112 |
+
]
|
| 113 |
+
}
|
README.md
CHANGED
|
@@ -7,22 +7,21 @@ sdk: docker
|
|
| 7 |
app_port: 7860
|
| 8 |
pinned: false
|
| 9 |
license: mit
|
| 10 |
-
short_description: RL gym environment for sales
|
| 11 |
---
|
| 12 |
|
| 13 |
-
# SalesPath Environment
|
| 14 |
|
| 15 |
-
A [OpenEnv](https://github.com/openenv)-compatible
|
|
|
|
| 16 |
|
| 17 |
-
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|--------|----------|-------------|
|
| 21 |
-
| `POST` | `/reset` | Reset the environment, returns initial observation |
|
| 22 |
-
| `POST` | `/step` | Take an action, returns next observation + reward |
|
| 23 |
-
| `GET` | `/health` | Health check |
|
| 24 |
|
| 25 |
-
## Quick Start
|
| 26 |
|
| 27 |
### Reset
|
| 28 |
```bash
|
|
@@ -35,13 +34,141 @@ curl -X POST https://imsachin010-salespath-env.hf.space/reset \
|
|
| 35 |
```bash
|
| 36 |
curl -X POST https://imsachin010-salespath-env.hf.space/step \
|
| 37 |
-H "Content-Type: application/json" \
|
| 38 |
-
-d '{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
```
|
| 40 |
|
| 41 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
-
|
| 44 |
-
-
|
| 45 |
-
-
|
| 46 |
-
- `HANDLE_OBJECTION` β Handle prospect objections
|
| 47 |
-
- `CLOSE` β Attempt to close the deal
|
|
|
|
| 7 |
app_port: 7860
|
| 8 |
pinned: false
|
| 9 |
license: mit
|
| 10 |
+
short_description: RL gym environment for training B2B sales agents via GRPO
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# SalesPath β RL Environment for B2B Sales Agents
|
| 14 |
|
| 15 |
+
A [OpenEnv](https://github.com/openenv)-compatible reinforcement learning gym
|
| 16 |
+
that trains an LLM to navigate the full B2B sales process through GRPO.
|
| 17 |
|
| 18 |
+
The agent must learn to qualify leads, handle objections, offer demos,
|
| 19 |
+
negotiate, and close β all while respecting business rules enforced by a
|
| 20 |
+
deterministic rule-based ProspectSimulator (no LLM on the environment side).
|
| 21 |
|
| 22 |
+
---
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
+
## Quick Start (hosted on HF Spaces)
|
| 25 |
|
| 26 |
### Reset
|
| 27 |
```bash
|
|
|
|
| 34 |
```bash
|
| 35 |
curl -X POST https://imsachin010-salespath-env.hf.space/step \
|
| 36 |
-H "Content-Type: application/json" \
|
| 37 |
+
-d '{
|
| 38 |
+
"action": {
|
| 39 |
+
"action_type": "PROSPECT",
|
| 40 |
+
"content": "Hello! I understand you have inventory tracking challenges. Tell me more."
|
| 41 |
+
}
|
| 42 |
+
}'
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
### Health check
|
| 46 |
+
```bash
|
| 47 |
+
curl https://imsachin010-salespath-env.hf.space/health
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
---
|
| 51 |
+
|
| 52 |
+
## Action Space
|
| 53 |
+
|
| 54 |
+
| Action | When to use |
|
| 55 |
+
|---|---|
|
| 56 |
+
| `PROSPECT` | Opening turn only β initial outreach |
|
| 57 |
+
| `QUALIFY` | Uncover budget, decision maker, pain points |
|
| 58 |
+
| `PRESENT` | Pitch the solution (requires QUALIFY first) |
|
| 59 |
+
| `HANDLE_OBJECTION` | Respond to pricing / timing objections |
|
| 60 |
+
| `OFFER_DEMO` | Schedule a live product demo |
|
| 61 |
+
| `NEGOTIATE` | Discuss pricing/terms (requires OFFER_DEMO + known budget) |
|
| 62 |
+
| `CLOSE` | Attempt to sign the deal |
|
| 63 |
+
| `FOLLOW_UP` | Re-engage after prospect silence |
|
| 64 |
+
| `DISQUALIFY` | End the conversation (correct only if budget < threshold AND no decision maker) |
|
| 65 |
+
|
| 66 |
+
---
|
| 67 |
+
|
| 68 |
+
## Business Rules Enforced
|
| 69 |
+
|
| 70 |
+
| Rule | Description |
|
| 71 |
+
|---|---|
|
| 72 |
+
| R01 | Must QUALIFY before PRESENT |
|
| 73 |
+
| R02 | Must OFFER_DEMO before NEGOTIATE |
|
| 74 |
+
| R03 | Cannot NEGOTIATE while budget is unknown |
|
| 75 |
+
| R04 | Discount in NEGOTIATE only after 2 objections handled |
|
| 76 |
+
| R05 | Cannot repeat the same action on consecutive turns |
|
| 77 |
+
| R06 | First action must be PROSPECT |
|
| 78 |
+
| R07 | FOLLOW_UP only valid after prospect silence (no response for 1+ turns) |
|
| 79 |
+
| R08 | DISQUALIFY valid only when budget < threshold AND no decision maker |
|
| 80 |
+
| R09 | Must OFFER_DEMO before CLOSE (difficulty 2+) |
|
| 81 |
+
|
| 82 |
+
---
|
| 83 |
+
|
| 84 |
+
## Reward Function
|
| 85 |
+
|
| 86 |
+
Composite weighted reward computed every step:
|
| 87 |
+
|
| 88 |
+
| Component | Weight | Description |
|
| 89 |
+
|---|---|---|
|
| 90 |
+
| `r_outcome` | 0.40 | +1.0 on successful close, +0.5 on valid DISQUALIFY, -0.5 on bad close |
|
| 91 |
+
| `r_compliance` | 0.30 | -0.2 per rule violation this turn |
|
| 92 |
+
| `r_ordering` | 0.15 | Fraction of workflow steps completed in correct order |
|
| 93 |
+
| `r_efficiency` | 0.10 | Penalty for turns beyond the optimal episode length |
|
| 94 |
+
| `r_format` | 0.05 | +1.0 for valid action type, -0.1 for invalid |
|
| 95 |
+
|
| 96 |
+
---
|
| 97 |
+
|
| 98 |
+
## Difficulty Levels
|
| 99 |
+
|
| 100 |
+
| Level | Description | Correct terminal action |
|
| 101 |
+
|---|---|---|
|
| 102 |
+
| 1 | Budget known, decision maker present, easy close | CLOSE |
|
| 103 |
+
| 2 | Budget hidden, 1 objection, demo required | CLOSE |
|
| 104 |
+
| 3 | Budget hidden, 2 objections, stalling prospect | CLOSE |
|
| 105 |
+
| 4 | Misleading signals, low budget, no decision maker | DISQUALIFY |
|
| 106 |
+
|
| 107 |
+
---
|
| 108 |
+
|
| 109 |
+
## Training Pipeline
|
| 110 |
+
|
| 111 |
+
```
|
| 112 |
+
sft_demos.jsonl (14 expert demos)
|
| 113 |
+
β
|
| 114 |
+
train_sft.py β SFT warm-start (SFTTrainer, TRL)
|
| 115 |
+
β
|
| 116 |
+
sft_checkpoint/
|
| 117 |
+
β
|
| 118 |
+
train_grpo.py β GRPO RL fine-tuning (GRPOTrainer, TRL + Unsloth 4-bit)
|
| 119 |
+
β
|
| 120 |
+
grpo_checkpoint/ + reward_log.jsonl
|
| 121 |
+
β
|
| 122 |
+
plot_rewards.py β reward curves
|
| 123 |
```
|
| 124 |
|
| 125 |
+
### Commands
|
| 126 |
+
|
| 127 |
+
```bash
|
| 128 |
+
# 1. Smoke test (no GPU, ~30 seconds)
|
| 129 |
+
python training/train_test.py
|
| 130 |
+
|
| 131 |
+
# 2. SFT warm-start (~15 min on T4)
|
| 132 |
+
python training/train_sft.py
|
| 133 |
+
|
| 134 |
+
# 3. Full GRPO training (~60 min on T4)
|
| 135 |
+
uvicorn salespath_env.server.app:app --port 7860 &
|
| 136 |
+
python training/train_grpo.py
|
| 137 |
+
|
| 138 |
+
# 4. Plot reward curves
|
| 139 |
+
python training/plot_rewards.py
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
---
|
| 143 |
+
|
| 144 |
+
## File Structure
|
| 145 |
+
|
| 146 |
+
```
|
| 147 |
+
salespath-env/
|
| 148 |
+
βββ salespath_env/
|
| 149 |
+
β βββ client.py β HTTP client for training scripts
|
| 150 |
+
β βββ models.py β SalesPathAction / Observation / State
|
| 151 |
+
β βββ server/
|
| 152 |
+
β βββ app.py β FastAPI app (OpenEnv)
|
| 153 |
+
β βββ salespath_environment.py
|
| 154 |
+
β βββ prospect_simulator.py β Rule-based, no LLM
|
| 155 |
+
β βββ rules.py β 9 business rules (R01βR09)
|
| 156 |
+
β βββ reward.py β 5-component reward function
|
| 157 |
+
β βββ task_bank.py β Prospect profiles (4 difficulty levels)
|
| 158 |
+
βββ training/
|
| 159 |
+
β βββ sft_demos.jsonl β Expert demonstration data
|
| 160 |
+
β βββ train_test.py β Smoke test (no GPU)
|
| 161 |
+
β βββ train_sft.py β SFT warm-start
|
| 162 |
+
β βββ train_grpo.py β GRPO RL training
|
| 163 |
+
β βββ plot_rewards.py β Reward curve visualisation
|
| 164 |
+
βββ Dockerfile
|
| 165 |
+
βββ requirements.txt
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
---
|
| 169 |
+
|
| 170 |
+
## Links
|
| 171 |
|
| 172 |
+
- π Blog post: _[add HuggingFace blog link here]_
|
| 173 |
+
- π₯ Demo video: _[add YouTube link here]_
|
| 174 |
+
- π€ HF Space: https://huggingface.co/spaces/imsachin010/salespath-env
|
|
|
|
|
|
pyproject.toml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "salespath-env"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "OpenEnv RL environment for training B2B sales agents via GRPO"
|
| 5 |
+
requires-python = ">=3.10"
|
| 6 |
+
license = { text = "MIT" }
|
| 7 |
+
|
| 8 |
+
dependencies = [
|
| 9 |
+
"openenv-core>=0.2.3",
|
| 10 |
+
"fastapi>=0.110.0",
|
| 11 |
+
"uvicorn[standard]>=0.29.0",
|
| 12 |
+
"pydantic>=2.0",
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
[project.optional-dependencies]
|
| 16 |
+
training = [
|
| 17 |
+
"trl>=0.8.6",
|
| 18 |
+
"transformers>=4.40.0",
|
| 19 |
+
"datasets>=2.18.0",
|
| 20 |
+
"peft>=0.10.0",
|
| 21 |
+
"bitsandbytes>=0.43.0",
|
| 22 |
+
"accelerate>=0.28.0",
|
| 23 |
+
"torch>=2.2.0",
|
| 24 |
+
"matplotlib>=3.8.0",
|
| 25 |
+
"unsloth",
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
[build-system]
|
| 29 |
+
requires = ["setuptools>=68", "wheel"]
|
| 30 |
+
build-backend = "setuptools.backends.legacy:build"
|
| 31 |
+
|
| 32 |
+
[tool.setuptools.packages.find]
|
| 33 |
+
where = ["."]
|
| 34 |
+
include = ["salespath_env*"]
|
requirements.txt
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
fastapi>=0.110.0
|
| 2 |
uvicorn[standard]>=0.29.0
|
| 3 |
pydantic>=2.0
|
|
|
|
| 1 |
+
# Environment server (used by Dockerfile)
|
| 2 |
fastapi>=0.110.0
|
| 3 |
uvicorn[standard]>=0.29.0
|
| 4 |
pydantic>=2.0
|
salespath_env/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (178 Bytes). View file
|
|
|
salespath_env/__pycache__/client.cpython-312.pyc
ADDED
|
Binary file (5.72 kB). View file
|
|
|
salespath_env/__pycache__/models.cpython-312.pyc
ADDED
|
Binary file (3.15 kB). View file
|
|
|
salespath_env/client.py
CHANGED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# salespath_env/client.py
|
| 2 |
+
"""
|
| 3 |
+
HTTP client for the SalesPath environment.
|
| 4 |
+
Used by training scripts to talk to the hosted FastAPI server.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import requests
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SalesPathClient:
|
| 13 |
+
"""
|
| 14 |
+
Thin wrapper around the /reset and /step HTTP endpoints.
|
| 15 |
+
|
| 16 |
+
Example
|
| 17 |
+
-------
|
| 18 |
+
>>> client = SalesPathClient("http://localhost:7860")
|
| 19 |
+
>>> obs = client.reset(difficulty=1)
|
| 20 |
+
>>> obs = client.step("PROSPECT", "Hi, tell me about your pain points.")
|
| 21 |
+
>>> print(obs["reward"])
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, base_url: str = "http://localhost:7860"):
|
| 25 |
+
self.base_url = base_url.rstrip("/")
|
| 26 |
+
self._session = requests.Session()
|
| 27 |
+
|
| 28 |
+
# ------------------------------------------------------------------
|
| 29 |
+
# Core API
|
| 30 |
+
# ------------------------------------------------------------------
|
| 31 |
+
|
| 32 |
+
def reset(self, difficulty: int = 1) -> dict:
|
| 33 |
+
"""
|
| 34 |
+
Reset the environment for a new episode.
|
| 35 |
+
|
| 36 |
+
OpenEnv /reset returns the raw observation dict.
|
| 37 |
+
Returns a flat dict with all observation fields.
|
| 38 |
+
"""
|
| 39 |
+
resp = self._session.post(
|
| 40 |
+
f"{self.base_url}/reset",
|
| 41 |
+
json={"difficulty": difficulty},
|
| 42 |
+
timeout=30,
|
| 43 |
+
)
|
| 44 |
+
resp.raise_for_status()
|
| 45 |
+
data = resp.json()
|
| 46 |
+
# /reset may return raw observation or wrapped {observation:{...}}
|
| 47 |
+
if "observation" in data:
|
| 48 |
+
flat = dict(data["observation"])
|
| 49 |
+
flat.setdefault("reward", data.get("reward", 0.0))
|
| 50 |
+
flat.setdefault("done", data.get("done", False))
|
| 51 |
+
return flat
|
| 52 |
+
return data
|
| 53 |
+
|
| 54 |
+
def step(
|
| 55 |
+
self,
|
| 56 |
+
action_type: str,
|
| 57 |
+
content: str = "",
|
| 58 |
+
target: str = "",
|
| 59 |
+
) -> dict:
|
| 60 |
+
"""
|
| 61 |
+
Take one action in the environment.
|
| 62 |
+
|
| 63 |
+
OpenEnv /step returns {observation:{...}, reward:float, done:bool}.
|
| 64 |
+
This method flattens it so callers get a single dict with all
|
| 65 |
+
observation fields plus reward and done at the top level.
|
| 66 |
+
|
| 67 |
+
Returns
|
| 68 |
+
-------
|
| 69 |
+
dict with keys:
|
| 70 |
+
prospect_response, workflow_stage, constraints_violated,
|
| 71 |
+
steps_completed, turn_number, reward, reward_components,
|
| 72 |
+
done, info
|
| 73 |
+
"""
|
| 74 |
+
resp = self._session.post(
|
| 75 |
+
f"{self.base_url}/step",
|
| 76 |
+
json={
|
| 77 |
+
"action": {
|
| 78 |
+
"action_type": action_type,
|
| 79 |
+
"content": content,
|
| 80 |
+
"target": target,
|
| 81 |
+
}
|
| 82 |
+
},
|
| 83 |
+
timeout=30,
|
| 84 |
+
)
|
| 85 |
+
resp.raise_for_status()
|
| 86 |
+
data = resp.json()
|
| 87 |
+
# Flatten: {observation:{...}, reward, done} β one flat dict
|
| 88 |
+
if "observation" in data:
|
| 89 |
+
flat = dict(data["observation"])
|
| 90 |
+
flat["reward"] = data.get("reward", flat.get("reward", 0.0))
|
| 91 |
+
flat["done"] = data.get("done", flat.get("done", False))
|
| 92 |
+
return flat
|
| 93 |
+
return data
|
| 94 |
+
|
| 95 |
+
def health(self) -> dict:
|
| 96 |
+
resp = self._session.get(f"{self.base_url}/health", timeout=10)
|
| 97 |
+
resp.raise_for_status()
|
| 98 |
+
return resp.json()
|
| 99 |
+
|
| 100 |
+
# ------------------------------------------------------------------
|
| 101 |
+
# Convenience: run a full hard-coded demo episode
|
| 102 |
+
# ------------------------------------------------------------------
|
| 103 |
+
|
| 104 |
+
def run_demo_episode(self, difficulty: int = 1, verbose: bool = True) -> float:
|
| 105 |
+
"""
|
| 106 |
+
Run one scripted episode and return total cumulative reward.
|
| 107 |
+
Useful for smoke-testing the server end-to-end.
|
| 108 |
+
"""
|
| 109 |
+
obs = self.reset(difficulty)
|
| 110 |
+
if verbose:
|
| 111 |
+
print(f"\n=== Episode start (difficulty={difficulty}) ===")
|
| 112 |
+
print(f"Prospect: {obs.get('prospect_response', '')}\n")
|
| 113 |
+
|
| 114 |
+
# Scripted optimal sequence for difficulty 1
|
| 115 |
+
script = [
|
| 116 |
+
("PROSPECT", "Hello! I'd love to learn about your current challenges."),
|
| 117 |
+
("QUALIFY", "Can you tell me about your budget and decision process?"),
|
| 118 |
+
("PRESENT", "Here's how our platform solves your inventory problem."),
|
| 119 |
+
("CLOSE", "Based on everything, shall we move forward?"),
|
| 120 |
+
]
|
| 121 |
+
|
| 122 |
+
total_reward = 0.0
|
| 123 |
+
for action_type, content in script:
|
| 124 |
+
obs = self.step(action_type, content)
|
| 125 |
+
total_reward += obs.get("reward", 0.0)
|
| 126 |
+
if verbose:
|
| 127 |
+
print(f"[Turn {obs['turn_number']}] Agent: {action_type}")
|
| 128 |
+
print(f" Prospect: {obs['prospect_response']}")
|
| 129 |
+
print(f" Reward: {obs['reward']:.3f} | Done: {obs['done']}")
|
| 130 |
+
if obs.get("constraints_violated"):
|
| 131 |
+
print(f" β Violations: {obs['constraints_violated']}")
|
| 132 |
+
print()
|
| 133 |
+
if obs["done"]:
|
| 134 |
+
break
|
| 135 |
+
|
| 136 |
+
if verbose:
|
| 137 |
+
print(f"=== Episode done. Cumulative reward: {total_reward:.3f} ===\n")
|
| 138 |
+
return total_reward
|
salespath_env/server/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (173 Bytes). View file
|
|
|
salespath_env/server/__pycache__/app.cpython-312.pyc
ADDED
|
Binary file (4.34 kB). View file
|
|
|
salespath_env/server/__pycache__/prospect_simulator.cpython-312.pyc
ADDED
|
Binary file (5.03 kB). View file
|
|
|
salespath_env/server/__pycache__/reward.cpython-312.pyc
ADDED
|
Binary file (3.01 kB). View file
|
|
|
salespath_env/server/__pycache__/rules.cpython-312.pyc
ADDED
|
Binary file (7.36 kB). View file
|
|
|
salespath_env/server/__pycache__/salespath_environment.cpython-312.pyc
ADDED
|
Binary file (6.62 kB). View file
|
|
|
salespath_env/server/__pycache__/task_bank.cpython-312.pyc
ADDED
|
Binary file (2.66 kB). View file
|
|
|
salespath_env/server/app.py
CHANGED
|
@@ -1,18 +1,116 @@
|
|
| 1 |
-
# salespath_env/server/app.py
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# salespath_env/server/app.py
|
| 2 |
+
"""
|
| 3 |
+
Custom stateful FastAPI server for SalesPath.
|
| 4 |
+
|
| 5 |
+
Why not create_fastapi_app?
|
| 6 |
+
OpenEnv's built-in HTTP /reset and /step endpoints are STATELESS β
|
| 7 |
+
they create a new Environment instance per request and destroy it.
|
| 8 |
+
State is preserved only over WebSocket sessions.
|
| 9 |
+
|
| 10 |
+
For our training loop (HTTP polling), we need a persistent environment
|
| 11 |
+
that survives across /reset + multiple /step calls. This file provides
|
| 12 |
+
that by keeping a single global SalesPathEnvironment instance.
|
| 13 |
+
|
| 14 |
+
The response envelope matches OpenEnv exactly:
|
| 15 |
+
{ "observation": {...}, "reward": float, "done": bool }
|
| 16 |
+
so all existing clients work without changes.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from typing import Any, Dict, Optional
|
| 20 |
+
|
| 21 |
+
from fastapi import FastAPI
|
| 22 |
+
from pydantic import BaseModel
|
| 23 |
+
|
| 24 |
+
from ..models import SalesPathAction
|
| 25 |
+
from .salespath_environment import SalesPathEnvironment
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
# Single persistent environment instance
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
|
| 32 |
+
_env: SalesPathEnvironment = SalesPathEnvironment()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
# Request models
|
| 37 |
+
# ---------------------------------------------------------------------------
|
| 38 |
+
|
| 39 |
+
class ResetRequest(BaseModel):
|
| 40 |
+
difficulty: int = 1
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ActionPayload(BaseModel):
|
| 44 |
+
action_type: str
|
| 45 |
+
content: str = ""
|
| 46 |
+
target: str = ""
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class StepRequest(BaseModel):
|
| 50 |
+
action: ActionPayload
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
# FastAPI app
|
| 55 |
+
# ---------------------------------------------------------------------------
|
| 56 |
+
|
| 57 |
+
app = FastAPI(
|
| 58 |
+
title="SalesPath Environment",
|
| 59 |
+
description="OpenEnv-compatible RL environment for B2B sales agent training.",
|
| 60 |
+
version="0.1.0",
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@app.post("/reset")
|
| 65 |
+
def reset(req: ResetRequest = ResetRequest()):
|
| 66 |
+
"""
|
| 67 |
+
Start a new episode.
|
| 68 |
+
Resets the environment and returns the initial observation.
|
| 69 |
+
"""
|
| 70 |
+
obs = _env.reset(difficulty=req.difficulty)
|
| 71 |
+
return {
|
| 72 |
+
"observation": obs.model_dump(),
|
| 73 |
+
"reward": obs.reward,
|
| 74 |
+
"done": obs.done,
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@app.post("/step")
|
| 79 |
+
def step(req: StepRequest):
|
| 80 |
+
"""
|
| 81 |
+
Take one action in the current episode.
|
| 82 |
+
Returns the next observation, reward, and done flag.
|
| 83 |
+
"""
|
| 84 |
+
action = SalesPathAction(
|
| 85 |
+
action_type=req.action.action_type,
|
| 86 |
+
content=req.action.content,
|
| 87 |
+
target=req.action.target,
|
| 88 |
+
)
|
| 89 |
+
obs = _env.step(action)
|
| 90 |
+
return {
|
| 91 |
+
"observation": obs.model_dump(),
|
| 92 |
+
"reward": obs.reward,
|
| 93 |
+
"done": obs.done,
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@app.get("/health")
|
| 98 |
+
def health():
|
| 99 |
+
return {"status": "healthy"}
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@app.get("/state")
|
| 103 |
+
def state():
|
| 104 |
+
"""Expose internal state (for debugging). Hidden state excluded."""
|
| 105 |
+
s = _env.state
|
| 106 |
+
return {
|
| 107 |
+
"episode_id": s.episode_id,
|
| 108 |
+
"turn_number": s.turn_number,
|
| 109 |
+
"workflow_stage": s.workflow_stage,
|
| 110 |
+
"steps_completed": s.steps_completed,
|
| 111 |
+
"constraints_violated": s.constraints_violated,
|
| 112 |
+
"objections_handled": s.objections_handled,
|
| 113 |
+
"difficulty": s.difficulty,
|
| 114 |
+
"done": s.done,
|
| 115 |
+
"prospect_profile": s.prospect_profile,
|
| 116 |
+
}
|
salespath_env/server/prospect_simulator.py
CHANGED
|
@@ -1,162 +1,176 @@
|
|
| 1 |
-
# salespath_env/server/prospect_simulator.py
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
"
|
| 11 |
-
|
| 12 |
-
"objection:
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
"
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
"""
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
def
|
| 66 |
-
self,
|
| 67 |
-
action: SalesPathAction,
|
| 68 |
-
state: SalesPathState,
|
| 69 |
-
) -> str:
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
#
|
| 91 |
-
#
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
if
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
return "open:neutral_signal"
|
|
|
|
| 1 |
+
# salespath_env/server/prospect_simulator.py
|
| 2 |
+
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
from ..models import SalesPathAction, SalesPathState
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
RESPONSE_TEXT = {
|
| 9 |
+
"open:positive_signal": "That sounds interesting. Tell me more about how this works.",
|
| 10 |
+
"open:neutral_signal": "I see. We're evaluating a few options at the moment.",
|
| 11 |
+
|
| 12 |
+
"objection:price": "The pricing seems higher than what we budgeted for.",
|
| 13 |
+
"objection:timing": "The timing isn't ideal β we're in the middle of a quarter close.",
|
| 14 |
+
"objection:premature_pitch": (
|
| 15 |
+
"I'm not sure we're ready to discuss solutions yet. "
|
| 16 |
+
"What do you know about our current situation?"
|
| 17 |
+
),
|
| 18 |
+
|
| 19 |
+
"deflect:budget_not_discussed": (
|
| 20 |
+
"We haven't really talked about what we're looking for yet."
|
| 21 |
+
),
|
| 22 |
+
"deflect:stall": (
|
| 23 |
+
"Let me get back to you on this. A lot is happening on our end."
|
| 24 |
+
),
|
| 25 |
+
|
| 26 |
+
"accept:demo_scheduled": (
|
| 27 |
+
"Yes, let's set up a demo. What time works next week?"
|
| 28 |
+
),
|
| 29 |
+
"accept:close_success": (
|
| 30 |
+
"Alright, I think we can move forward with this. "
|
| 31 |
+
"Send over the paperwork."
|
| 32 |
+
),
|
| 33 |
+
|
| 34 |
+
"reject:close_failed": (
|
| 35 |
+
"I don't think we're ready to commit at this point."
|
| 36 |
+
),
|
| 37 |
+
|
| 38 |
+
"silence": "",
|
| 39 |
+
|
| 40 |
+
"exit:disqualified": (
|
| 41 |
+
"I think we're done here. This isn't the right fit."
|
| 42 |
+
),
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
# Prefix injected into QUALIFY response to reveal budget signal
|
| 46 |
+
# without mutating prospect_profile (immutable prospect state).
|
| 47 |
+
BUDGET_REVEAL_TEXT = {
|
| 48 |
+
"high": "We do have solid budget allocated for this initiative. ",
|
| 49 |
+
"medium": "We have some budget set aside, though flexibility is limited. ",
|
| 50 |
+
"low": "Our budget is quite constrained right now. ",
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class ProspectSimulator:
|
| 55 |
+
"""
|
| 56 |
+
Pure rule-based simulator. No LLM. No transformers.
|
| 57 |
+
Deterministic per action type.
|
| 58 |
+
|
| 59 |
+
Immutability guarantee:
|
| 60 |
+
This class NEVER mutates state.prospect_profile.
|
| 61 |
+
Budget reveal is surfaced via the response *text* only.
|
| 62 |
+
The environment (salespath_environment.py) owns all state writes.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def respond(
|
| 66 |
+
self,
|
| 67 |
+
action: SalesPathAction,
|
| 68 |
+
state: SalesPathState,
|
| 69 |
+
) -> tuple[str, str]:
|
| 70 |
+
"""
|
| 71 |
+
Returns:
|
| 72 |
+
(response_token, response_text)
|
| 73 |
+
"""
|
| 74 |
+
token = self._get_token(action, state)
|
| 75 |
+
text = self._build_text(token, action, state)
|
| 76 |
+
return token, text
|
| 77 |
+
|
| 78 |
+
# ------------------------------------------------------------------
|
| 79 |
+
# Private helpers
|
| 80 |
+
# ------------------------------------------------------------------
|
| 81 |
+
|
| 82 |
+
def _build_text(
|
| 83 |
+
self,
|
| 84 |
+
token: str,
|
| 85 |
+
action: SalesPathAction,
|
| 86 |
+
state: SalesPathState,
|
| 87 |
+
) -> str:
|
| 88 |
+
base = RESPONSE_TEXT[token]
|
| 89 |
+
|
| 90 |
+
# Inject budget reveal into QUALIFY response text.
|
| 91 |
+
# We read from hidden_state, not prospect_profile, so no mutation needed.
|
| 92 |
+
if action.action_type == "QUALIFY":
|
| 93 |
+
budget_signal = state.prospect_profile.get("budget_signal", "unknown")
|
| 94 |
+
if budget_signal == "unknown":
|
| 95 |
+
revealed = state.hidden_state.get("revealed_budget", "medium")
|
| 96 |
+
prefix = BUDGET_REVEAL_TEXT.get(revealed, "")
|
| 97 |
+
return prefix + base
|
| 98 |
+
|
| 99 |
+
return base
|
| 100 |
+
|
| 101 |
+
def _get_token(
|
| 102 |
+
self,
|
| 103 |
+
action: SalesPathAction,
|
| 104 |
+
state: SalesPathState,
|
| 105 |
+
) -> str:
|
| 106 |
+
atype = action.action_type
|
| 107 |
+
difficulty = state.difficulty
|
| 108 |
+
turn = state.turn_number
|
| 109 |
+
profile = state.prospect_profile
|
| 110 |
+
hidden = state.hidden_state
|
| 111 |
+
objections = state.objections_handled
|
| 112 |
+
|
| 113 |
+
# --------------------------------------------------
|
| 114 |
+
# 1. Rule-violation responses (highest priority)
|
| 115 |
+
# --------------------------------------------------
|
| 116 |
+
if state.constraints_violated:
|
| 117 |
+
latest = state.constraints_violated[-1]
|
| 118 |
+
if latest == "R01":
|
| 119 |
+
return "objection:premature_pitch"
|
| 120 |
+
if latest == "R03":
|
| 121 |
+
return "deflect:budget_not_discussed"
|
| 122 |
+
|
| 123 |
+
# --------------------------------------------------
|
| 124 |
+
# 2. Stall injection for difficulty 3+
|
| 125 |
+
# FIX: moved BEFORE action branches so it can
|
| 126 |
+
# actually fire (was dead code in original).
|
| 127 |
+
# --------------------------------------------------
|
| 128 |
+
if difficulty >= 3 and turn >= 5:
|
| 129 |
+
stall_prob = hidden.get("stall_probability", 0.0)
|
| 130 |
+
if stall_prob > 0.0 and random.random() < stall_prob:
|
| 131 |
+
return "deflect:stall"
|
| 132 |
+
|
| 133 |
+
# --------------------------------------------------
|
| 134 |
+
# 3. Action-based deterministic responses
|
| 135 |
+
# --------------------------------------------------
|
| 136 |
+
if atype == "PROSPECT":
|
| 137 |
+
return "open:positive_signal"
|
| 138 |
+
|
| 139 |
+
if atype == "QUALIFY":
|
| 140 |
+
return "open:neutral_signal"
|
| 141 |
+
|
| 142 |
+
if atype == "PRESENT":
|
| 143 |
+
if difficulty >= 2 and objections == 0:
|
| 144 |
+
return "objection:price"
|
| 145 |
+
return "open:positive_signal"
|
| 146 |
+
|
| 147 |
+
if atype == "HANDLE_OBJECTION":
|
| 148 |
+
state.objections_handled += 1 # only non-profile mutation
|
| 149 |
+
required = hidden.get("num_objections", 1)
|
| 150 |
+
if state.objections_handled >= required:
|
| 151 |
+
return "open:positive_signal"
|
| 152 |
+
if objections == 0:
|
| 153 |
+
return "objection:timing"
|
| 154 |
+
return "open:positive_signal"
|
| 155 |
+
|
| 156 |
+
if atype == "OFFER_DEMO":
|
| 157 |
+
return "accept:demo_scheduled"
|
| 158 |
+
|
| 159 |
+
if atype == "NEGOTIATE":
|
| 160 |
+
return "open:neutral_signal"
|
| 161 |
+
|
| 162 |
+
if atype == "CLOSE":
|
| 163 |
+
true_budget = hidden.get("true_budget", 0.7)
|
| 164 |
+
close_threshold = hidden.get("close_threshold", 0.5)
|
| 165 |
+
decision_maker = profile.get("decision_maker", True)
|
| 166 |
+
if true_budget >= close_threshold and decision_maker:
|
| 167 |
+
return "accept:close_success"
|
| 168 |
+
return "reject:close_failed"
|
| 169 |
+
|
| 170 |
+
if atype == "FOLLOW_UP":
|
| 171 |
+
return "open:neutral_signal"
|
| 172 |
+
|
| 173 |
+
if atype == "DISQUALIFY":
|
| 174 |
+
return "exit:disqualified"
|
| 175 |
+
|
| 176 |
return "open:neutral_signal"
|
salespath_env/server/rules.py
CHANGED
|
@@ -78,10 +78,15 @@ def _no_repeat_action(
|
|
| 78 |
"""
|
| 79 |
R05:
|
| 80 |
Same action twice in a row is invalid.
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
return False
|
| 86 |
|
| 87 |
|
|
@@ -104,13 +109,35 @@ def _followup_timing(
|
|
| 104 |
) -> bool:
|
| 105 |
"""
|
| 106 |
R07:
|
| 107 |
-
FOLLOW_UP only valid after silence.
|
| 108 |
-
|
|
|
|
| 109 |
"""
|
| 110 |
if action.action_type == "FOLLOW_UP":
|
| 111 |
-
if state.conversation_history:
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
return False
|
| 115 |
|
| 116 |
|
|
@@ -120,15 +147,20 @@ def _disqualify_logic(
|
|
| 120 |
) -> bool:
|
| 121 |
"""
|
| 122 |
R08:
|
| 123 |
-
DISQUALIFY
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
| 125 |
"""
|
| 126 |
if action.action_type == "DISQUALIFY":
|
| 127 |
true_budget = state.hidden_state.get("true_budget", 0.5)
|
| 128 |
close_threshold = state.hidden_state.get("close_threshold", 0.5)
|
| 129 |
decision_maker = state.prospect_profile.get("decision_maker", True)
|
| 130 |
|
| 131 |
-
|
|
|
|
|
|
|
| 132 |
|
| 133 |
return False
|
| 134 |
|
|
|
|
| 78 |
"""
|
| 79 |
R05:
|
| 80 |
Same action twice in a row is invalid.
|
| 81 |
+
FIX: conversation_history alternates agent/prospect entries.
|
| 82 |
+
Must filter to agent-only turns before comparing.
|
| 83 |
+
"""
|
| 84 |
+
agent_turns = [
|
| 85 |
+
e for e in state.conversation_history
|
| 86 |
+
if e.get("speaker") == "agent"
|
| 87 |
+
]
|
| 88 |
+
if agent_turns:
|
| 89 |
+
return agent_turns[-1].get("action_type", "") == action.action_type
|
| 90 |
return False
|
| 91 |
|
| 92 |
|
|
|
|
| 109 |
) -> bool:
|
| 110 |
"""
|
| 111 |
R07:
|
| 112 |
+
FOLLOW_UP only valid after prospect silence (no response for 1+ agent turns).
|
| 113 |
+
Violation if the prospect HAS replied since the last agent action.
|
| 114 |
+
FIX: Previous logic was inverted β it was blocking valid FOLLOW_UP.
|
| 115 |
"""
|
| 116 |
if action.action_type == "FOLLOW_UP":
|
| 117 |
+
if not state.conversation_history:
|
| 118 |
+
return True # Nothing happened yet β FOLLOW_UP makes no sense
|
| 119 |
+
|
| 120 |
+
agent_turns = [
|
| 121 |
+
e for e in state.conversation_history
|
| 122 |
+
if e.get("speaker") == "agent"
|
| 123 |
+
]
|
| 124 |
+
prospect_turns = [
|
| 125 |
+
e for e in state.conversation_history
|
| 126 |
+
if e.get("speaker") == "prospect"
|
| 127 |
+
]
|
| 128 |
+
|
| 129 |
+
if not agent_turns:
|
| 130 |
+
return True
|
| 131 |
+
|
| 132 |
+
last_agent_turn_num = agent_turns[-1]["turn"]
|
| 133 |
+
last_prospect_turn_num = max(
|
| 134 |
+
(e["turn"] for e in prospect_turns),
|
| 135 |
+
default=0,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Violation if prospect already responded AFTER the last agent turn
|
| 139 |
+
return last_prospect_turn_num >= last_agent_turn_num
|
| 140 |
+
|
| 141 |
return False
|
| 142 |
|
| 143 |
|
|
|
|
| 147 |
) -> bool:
|
| 148 |
"""
|
| 149 |
R08:
|
| 150 |
+
DISQUALIFY is correct ONLY when:
|
| 151 |
+
- true_budget < close_threshold AND
|
| 152 |
+
- decision_maker is False
|
| 153 |
+
Violation if prospect is actually closeable OR has a decision maker.
|
| 154 |
+
FIX: Both conditions must hold for a valid disqualification.
|
| 155 |
"""
|
| 156 |
if action.action_type == "DISQUALIFY":
|
| 157 |
true_budget = state.hidden_state.get("true_budget", 0.5)
|
| 158 |
close_threshold = state.hidden_state.get("close_threshold", 0.5)
|
| 159 |
decision_maker = state.prospect_profile.get("decision_maker", True)
|
| 160 |
|
| 161 |
+
# Valid disqualify requires: low budget AND no decision maker
|
| 162 |
+
valid_disqualify = (true_budget < close_threshold) and (not decision_maker)
|
| 163 |
+
return not valid_disqualify # Violation if NOT a valid disqualify case
|
| 164 |
|
| 165 |
return False
|
| 166 |
|
salespath_env/server/salespath_environment.py
CHANGED
|
@@ -220,6 +220,20 @@ class SalesPathEnvironment(Environment):
|
|
| 220 |
)
|
| 221 |
)
|
| 222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
state.conversation_history.append(
|
| 224 |
{
|
| 225 |
"turn": state.turn_number,
|
|
|
|
| 220 |
)
|
| 221 |
)
|
| 222 |
|
| 223 |
+
# -----------------------------------
|
| 224 |
+
# Budget reveal (env owns state write)
|
| 225 |
+
# Simulator surfaced the info via text;
|
| 226 |
+
# now we update prospect_profile so rules
|
| 227 |
+
# (e.g. R03) can see the revealed value.
|
| 228 |
+
# -----------------------------------
|
| 229 |
+
if (
|
| 230 |
+
action.action_type == "QUALIFY"
|
| 231 |
+
and state.prospect_profile.get("budget_signal") == "unknown"
|
| 232 |
+
):
|
| 233 |
+
state.prospect_profile["budget_signal"] = (
|
| 234 |
+
state.hidden_state.get("revealed_budget", "medium")
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
state.conversation_history.append(
|
| 238 |
{
|
| 239 |
"turn": state.turn_number,
|
training/__pycache__/plot_rewards.cpython-312.pyc
ADDED
|
Binary file (5.92 kB). View file
|
|
|
training/__pycache__/train_grpo.cpython-312.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
training/__pycache__/train_sft.cpython-312.pyc
ADDED
|
Binary file (5.39 kB). View file
|
|
|
training/__pycache__/train_test.cpython-312.pyc
ADDED
|
Binary file (8.78 kB). View file
|
|
|
training/plot_rewards.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
plot_rewards.py β Visualise GRPO training progress
|
| 3 |
+
====================================================
|
| 4 |
+
Reads reward_log.jsonl written by train_grpo.py and
|
| 5 |
+
produces two plots:
|
| 6 |
+
|
| 7 |
+
1. Mean reward per step (with min/max band)
|
| 8 |
+
2. Reward by difficulty level
|
| 9 |
+
|
| 10 |
+
Run:
|
| 11 |
+
python training/plot_rewards.py
|
| 12 |
+
python training/plot_rewards.py --log ./reward_log.jsonl --out ./plots/
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
from collections import defaultdict
|
| 19 |
+
|
| 20 |
+
def load_log(path: str) -> list[dict]:
|
| 21 |
+
records = []
|
| 22 |
+
with open(path) as f:
|
| 23 |
+
for line in f:
|
| 24 |
+
line = line.strip()
|
| 25 |
+
if line:
|
| 26 |
+
records.append(json.loads(line))
|
| 27 |
+
return records
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def plot(log_path: str, out_dir: str):
|
| 31 |
+
try:
|
| 32 |
+
import matplotlib.pyplot as plt
|
| 33 |
+
except ImportError:
|
| 34 |
+
print("β matplotlib not installed. pip install matplotlib")
|
| 35 |
+
return
|
| 36 |
+
|
| 37 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 38 |
+
records = load_log(log_path)
|
| 39 |
+
|
| 40 |
+
if not records:
|
| 41 |
+
print(f"β No records found in {log_path}")
|
| 42 |
+
return
|
| 43 |
+
|
| 44 |
+
steps = [r["step"] for r in records]
|
| 45 |
+
means = [r["mean_reward"] for r in records]
|
| 46 |
+
maxes = [r["max_reward"] for r in records]
|
| 47 |
+
mins = [r["min_reward"] for r in records]
|
| 48 |
+
difficulties = [r["difficulty"] for r in records]
|
| 49 |
+
|
| 50 |
+
# --- Plot 1: mean reward with band ---
|
| 51 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
| 52 |
+
ax.plot(steps, means, label="Mean reward", color="#4C6EF5", linewidth=2)
|
| 53 |
+
ax.fill_between(steps, mins, maxes, alpha=0.2, color="#4C6EF5", label="Min/Max band")
|
| 54 |
+
ax.axhline(0, color="gray", linestyle="--", linewidth=0.8)
|
| 55 |
+
|
| 56 |
+
# Mark difficulty changes
|
| 57 |
+
prev_d = None
|
| 58 |
+
for s, d in zip(steps, difficulties):
|
| 59 |
+
if d != prev_d:
|
| 60 |
+
ax.axvline(s, color="orange", linestyle=":", linewidth=1.2, alpha=0.7)
|
| 61 |
+
ax.text(s + 0.5, ax.get_ylim()[0] * 0.9, f"D{d}", fontsize=8, color="orange")
|
| 62 |
+
prev_d = d
|
| 63 |
+
|
| 64 |
+
ax.set_xlabel("Training Step")
|
| 65 |
+
ax.set_ylabel("Episode Reward")
|
| 66 |
+
ax.set_title("SalesPath GRPO β Mean Reward per Step")
|
| 67 |
+
ax.legend()
|
| 68 |
+
ax.grid(True, alpha=0.3)
|
| 69 |
+
plt.tight_layout()
|
| 70 |
+
path1 = os.path.join(out_dir, "reward_curve.png")
|
| 71 |
+
plt.savefig(path1, dpi=150)
|
| 72 |
+
print(f"β
Saved: {path1}")
|
| 73 |
+
|
| 74 |
+
# --- Plot 2: per-difficulty box ---
|
| 75 |
+
by_diff = defaultdict(list)
|
| 76 |
+
for r in records:
|
| 77 |
+
by_diff[r["difficulty"]].append(r["mean_reward"])
|
| 78 |
+
|
| 79 |
+
fig2, ax2 = plt.subplots(figsize=(7, 5))
|
| 80 |
+
labels = sorted(by_diff.keys())
|
| 81 |
+
data = [by_diff[d] for d in labels]
|
| 82 |
+
ax2.boxplot(data, labels=[f"Difficulty {d}" for d in labels], patch_artist=True)
|
| 83 |
+
ax2.set_ylabel("Mean Episode Reward")
|
| 84 |
+
ax2.set_title("Reward Distribution by Difficulty Level")
|
| 85 |
+
ax2.grid(True, alpha=0.3)
|
| 86 |
+
plt.tight_layout()
|
| 87 |
+
path2 = os.path.join(out_dir, "reward_by_difficulty.png")
|
| 88 |
+
plt.savefig(path2, dpi=150)
|
| 89 |
+
print(f"β
Saved: {path2}")
|
| 90 |
+
|
| 91 |
+
# Print quick stats
|
| 92 |
+
print(f"\nπ Training summary ({len(records)} steps)")
|
| 93 |
+
print(f" Start mean reward : {means[0]:+.3f}")
|
| 94 |
+
print(f" End mean reward : {means[-1]:+.3f}")
|
| 95 |
+
print(f" Best mean reward : {max(means):+.3f} (step {steps[means.index(max(means))]})")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
if __name__ == "__main__":
|
| 99 |
+
parser = argparse.ArgumentParser()
|
| 100 |
+
parser.add_argument("--log", default="./reward_log.jsonl")
|
| 101 |
+
parser.add_argument("--out", default="./plots")
|
| 102 |
+
args = parser.parse_args()
|
| 103 |
+
plot(args.log, args.out)
|
training/sft_demos.jsonl
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Meridian Retail, a medium retail company. Pain points: manual inventory tracking, slow reporting. Begin the sales conversation. [Stage: START | Turn: 0]"}, {"role": "assistant", "content": "ACTION: PROSPECT\nCONTENT: Hi! I understand Meridian Retail is dealing with manual inventory tracking and slow reporting. I'd love to learn more about how these challenges are impacting your day-to-day operations."}]}
|
| 2 |
+
{"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Meridian Retail, a medium retail company. Pain points: manual inventory tracking, slow reporting. [Stage: PROSPECT | Turn: 1 | Prospect said: That sounds interesting. Tell me more about how this works.]"}, {"role": "assistant", "content": "ACTION: QUALIFY\nCONTENT: Great! Before I dive in, could you share more about your budget range and who the key decision makers are for this type of initiative?"}]}
|
| 3 |
+
{"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Meridian Retail, a medium retail company. Pain points: manual inventory tracking, slow reporting. [Stage: QUALIFY | Turn: 2 | Prospect said: I see. We're evaluating a few options at the moment. Budget signal: high. Steps completed: PROSPECT, QUALIFY]"}, {"role": "assistant", "content": "ACTION: PRESENT\nCONTENT: Perfect. Our platform automates inventory tracking in real-time and generates reports in seconds β exactly the pain points you mentioned. It integrates with your existing systems with zero downtime."}]}
|
| 4 |
+
{"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Meridian Retail, a medium retail company. [Stage: PRESENT | Turn: 3 | Prospect said: That sounds interesting. Tell me more. | Steps completed: PROSPECT, QUALIFY, PRESENT]"}, {"role": "assistant", "content": "ACTION: CLOSE\nCONTENT: Based on everything we've discussed, I believe our solution is an excellent fit for Meridian Retail. Shall we move forward and get the paperwork started?"}]}
|
| 5 |
+
{"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Apex Logistics, an enterprise logistics company. Pain points: route optimization, driver coordination, fuel tracking. Budget: unknown. Begin the sales conversation. [Stage: START | Turn: 0]"}, {"role": "assistant", "content": "ACTION: PROSPECT\nCONTENT: Hello! I see Apex Logistics is tackling route optimization and driver coordination challenges. These are critical in logistics. I'd love to explore how we can help β could you walk me through your current process?"}]}
|
| 6 |
+
{"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Apex Logistics. [Stage: PROSPECT | Turn: 1 | Prospect said: That sounds interesting. | Budget: unknown]"}, {"role": "assistant", "content": "ACTION: QUALIFY\nCONTENT: To make sure I propose the right solution, could you share your approximate budget for this initiative and confirm who the final decision maker would be?"}]}
|
| 7 |
+
{"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Apex Logistics. [Stage: QUALIFY | Turn: 2 | Prospect said: We have some budget set aside. Budget signal: medium | Steps: PROSPECT, QUALIFY]"}, {"role": "assistant", "content": "ACTION: PRESENT\nCONTENT: Our platform optimizes routes using real-time traffic and GPS data, cutting fuel costs by up to 20% and improving driver coordination through a unified dashboard. Companies your size see ROI within 6 months."}]}
|
| 8 |
+
{"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Apex Logistics. [Stage: PRESENT | Turn: 3 | Prospect said: The pricing seems higher than what we budgeted. | Objections: 0 | Steps: PROSPECT, QUALIFY, PRESENT]"}, {"role": "assistant", "content": "ACTION: HANDLE_OBJECTION\nCONTENT: That's a fair concern. The upfront cost is offset by the fuel savings and reduced overtime β most clients break even within 4 months. We also offer flexible payment plans. Would that make it more accessible?"}]}
|
| 9 |
+
{"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Apex Logistics. [Stage: HANDLE_OBJECTION | Turn: 4 | Prospect said: That sounds reasonable. | Objections handled: 1 | Steps: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION]"}, {"role": "assistant", "content": "ACTION: OFFER_DEMO\nCONTENT: I'd love to show you the platform in action with your actual data. Can we schedule a 30-minute live demo next week so your team can see the route optimization in real time?"}]}
|
| 10 |
+
{"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Apex Logistics. [Stage: OFFER_DEMO | Turn: 5 | Prospect said: Yes, let's set up a demo. | Steps: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO]"}, {"role": "assistant", "content": "ACTION: CLOSE\nCONTENT: Excellent! Given the demo and the ROI we've discussed, I think we're aligned. Are you ready to move forward with the agreement?"}]}
|
| 11 |
+
{"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Cipher Tech, a small tech company. Budget signal: high (but misleading β actual budget is low). No decision maker present. [Stage: START | Turn: 0]"}, {"role": "assistant", "content": "ACTION: PROSPECT\nCONTENT: Hi! I understand you're facing security and compliance challenges. Could you tell me more about your current setup and what you're trying to achieve?"}]}
|
| 12 |
+
{"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Cipher Tech. [Stage: QUALIFY | Turn: 2 | Prospect said: Our budget is quite constrained. No decision maker available. | Budget: low | Decision maker: false | Steps: PROSPECT, QUALIFY]"}, {"role": "assistant", "content": "ACTION: DISQUALIFY\nCONTENT: Thank you for your transparency. Based on your current budget constraints and the absence of a decision maker, I don't think we're the right fit at this time. I'll leave the door open for when circumstances change."}]}
|
| 13 |
+
{"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Nova Financial. [Stage: OFFER_DEMO | Turn: 7 | Prospect said: Let me get back to you. A lot is happening on our end. | Objections: 2 | Budget: medium | Steps: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, HANDLE_OBJECTION, OFFER_DEMO]"}, {"role": "assistant", "content": "ACTION: NEGOTIATE\nCONTENT: I understand you're busy. To make this decision easier, I can offer a 3-month pilot at a reduced rate so your team can validate the ROI before full commitment."}]}
|
| 14 |
+
{"messages": [{"role": "system", "content": "You are a professional B2B sales agent. Your goal is to close deals by following the correct sales process. Always respond with exactly one action in the format:\nACTION: <ACTION_TYPE>\nCONTENT: <your message to the prospect>\n\nValid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"}, {"role": "user", "content": "You are engaging Atlas Health. [Stage: NEGOTIATE | Turn: 9 | Prospect said: I see. We're evaluating options. | Steps: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE]"}, {"role": "assistant", "content": "ACTION: CLOSE\nCONTENT: We've covered your pain points, seen the demo, and aligned on terms. I believe this is the right solution for Atlas Health. Can we finalize the agreement today?"}]}
|
training/train_grpo.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
train_grpo.py β Full GRPO RL Training
|
| 3 |
+
=======================================
|
| 4 |
+
Stage 2: loads the SFT checkpoint and fine-tunes with GRPO
|
| 5 |
+
using live rollouts against the SalesPath environment.
|
| 6 |
+
|
| 7 |
+
Architecture
|
| 8 |
+
------------
|
| 9 |
+
SFT checkpoint β Unsloth 4-bit QLoRA β GRPOTrainer (TRL)
|
| 10 |
+
β
|
| 11 |
+
SalesPath env (HTTP)
|
| 12 |
+
reward = composite score
|
| 13 |
+
|
| 14 |
+
Recommended hardware : A100 / T4 GPU (Google Colab)
|
| 15 |
+
Expected runtime : ~45-90 min for 200 steps on T4
|
| 16 |
+
|
| 17 |
+
Run:
|
| 18 |
+
# 1. Start the env server in another terminal:
|
| 19 |
+
# uvicorn salespath_env.server.app:app --port 7860
|
| 20 |
+
#
|
| 21 |
+
# 2. Then run this script:
|
| 22 |
+
python training/train_grpo.py
|
| 23 |
+
|
| 24 |
+
Outputs:
|
| 25 |
+
./grpo_checkpoint/ β final RL-trained model
|
| 26 |
+
reward_log.jsonl β per-step reward components for plotting
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
from __future__ import annotations
|
| 30 |
+
|
| 31 |
+
import json
|
| 32 |
+
import os
|
| 33 |
+
import re
|
| 34 |
+
import sys
|
| 35 |
+
import time
|
| 36 |
+
from typing import Any
|
| 37 |
+
|
| 38 |
+
import torch
|
| 39 |
+
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
# Config
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
ENV_URL = os.environ.get("SALESPATH_ENV_URL", "http://localhost:7860")
|
| 44 |
+
SFT_CHECKPOINT = os.environ.get("SFT_CHECKPOINT", "./sft_checkpoint")
|
| 45 |
+
OUTPUT_DIR = "./grpo_checkpoint"
|
| 46 |
+
REWARD_LOG_PATH = "./reward_log.jsonl"
|
| 47 |
+
|
| 48 |
+
MODEL_NAME = SFT_CHECKPOINT # start from SFT weights
|
| 49 |
+
MAX_SEQ_LEN = 1024
|
| 50 |
+
LORA_R = 16
|
| 51 |
+
LORA_ALPHA = 16
|
| 52 |
+
|
| 53 |
+
# GRPO hyper-parameters
|
| 54 |
+
NUM_TRAIN_STEPS = 200 # increase to 500+ for best results
|
| 55 |
+
ROLLOUTS_PER_STEP = 8 # episodes collected before each gradient update
|
| 56 |
+
DIFFICULTY_SCHEDULE = { # step β difficulty to use for rollouts
|
| 57 |
+
0: 1,
|
| 58 |
+
50: 2,
|
| 59 |
+
100: 3,
|
| 60 |
+
150: 4,
|
| 61 |
+
}
|
| 62 |
+
LR = 5e-6
|
| 63 |
+
KL_COEFF = 0.05 # keep close to SFT policy
|
| 64 |
+
GRAD_ACCUM = 4
|
| 65 |
+
BATCH_SIZE = 2
|
| 66 |
+
|
| 67 |
+
REPORT_TO = "none" # swap to "wandb" for live reward curves
|
| 68 |
+
|
| 69 |
+
# ---------------------------------------------------------------------------
|
| 70 |
+
# 1. Load model (Unsloth 4-bit QLoRA)
|
| 71 |
+
# ---------------------------------------------------------------------------
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
from unsloth import FastLanguageModel
|
| 75 |
+
USE_UNSLOTH = True
|
| 76 |
+
except ImportError:
|
| 77 |
+
USE_UNSLOTH = False
|
| 78 |
+
print("β οΈ Unsloth not found β falling back to HuggingFace transformers.")
|
| 79 |
+
|
| 80 |
+
if USE_UNSLOTH:
|
| 81 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 82 |
+
model_name=MODEL_NAME,
|
| 83 |
+
max_seq_length=MAX_SEQ_LEN,
|
| 84 |
+
dtype=None,
|
| 85 |
+
load_in_4bit=True,
|
| 86 |
+
)
|
| 87 |
+
model = FastLanguageModel.get_peft_model(
|
| 88 |
+
model,
|
| 89 |
+
r=LORA_R,
|
| 90 |
+
lora_alpha=LORA_ALPHA,
|
| 91 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 92 |
+
"gate_proj", "up_proj", "down_proj"],
|
| 93 |
+
lora_dropout=0.0,
|
| 94 |
+
bias="none",
|
| 95 |
+
use_gradient_checkpointing="unsloth",
|
| 96 |
+
random_state=42,
|
| 97 |
+
)
|
| 98 |
+
else:
|
| 99 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 100 |
+
from peft import get_peft_model, LoraConfig, TaskType
|
| 101 |
+
|
| 102 |
+
bnb = BitsAndBytesConfig(
|
| 103 |
+
load_in_4bit=True,
|
| 104 |
+
bnb_4bit_use_double_quant=True,
|
| 105 |
+
bnb_4bit_quant_type="nf4",
|
| 106 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 107 |
+
)
|
| 108 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 109 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 110 |
+
MODEL_NAME, quantization_config=bnb, device_map="auto"
|
| 111 |
+
)
|
| 112 |
+
model = get_peft_model(model, LoraConfig(
|
| 113 |
+
r=LORA_R, lora_alpha=LORA_ALPHA,
|
| 114 |
+
target_modules=["q_proj", "v_proj"],
|
| 115 |
+
task_type=TaskType.CAUSAL_LM,
|
| 116 |
+
))
|
| 117 |
+
|
| 118 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 119 |
+
tokenizer.padding_side = "right"
|
| 120 |
+
print(f"β
Model loaded from: {MODEL_NAME}")
|
| 121 |
+
|
| 122 |
+
# ---------------------------------------------------------------------------
|
| 123 |
+
# 2. Environment client
|
| 124 |
+
# ---------------------------------------------------------------------------
|
| 125 |
+
|
| 126 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
| 127 |
+
from salespath_env.client import SalesPathClient
|
| 128 |
+
|
| 129 |
+
client = SalesPathClient(ENV_URL)
|
| 130 |
+
print(f"β
Connected to env at {ENV_URL} β {client.health()}")
|
| 131 |
+
|
| 132 |
+
# ---------------------------------------------------------------------------
|
| 133 |
+
# 3. Prompt / action helpers
|
| 134 |
+
# ---------------------------------------------------------------------------
|
| 135 |
+
|
| 136 |
+
SYSTEM_PROMPT = (
|
| 137 |
+
"You are a professional B2B sales agent. "
|
| 138 |
+
"Follow the correct sales process to close deals.\n"
|
| 139 |
+
"Always respond with exactly ONE action in this format:\n"
|
| 140 |
+
"ACTION: <ACTION_TYPE>\n"
|
| 141 |
+
"CONTENT: <your message>\n\n"
|
| 142 |
+
"Valid actions: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, "
|
| 143 |
+
"OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY"
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
ACTION_RE = re.compile(
|
| 147 |
+
r"ACTION:\s*([A-Z_]+)\s*\nCONTENT:\s*(.+)",
|
| 148 |
+
re.DOTALL,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def obs_to_user_message(obs: dict, stage: str, turn: int) -> str:
|
| 153 |
+
parts = [obs.get("prospect_response", "")]
|
| 154 |
+
if obs.get("steps_completed"):
|
| 155 |
+
parts.append(f"Steps completed: {', '.join(obs['steps_completed'])}")
|
| 156 |
+
if obs.get("constraints_violated"):
|
| 157 |
+
parts.append(f"β Violations: {', '.join(obs['constraints_violated'])}")
|
| 158 |
+
parts.append(f"[Stage: {stage} | Turn: {turn}]")
|
| 159 |
+
return "\n".join(parts)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def parse_action(text: str) -> tuple[str, str]:
|
| 163 |
+
"""Extract (action_type, content) from model output."""
|
| 164 |
+
m = ACTION_RE.search(text.strip())
|
| 165 |
+
if m:
|
| 166 |
+
return m.group(1).strip(), m.group(2).strip()
|
| 167 |
+
# Fallback: if the model doesn't follow format, treat whole text as QUALIFY
|
| 168 |
+
return "QUALIFY", text.strip()
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def generate_action(messages: list[dict]) -> str:
|
| 172 |
+
"""Run one forward pass; return raw generated text."""
|
| 173 |
+
inputs = tokenizer.apply_chat_template(
|
| 174 |
+
messages,
|
| 175 |
+
tokenize=True,
|
| 176 |
+
add_generation_prompt=True,
|
| 177 |
+
return_tensors="pt",
|
| 178 |
+
).to(model.device)
|
| 179 |
+
|
| 180 |
+
with torch.no_grad():
|
| 181 |
+
output_ids = model.generate(
|
| 182 |
+
inputs,
|
| 183 |
+
max_new_tokens=128,
|
| 184 |
+
do_sample=True,
|
| 185 |
+
temperature=0.7,
|
| 186 |
+
top_p=0.9,
|
| 187 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
new_tokens = output_ids[0, inputs.shape[-1]:]
|
| 191 |
+
return tokenizer.decode(new_tokens, skip_special_tokens=True)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# ---------------------------------------------------------------------------
|
| 195 |
+
# 4. Rollout collector
|
| 196 |
+
# ---------------------------------------------------------------------------
|
| 197 |
+
|
| 198 |
+
def run_episode(difficulty: int) -> list[dict]:
|
| 199 |
+
"""
|
| 200 |
+
Run one complete episode; return list of
|
| 201 |
+
{prompt_messages, completion, reward, reward_components} dicts.
|
| 202 |
+
"""
|
| 203 |
+
obs = client.reset(difficulty=difficulty)
|
| 204 |
+
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
| 205 |
+
samples = []
|
| 206 |
+
|
| 207 |
+
for _ in range(20): # hard cap matches env MAX_TURNS
|
| 208 |
+
user_msg = obs_to_user_message(
|
| 209 |
+
obs,
|
| 210 |
+
obs.get("workflow_stage", "START"),
|
| 211 |
+
obs.get("turn_number", 0),
|
| 212 |
+
)
|
| 213 |
+
messages.append({"role": "user", "content": user_msg})
|
| 214 |
+
|
| 215 |
+
# Generate & parse
|
| 216 |
+
completion = generate_action(list(messages))
|
| 217 |
+
action_type, content = parse_action(completion)
|
| 218 |
+
|
| 219 |
+
# Step env
|
| 220 |
+
obs = client.step(action_type, content)
|
| 221 |
+
|
| 222 |
+
samples.append({
|
| 223 |
+
"messages": list(messages),
|
| 224 |
+
"completion": completion,
|
| 225 |
+
"reward": obs["reward"],
|
| 226 |
+
"reward_components": obs.get("reward_components", {}),
|
| 227 |
+
})
|
| 228 |
+
|
| 229 |
+
messages.append({"role": "assistant", "content": completion})
|
| 230 |
+
|
| 231 |
+
if obs["done"]:
|
| 232 |
+
break
|
| 233 |
+
|
| 234 |
+
return samples
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def collect_rollouts(
|
| 238 |
+
n: int,
|
| 239 |
+
difficulty: int,
|
| 240 |
+
) -> tuple[list[str], list[str], list[float]]:
|
| 241 |
+
"""
|
| 242 |
+
Collect n episode rollouts.
|
| 243 |
+
Returns (prompts, completions, rewards) as flat lists for GRPOTrainer.
|
| 244 |
+
"""
|
| 245 |
+
prompts, completions, rewards = [], [], []
|
| 246 |
+
|
| 247 |
+
for ep in range(n):
|
| 248 |
+
samples = run_episode(difficulty)
|
| 249 |
+
for s in samples:
|
| 250 |
+
prompt_text = tokenizer.apply_chat_template(
|
| 251 |
+
s["messages"],
|
| 252 |
+
tokenize=False,
|
| 253 |
+
add_generation_prompt=True,
|
| 254 |
+
)
|
| 255 |
+
prompts.append(prompt_text)
|
| 256 |
+
completions.append(s["completion"])
|
| 257 |
+
rewards.append(s["reward"])
|
| 258 |
+
|
| 259 |
+
ep_reward = sum(s["reward"] for s in samples)
|
| 260 |
+
print(f" ep {ep+1}/{n} steps={len(samples)} ep_reward={ep_reward:+.3f}")
|
| 261 |
+
|
| 262 |
+
return prompts, completions, rewards
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# ---------------------------------------------------------------------------
|
| 266 |
+
# 5. Reward log helper
|
| 267 |
+
# ---------------------------------------------------------------------------
|
| 268 |
+
|
| 269 |
+
reward_log: list[dict] = []
|
| 270 |
+
|
| 271 |
+
def log_rewards(step: int, rewards: list[float], difficulty: int) -> None:
|
| 272 |
+
entry = {
|
| 273 |
+
"step": step,
|
| 274 |
+
"difficulty": difficulty,
|
| 275 |
+
"mean_reward": sum(rewards) / len(rewards),
|
| 276 |
+
"max_reward": max(rewards),
|
| 277 |
+
"min_reward": min(rewards),
|
| 278 |
+
"n_samples": len(rewards),
|
| 279 |
+
}
|
| 280 |
+
reward_log.append(entry)
|
| 281 |
+
with open(REWARD_LOG_PATH, "a") as f:
|
| 282 |
+
f.write(json.dumps(entry) + "\n")
|
| 283 |
+
print(
|
| 284 |
+
f" π step={step:4d} diff={difficulty} "
|
| 285 |
+
f"mean={entry['mean_reward']:+.3f} "
|
| 286 |
+
f"max={entry['max_reward']:+.3f}"
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# ---------------------------------------------------------------------------
|
| 291 |
+
# 6. GRPOTrainer setup
|
| 292 |
+
# ---------------------------------------------------------------------------
|
| 293 |
+
|
| 294 |
+
from datasets import Dataset
|
| 295 |
+
from trl import GRPOTrainer, GRPOConfig
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def make_reward_fn(precomputed: dict[str, float]):
|
| 299 |
+
"""
|
| 300 |
+
GRPOTrainer calls reward_funcs(prompts, completions) β list[float].
|
| 301 |
+
We pre-run rollouts and store results; the reward_fn just looks them up.
|
| 302 |
+
"""
|
| 303 |
+
def reward_fn(prompts: list[str], completions: list[str], **kwargs) -> list[float]:
|
| 304 |
+
return [
|
| 305 |
+
precomputed.get(p + c, 0.0)
|
| 306 |
+
for p, c in zip(prompts, completions)
|
| 307 |
+
]
|
| 308 |
+
return reward_fn
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
grpo_config = GRPOConfig(
|
| 312 |
+
output_dir=OUTPUT_DIR,
|
| 313 |
+
num_train_epochs=1, # we control steps manually
|
| 314 |
+
per_device_train_batch_size=BATCH_SIZE,
|
| 315 |
+
gradient_accumulation_steps=GRAD_ACCUM,
|
| 316 |
+
learning_rate=LR,
|
| 317 |
+
kl_coeff=KL_COEFF,
|
| 318 |
+
logging_steps=1,
|
| 319 |
+
save_steps=50,
|
| 320 |
+
fp16=not USE_UNSLOTH,
|
| 321 |
+
report_to=REPORT_TO,
|
| 322 |
+
max_completion_length=128,
|
| 323 |
+
remove_unused_columns=False,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
# ---------------------------------------------------------------------------
|
| 328 |
+
# 7. Training loop
|
| 329 |
+
# ---------------------------------------------------------------------------
|
| 330 |
+
|
| 331 |
+
print(f"\nπ Starting GRPO training for {NUM_TRAIN_STEPS} steps")
|
| 332 |
+
print(f" Rollouts per step : {ROLLOUTS_PER_STEP}")
|
| 333 |
+
print(f" KL coefficient : {KL_COEFF}")
|
| 334 |
+
print(f" Difficulty schedule: {DIFFICULTY_SCHEDULE}\n")
|
| 335 |
+
|
| 336 |
+
for step in range(NUM_TRAIN_STEPS):
|
| 337 |
+
# Determine difficulty for this step
|
| 338 |
+
difficulty = 1
|
| 339 |
+
for threshold, d in sorted(DIFFICULTY_SCHEDULE.items()):
|
| 340 |
+
if step >= threshold:
|
| 341 |
+
difficulty = d
|
| 342 |
+
|
| 343 |
+
print(f"\n[Step {step+1}/{NUM_TRAIN_STEPS}] difficulty={difficulty}")
|
| 344 |
+
|
| 345 |
+
# -- Collect rollouts --
|
| 346 |
+
prompts, completions, rewards = collect_rollouts(
|
| 347 |
+
ROLLOUTS_PER_STEP, difficulty
|
| 348 |
+
)
|
| 349 |
+
log_rewards(step + 1, rewards, difficulty)
|
| 350 |
+
|
| 351 |
+
# -- Build dataset for this step --
|
| 352 |
+
reward_lookup = {
|
| 353 |
+
p + c: r
|
| 354 |
+
for p, c, r in zip(prompts, completions, rewards)
|
| 355 |
+
}
|
| 356 |
+
step_dataset = Dataset.from_dict({
|
| 357 |
+
"prompt": prompts,
|
| 358 |
+
"completion": completions,
|
| 359 |
+
})
|
| 360 |
+
|
| 361 |
+
# -- GRPOTrainer one-step update --
|
| 362 |
+
trainer = GRPOTrainer(
|
| 363 |
+
model=model,
|
| 364 |
+
reward_funcs=make_reward_fn(reward_lookup),
|
| 365 |
+
args=grpo_config,
|
| 366 |
+
train_dataset=step_dataset,
|
| 367 |
+
processing_class=tokenizer,
|
| 368 |
+
)
|
| 369 |
+
trainer.train()
|
| 370 |
+
|
| 371 |
+
# Save checkpoint every 50 steps
|
| 372 |
+
if (step + 1) % 50 == 0:
|
| 373 |
+
ckpt = os.path.join(OUTPUT_DIR, f"step_{step+1}")
|
| 374 |
+
model.save_pretrained(ckpt)
|
| 375 |
+
tokenizer.save_pretrained(ckpt)
|
| 376 |
+
print(f" πΎ Checkpoint saved: {ckpt}")
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
# ---------------------------------------------------------------------------
|
| 380 |
+
# 8. Final save
|
| 381 |
+
# ---------------------------------------------------------------------------
|
| 382 |
+
|
| 383 |
+
model.save_pretrained(OUTPUT_DIR)
|
| 384 |
+
tokenizer.save_pretrained(OUTPUT_DIR)
|
| 385 |
+
print(f"\nβ
GRPO training complete.")
|
| 386 |
+
print(f" Model β {OUTPUT_DIR}")
|
| 387 |
+
print(f" Rewards β {REWARD_LOG_PATH}")
|
| 388 |
+
print("\nPlot rewards with: python training/plot_rewards.py")
|
training/train_sft.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
train_sft.py β SFT Warm-Start Stage
|
| 3 |
+
=====================================
|
| 4 |
+
Fine-tunes a base LLM on expert sales demonstrations BEFORE GRPO.
|
| 5 |
+
SFT teaches the model the correct action FORMAT and rough ordering,
|
| 6 |
+
giving GRPO a much better starting policy.
|
| 7 |
+
|
| 8 |
+
Recommended hardware : T4 GPU (Google Colab free tier)
|
| 9 |
+
Expected runtime : ~10-15 minutes for 14 demos Γ 3 epochs
|
| 10 |
+
|
| 11 |
+
Run:
|
| 12 |
+
python training/train_sft.py
|
| 13 |
+
|
| 14 |
+
Outputs:
|
| 15 |
+
./sft_checkpoint/ β load this as base in train_grpo.py
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
|
| 22 |
+
# ---------------------------------------------------------------------------
|
| 23 |
+
# Config β tweak these
|
| 24 |
+
# ---------------------------------------------------------------------------
|
| 25 |
+
MODEL_NAME = "unsloth/Qwen2.5-1.5B-Instruct" # swap for 0.5B on tiny GPU
|
| 26 |
+
OUTPUT_DIR = "./sft_checkpoint"
|
| 27 |
+
DATA_PATH = os.path.join(os.path.dirname(__file__), "sft_demos.jsonl")
|
| 28 |
+
MAX_SEQ_LEN = 1024
|
| 29 |
+
NUM_EPOCHS = 3
|
| 30 |
+
BATCH_SIZE = 2
|
| 31 |
+
GRAD_ACCUM = 4
|
| 32 |
+
LR = 2e-4
|
| 33 |
+
LORA_R = 16
|
| 34 |
+
LORA_ALPHA = 16
|
| 35 |
+
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
# 1. Load model with Unsloth 4-bit QLoRA
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
try:
|
| 40 |
+
from unsloth import FastLanguageModel
|
| 41 |
+
USE_UNSLOTH = True
|
| 42 |
+
except ImportError:
|
| 43 |
+
USE_UNSLOTH = False
|
| 44 |
+
print("β οΈ Unsloth not installed β falling back to plain HuggingFace.")
|
| 45 |
+
print(" Install with: pip install unsloth")
|
| 46 |
+
|
| 47 |
+
if USE_UNSLOTH:
|
| 48 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 49 |
+
model_name=MODEL_NAME,
|
| 50 |
+
max_seq_length=MAX_SEQ_LEN,
|
| 51 |
+
dtype=None, # auto-detect: bf16 on Ampere+, fp16 otherwise
|
| 52 |
+
load_in_4bit=True,
|
| 53 |
+
)
|
| 54 |
+
model = FastLanguageModel.get_peft_model(
|
| 55 |
+
model,
|
| 56 |
+
r=LORA_R,
|
| 57 |
+
lora_alpha=LORA_ALPHA,
|
| 58 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 59 |
+
"gate_proj", "up_proj", "down_proj"],
|
| 60 |
+
lora_dropout=0.05,
|
| 61 |
+
bias="none",
|
| 62 |
+
use_gradient_checkpointing="unsloth",
|
| 63 |
+
random_state=42,
|
| 64 |
+
)
|
| 65 |
+
else:
|
| 66 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 67 |
+
from peft import get_peft_model, LoraConfig, TaskType
|
| 68 |
+
import torch
|
| 69 |
+
|
| 70 |
+
bnb_config = BitsAndBytesConfig(
|
| 71 |
+
load_in_4bit=True,
|
| 72 |
+
bnb_4bit_use_double_quant=True,
|
| 73 |
+
bnb_4bit_quant_type="nf4",
|
| 74 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 75 |
+
)
|
| 76 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 77 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 78 |
+
MODEL_NAME,
|
| 79 |
+
quantization_config=bnb_config,
|
| 80 |
+
device_map="auto",
|
| 81 |
+
)
|
| 82 |
+
lora_config = LoraConfig(
|
| 83 |
+
r=LORA_R, lora_alpha=LORA_ALPHA,
|
| 84 |
+
target_modules=["q_proj", "v_proj"],
|
| 85 |
+
task_type=TaskType.CAUSAL_LM,
|
| 86 |
+
)
|
| 87 |
+
model = get_peft_model(model, lora_config)
|
| 88 |
+
|
| 89 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 90 |
+
tokenizer.padding_side = "right"
|
| 91 |
+
|
| 92 |
+
print(f"β
Model loaded: {MODEL_NAME} (4-bit QLoRA, r={LORA_R})")
|
| 93 |
+
|
| 94 |
+
# ---------------------------------------------------------------------------
|
| 95 |
+
# 2. Load & format SFT dataset
|
| 96 |
+
# ---------------------------------------------------------------------------
|
| 97 |
+
|
| 98 |
+
def load_sft_data(path: str) -> list[dict]:
|
| 99 |
+
records = []
|
| 100 |
+
with open(path) as f:
|
| 101 |
+
for line in f:
|
| 102 |
+
line = line.strip()
|
| 103 |
+
if line:
|
| 104 |
+
records.append(json.loads(line))
|
| 105 |
+
return records
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def format_chat(record: dict) -> str:
|
| 109 |
+
"""
|
| 110 |
+
Apply the model's chat template to convert messages β a single string.
|
| 111 |
+
"""
|
| 112 |
+
return tokenizer.apply_chat_template(
|
| 113 |
+
record["messages"],
|
| 114 |
+
tokenize=False,
|
| 115 |
+
add_generation_prompt=False,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
raw_data = load_sft_data(DATA_PATH)
|
| 120 |
+
print(f"β
Loaded {len(raw_data)} SFT demonstrations from {DATA_PATH}")
|
| 121 |
+
|
| 122 |
+
from datasets import Dataset
|
| 123 |
+
|
| 124 |
+
formatted = [{"text": format_chat(r)} for r in raw_data]
|
| 125 |
+
dataset = Dataset.from_list(formatted)
|
| 126 |
+
print(f" Sample:\n{formatted[0]['text'][:300]}\n...")
|
| 127 |
+
|
| 128 |
+
# ---------------------------------------------------------------------------
|
| 129 |
+
# 3. SFT Trainer
|
| 130 |
+
# ---------------------------------------------------------------------------
|
| 131 |
+
|
| 132 |
+
from trl import SFTTrainer, SFTConfig
|
| 133 |
+
|
| 134 |
+
sft_config = SFTConfig(
|
| 135 |
+
output_dir=OUTPUT_DIR,
|
| 136 |
+
num_train_epochs=NUM_EPOCHS,
|
| 137 |
+
per_device_train_batch_size=BATCH_SIZE,
|
| 138 |
+
gradient_accumulation_steps=GRAD_ACCUM,
|
| 139 |
+
learning_rate=LR,
|
| 140 |
+
warmup_ratio=0.1,
|
| 141 |
+
lr_scheduler_type="cosine",
|
| 142 |
+
logging_steps=1,
|
| 143 |
+
save_strategy="epoch",
|
| 144 |
+
fp16=not USE_UNSLOTH, # Unsloth handles this internally
|
| 145 |
+
bf16=False,
|
| 146 |
+
max_seq_length=MAX_SEQ_LEN,
|
| 147 |
+
dataset_text_field="text",
|
| 148 |
+
report_to="none", # swap to "wandb" if you have W&B set up
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
trainer = SFTTrainer(
|
| 152 |
+
model=model,
|
| 153 |
+
tokenizer=tokenizer,
|
| 154 |
+
train_dataset=dataset,
|
| 155 |
+
args=sft_config,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
print("π Starting SFT training...")
|
| 159 |
+
trainer_stats = trainer.train()
|
| 160 |
+
|
| 161 |
+
print(f"\nβ
SFT done.")
|
| 162 |
+
print(f" Loss : {trainer_stats.training_loss:.4f}")
|
| 163 |
+
print(f" Saved : {OUTPUT_DIR}")
|
| 164 |
+
|
| 165 |
+
# ---------------------------------------------------------------------------
|
| 166 |
+
# 4. Save final checkpoint
|
| 167 |
+
# ---------------------------------------------------------------------------
|
| 168 |
+
|
| 169 |
+
model.save_pretrained(OUTPUT_DIR)
|
| 170 |
+
tokenizer.save_pretrained(OUTPUT_DIR)
|
| 171 |
+
print(f"β
Checkpoint saved to {OUTPUT_DIR}")
|
| 172 |
+
print("\nNext step β run: python training/train_grpo.py")
|
training/train_test.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
train_test.py β Quick smoke test (no GPU, no LLM needed).
|
| 3 |
+
|
| 4 |
+
Tests the FULL pipeline end-to-end in ~30 seconds:
|
| 5 |
+
1. Starts the env server in a subprocess
|
| 6 |
+
2. Runs 4 scripted episodes (one per difficulty)
|
| 7 |
+
3. Prints reward traces and rule-violation checks
|
| 8 |
+
4. Verifies the three fixed bugs (R05, R07, R08) behave correctly
|
| 9 |
+
|
| 10 |
+
Run:
|
| 11 |
+
python training/train_test.py
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import subprocess
|
| 15 |
+
import sys
|
| 16 |
+
import time
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
# Add project root to path so we can import client
|
| 20 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
# 1. Start server in background
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
|
| 27 |
+
def start_server(port: int = 7860):
|
| 28 |
+
proc = subprocess.Popen(
|
| 29 |
+
[
|
| 30 |
+
sys.executable, "-m", "uvicorn",
|
| 31 |
+
"salespath_env.server.app:app",
|
| 32 |
+
"--host", "0.0.0.0",
|
| 33 |
+
"--port", str(port),
|
| 34 |
+
"--log-level", "error",
|
| 35 |
+
],
|
| 36 |
+
cwd=os.path.join(os.path.dirname(__file__), ".."),
|
| 37 |
+
)
|
| 38 |
+
print(f"β³ Starting server on port {port}...")
|
| 39 |
+
time.sleep(4) # wait for uvicorn to be ready
|
| 40 |
+
return proc
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ---------------------------------------------------------------------------
|
| 44 |
+
# 2. Import client
|
| 45 |
+
# ---------------------------------------------------------------------------
|
| 46 |
+
|
| 47 |
+
from salespath_env.client import SalesPathClient
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
# 3. Test episodes
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
|
| 54 |
+
EPISODES = {
|
| 55 |
+
1: [
|
| 56 |
+
# Happy path β difficulty 1
|
| 57 |
+
("PROSPECT", "Hello, tell me about your challenges."),
|
| 58 |
+
("QUALIFY", "What's your budget and who decides?"),
|
| 59 |
+
("PRESENT", "Here's how we solve your inventory problem."),
|
| 60 |
+
("CLOSE", "Shall we move forward?"),
|
| 61 |
+
],
|
| 62 |
+
2: [
|
| 63 |
+
# Objection + demo β difficulty 2
|
| 64 |
+
("PROSPECT", "Hi, I'd like to learn more about your operations."),
|
| 65 |
+
("QUALIFY", "Can you share your budget range?"),
|
| 66 |
+
("PRESENT", "Here is our solution."),
|
| 67 |
+
("HANDLE_OBJECTION", "Totally understand the pricing concern β here's why the ROI works."),
|
| 68 |
+
("OFFER_DEMO", "Let me show you a live demo next week."),
|
| 69 |
+
("CLOSE", "Ready to move forward?"),
|
| 70 |
+
],
|
| 71 |
+
3: [
|
| 72 |
+
# Full hard path β difficulty 3
|
| 73 |
+
("PROSPECT", "Hello, I've researched your compliance challenges."),
|
| 74 |
+
("QUALIFY", "Who is the decision maker and what is the budget?"),
|
| 75 |
+
("PRESENT", "Here's how we address audit trails and data silos."),
|
| 76 |
+
("HANDLE_OBJECTION", "I understand the timing is tough β many clients felt the same."),
|
| 77 |
+
("HANDLE_OBJECTION", "On the price point β we can structure payments quarterly."),
|
| 78 |
+
("OFFER_DEMO", "Let me show you a live demo."),
|
| 79 |
+
("NEGOTIATE", "Here is a pilot option at a reduced rate."),
|
| 80 |
+
("CLOSE", "Shall we proceed?"),
|
| 81 |
+
],
|
| 82 |
+
4: [
|
| 83 |
+
# Trap case β correct action is DISQUALIFY
|
| 84 |
+
("PROSPECT", "Hi, tell me about your security needs."),
|
| 85 |
+
("QUALIFY", "What is your budget and who decides?"),
|
| 86 |
+
("DISQUALIFY", "Given the budget constraints and no decision maker, this isn't the right time."),
|
| 87 |
+
],
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def run_episode(client: SalesPathClient, difficulty: int, script: list) -> dict:
|
| 92 |
+
obs = client.reset(difficulty=difficulty)
|
| 93 |
+
print(f"\n{'='*60}")
|
| 94 |
+
print(f" Difficulty {difficulty} | Prospect: {obs.get('prospect_response', '')[:80]}")
|
| 95 |
+
print(f"{'='*60}")
|
| 96 |
+
|
| 97 |
+
results = {"rewards": [], "violations": [], "turns": 0, "done": False}
|
| 98 |
+
|
| 99 |
+
for action_type, content in script:
|
| 100 |
+
obs = client.step(action_type, content)
|
| 101 |
+
results["rewards"].append(obs["reward"])
|
| 102 |
+
results["violations"].extend(obs.get("constraints_violated", []))
|
| 103 |
+
results["turns"] = obs["turn_number"]
|
| 104 |
+
results["done"] = obs["done"]
|
| 105 |
+
|
| 106 |
+
status = "β
" if not obs.get("constraints_violated") else "β οΈ"
|
| 107 |
+
print(
|
| 108 |
+
f" {status} Turn {obs['turn_number']:2d} "
|
| 109 |
+
f"{action_type:<20} "
|
| 110 |
+
f"reward={obs['reward']:+.3f} "
|
| 111 |
+
f"violations={obs.get('constraints_violated', [])}"
|
| 112 |
+
)
|
| 113 |
+
if obs["done"]:
|
| 114 |
+
break
|
| 115 |
+
|
| 116 |
+
total = sum(results["rewards"])
|
| 117 |
+
print(f"\n Cumulative reward: {total:+.3f} | Violations: {results['violations']}")
|
| 118 |
+
return results
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# ---------------------------------------------------------------------------
|
| 122 |
+
# 4. Bug-regression checks
|
| 123 |
+
# ---------------------------------------------------------------------------
|
| 124 |
+
|
| 125 |
+
def test_r05_no_repeat(client: SalesPathClient):
|
| 126 |
+
"""R05: same action twice in a row must fire a violation."""
|
| 127 |
+
print("\n--- BUG CHECK: R05 no-repeat ---")
|
| 128 |
+
client.reset(difficulty=1)
|
| 129 |
+
client.step("PROSPECT", "Hello.")
|
| 130 |
+
obs = client.step("PROSPECT", "Hello again.") # should violate R05
|
| 131 |
+
violated = obs.get("constraints_violated", [])
|
| 132 |
+
ok = "R05" in violated
|
| 133 |
+
print(f" R05 fired on consecutive PROSPECT: {'β
PASS' if ok else 'β FAIL'} {violated}")
|
| 134 |
+
return ok
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def test_r07_followup(client: SalesPathClient):
|
| 138 |
+
"""R07: FOLLOW_UP after a prospect response should be a violation."""
|
| 139 |
+
print("\n--- BUG CHECK: R07 followup timing ---")
|
| 140 |
+
client.reset(difficulty=1)
|
| 141 |
+
client.step("PROSPECT", "Hello.") # prospect responds positively
|
| 142 |
+
obs = client.step("FOLLOW_UP", "Just checking in.") # violation β prospect already replied
|
| 143 |
+
violated = obs.get("constraints_violated", [])
|
| 144 |
+
ok = "R07" in violated
|
| 145 |
+
print(f" R07 fired when prospect already responded: {'β
PASS' if ok else 'β FAIL'} {violated}")
|
| 146 |
+
return ok
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def test_r08_disqualify(client: SalesPathClient):
|
| 150 |
+
"""R08: DISQUALIFY on a closeable difficulty-1 prospect must be a violation."""
|
| 151 |
+
print("\n--- BUG CHECK: R08 disqualify logic ---")
|
| 152 |
+
client.reset(difficulty=1) # high budget, decision maker present β closeable
|
| 153 |
+
client.step("PROSPECT", "Hello.")
|
| 154 |
+
client.step("QUALIFY", "What's your budget?")
|
| 155 |
+
obs = client.step("DISQUALIFY", "I don't think you're a fit.")
|
| 156 |
+
violated = obs.get("constraints_violated", [])
|
| 157 |
+
ok = "R08" in violated
|
| 158 |
+
print(f" R08 fired on valid prospect: {'β
PASS' if ok else 'β FAIL'} {violated}")
|
| 159 |
+
return ok
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# ---------------------------------------------------------------------------
|
| 163 |
+
# 5. Main
|
| 164 |
+
# ---------------------------------------------------------------------------
|
| 165 |
+
|
| 166 |
+
def main():
|
| 167 |
+
PORT = 7860
|
| 168 |
+
server = start_server(PORT)
|
| 169 |
+
|
| 170 |
+
try:
|
| 171 |
+
client = SalesPathClient(f"http://localhost:{PORT}")
|
| 172 |
+
|
| 173 |
+
# Health check
|
| 174 |
+
try:
|
| 175 |
+
h = client.health()
|
| 176 |
+
print(f"β
Server healthy: {h}")
|
| 177 |
+
except Exception as e:
|
| 178 |
+
print(f"β Server not responding: {e}")
|
| 179 |
+
return
|
| 180 |
+
|
| 181 |
+
# Run all difficulty episodes
|
| 182 |
+
all_rewards = {}
|
| 183 |
+
for diff, script in EPISODES.items():
|
| 184 |
+
result = run_episode(client, diff, script)
|
| 185 |
+
all_rewards[diff] = sum(result["rewards"])
|
| 186 |
+
|
| 187 |
+
# Bug regression suite
|
| 188 |
+
print("\n" + "="*60)
|
| 189 |
+
print(" BUG REGRESSION SUITE")
|
| 190 |
+
print("="*60)
|
| 191 |
+
r05_ok = test_r05_no_repeat(client)
|
| 192 |
+
r07_ok = test_r07_followup(client)
|
| 193 |
+
r08_ok = test_r08_disqualify(client)
|
| 194 |
+
|
| 195 |
+
# Summary
|
| 196 |
+
print("\n" + "="*60)
|
| 197 |
+
print(" SUMMARY")
|
| 198 |
+
print("="*60)
|
| 199 |
+
for diff, total in all_rewards.items():
|
| 200 |
+
print(f" Difficulty {diff}: cumulative reward = {total:+.3f}")
|
| 201 |
+
|
| 202 |
+
bugs_passed = sum([r05_ok, r07_ok, r08_ok])
|
| 203 |
+
print(f"\n Bug fixes passing: {bugs_passed}/3")
|
| 204 |
+
print(f"\n{'β
ALL SYSTEMS GO' if bugs_passed == 3 else 'β οΈ SOME CHECKS FAILED'}")
|
| 205 |
+
|
| 206 |
+
finally:
|
| 207 |
+
server.terminate()
|
| 208 |
+
print("\nServer stopped.")
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
if __name__ == "__main__":
|
| 212 |
+
main()
|