Spaces:
Runtime error
Runtime error
Commit ·
b77d3c5
0
Parent(s):
first commit
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitignore +1 -0
- CODING_APPROACH.md +1028 -0
- Dockerfile +27 -0
- README.md +51 -0
- RULES.md +354 -0
- push_to_hub.py +44 -0
- pyproject.toml +22 -0
- requirements.txt +4 -0
- salespath_env.egg-info/PKG-INFO +8 -0
- salespath_env.egg-info/SOURCES.txt +17 -0
- salespath_env.egg-info/dependency_links.txt +1 -0
- salespath_env.egg-info/requires.txt +4 -0
- salespath_env.egg-info/top_level.txt +1 -0
- salespath_env/README.md +0 -0
- salespath_env/__init__.py +2 -0
- salespath_env/__pycache__/__init__.cpython-313.pyc +0 -0
- salespath_env/__pycache__/client.cpython-313.pyc +0 -0
- salespath_env/__pycache__/models.cpython-313.pyc +0 -0
- salespath_env/client.py +81 -0
- salespath_env/models.py +93 -0
- salespath_env/openenv.yaml +13 -0
- salespath_env/pyproject.toml +0 -0
- salespath_env/server/Dockerfile +12 -0
- salespath_env/server/__init__.py +2 -0
- salespath_env/server/__pycache__/__init__.cpython-313.pyc +0 -0
- salespath_env/server/__pycache__/app.cpython-313.pyc +0 -0
- salespath_env/server/__pycache__/prospect_simulator.cpython-313.pyc +0 -0
- salespath_env/server/__pycache__/reward.cpython-313.pyc +0 -0
- salespath_env/server/__pycache__/rules.cpython-313.pyc +0 -0
- salespath_env/server/__pycache__/salespath_environment.cpython-313.pyc +0 -0
- salespath_env/server/__pycache__/task_bank.cpython-313.pyc +0 -0
- salespath_env/server/app.py +18 -0
- salespath_env/server/prospect_simulator.py +162 -0
- salespath_env/server/requirements.txt +3 -0
- salespath_env/server/reward.py +138 -0
- salespath_env/server/rules.py +222 -0
- salespath_env/server/salespath_environment.py +294 -0
- salespath_env/server/task_bank.py +199 -0
- training/__init__.py +0 -0
- training/__pycache__/__init__.cpython-313.pyc +0 -0
- training/__pycache__/curriculum.cpython-313.pyc +0 -0
- training/__pycache__/debug_episode.cpython-313.pyc +0 -0
- training/__pycache__/grpo_train.cpython-313.pyc +0 -0
- training/__pycache__/rollout.cpython-313.pyc +0 -0
- training/__pycache__/test_rollout.cpython-313.pyc +0 -0
- training/colab_train.ipynb +100 -0
- training/curriculum.py +80 -0
- training/debug_episode.py +40 -0
- training/grpo_train.py +315 -0
- training/rollout.py +143 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
/.spa
|
CODING_APPROACH.md
ADDED
|
@@ -0,0 +1,1028 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SalesPath — End-to-End Coding Approach
|
| 2 |
+
### For Agent Execution. Follow in order. No skipping.
|
| 3 |
+
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
## Phase 0: Setup (Do First, ~15 min)
|
| 7 |
+
|
| 8 |
+
```bash
|
| 9 |
+
# Install OpenEnv
|
| 10 |
+
pip install openenv
|
| 11 |
+
|
| 12 |
+
# Scaffold the project
|
| 13 |
+
openenv init salespath_env
|
| 14 |
+
cd salespath_env
|
| 15 |
+
|
| 16 |
+
# Install dependencies
|
| 17 |
+
pip install -e .
|
| 18 |
+
|
| 19 |
+
# Verify scaffold works
|
| 20 |
+
uv run server --host 0.0.0.0 --port 8000
|
| 21 |
+
# Should start FastAPI on 8000. Ctrl+C after confirming.
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
Edit `pyproject.toml` — add dependencies:
|
| 25 |
+
```toml
|
| 26 |
+
[project]
|
| 27 |
+
name = "salespath_env"
|
| 28 |
+
version = "0.1.0"
|
| 29 |
+
dependencies = [
|
| 30 |
+
"openenv",
|
| 31 |
+
"fastapi",
|
| 32 |
+
"uvicorn",
|
| 33 |
+
"pydantic>=2.0",
|
| 34 |
+
"trl>=0.8.0",
|
| 35 |
+
"unsloth",
|
| 36 |
+
"torch",
|
| 37 |
+
"transformers",
|
| 38 |
+
]
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
---
|
| 42 |
+
|
| 43 |
+
## Phase 1: Models (Person A) — `models.py`
|
| 44 |
+
|
| 45 |
+
Write this file first. Everything else depends on it.
|
| 46 |
+
|
| 47 |
+
```python
|
| 48 |
+
# salespath_env/models.py
|
| 49 |
+
from __future__ import annotations
|
| 50 |
+
import uuid
|
| 51 |
+
from dataclasses import dataclass, field
|
| 52 |
+
from typing import Optional
|
| 53 |
+
from openenv.core import Action, Observation, State
|
| 54 |
+
|
| 55 |
+
VALID_ACTIONS = {
|
| 56 |
+
"PROSPECT", "QUALIFY", "PRESENT", "HANDLE_OBJECTION",
|
| 57 |
+
"OFFER_DEMO", "NEGOTIATE", "CLOSE", "FOLLOW_UP", "DISQUALIFY"
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
class SalesPathAction(Action):
|
| 61 |
+
action_type: str
|
| 62 |
+
content: str
|
| 63 |
+
target: str = ""
|
| 64 |
+
|
| 65 |
+
def is_valid(self) -> bool:
|
| 66 |
+
return self.action_type in VALID_ACTIONS
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class SalesPathObservation(Observation):
|
| 70 |
+
prospect_response: str = ""
|
| 71 |
+
workflow_stage: str = "START"
|
| 72 |
+
constraints_violated: list[str] = field(default_factory=list)
|
| 73 |
+
steps_completed: list[str] = field(default_factory=list)
|
| 74 |
+
turn_number: int = 0
|
| 75 |
+
reward: float = 0.0
|
| 76 |
+
reward_components: dict = field(default_factory=dict)
|
| 77 |
+
done: bool = False
|
| 78 |
+
info: dict = field(default_factory=dict)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class SalesPathState(State):
|
| 82 |
+
episode_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
| 83 |
+
prospect_profile: dict = field(default_factory=dict)
|
| 84 |
+
conversation_history: list[dict] = field(default_factory=list)
|
| 85 |
+
workflow_stage: str = "START"
|
| 86 |
+
required_workflow: list[str] = field(default_factory=list)
|
| 87 |
+
steps_completed: list[str] = field(default_factory=list)
|
| 88 |
+
constraints_violated: list[str] = field(default_factory=list)
|
| 89 |
+
objections_handled: int = 0
|
| 90 |
+
turn_number: int = 0
|
| 91 |
+
difficulty: int = 1
|
| 92 |
+
done: bool = False
|
| 93 |
+
# Hidden — never expose in Observation
|
| 94 |
+
_hidden: dict = field(default_factory=dict)
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
---
|
| 98 |
+
|
| 99 |
+
## Phase 2: Task Bank (Person A) — `server/task_bank.py`
|
| 100 |
+
|
| 101 |
+
This generates prospect profiles. Keep it simple — 10 profiles per difficulty level.
|
| 102 |
+
|
| 103 |
+
```python
|
| 104 |
+
# server/task_bank.py
|
| 105 |
+
import random
|
| 106 |
+
from dataclasses import dataclass
|
| 107 |
+
|
| 108 |
+
@dataclass
|
| 109 |
+
class ProspectProfile:
|
| 110 |
+
company_name: str
|
| 111 |
+
company_size: str # "small" / "medium" / "enterprise"
|
| 112 |
+
industry: str
|
| 113 |
+
budget_signal: str # "high" / "medium" / "low" / "unknown"
|
| 114 |
+
pain_points: list[str]
|
| 115 |
+
decision_maker: bool
|
| 116 |
+
# Hidden — simulator uses these, agent never sees raw values
|
| 117 |
+
true_budget: float # 0.0 to 1.0 scale
|
| 118 |
+
close_threshold: float # budget needed to close
|
| 119 |
+
stall_probability: float # for Level 3+
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
PROFILES_L1 = [
|
| 123 |
+
ProspectProfile(
|
| 124 |
+
company_name="Meridian Retail",
|
| 125 |
+
company_size="medium",
|
| 126 |
+
industry="retail",
|
| 127 |
+
budget_signal="high",
|
| 128 |
+
pain_points=["manual inventory tracking", "slow reporting"],
|
| 129 |
+
decision_maker=True,
|
| 130 |
+
true_budget=0.8,
|
| 131 |
+
close_threshold=0.5,
|
| 132 |
+
stall_probability=0.0,
|
| 133 |
+
),
|
| 134 |
+
# Add 9 more L1 profiles following same pattern
|
| 135 |
+
# L1: budget_signal always known, decision_maker always True, close_threshold <= 0.6
|
| 136 |
+
]
|
| 137 |
+
|
| 138 |
+
PROFILES_L2 = [
|
| 139 |
+
ProspectProfile(
|
| 140 |
+
company_name="Apex Logistics",
|
| 141 |
+
company_size="enterprise",
|
| 142 |
+
industry="logistics",
|
| 143 |
+
budget_signal="unknown", # revealed after QUALIFY
|
| 144 |
+
pain_points=["route optimization", "driver coordination", "fuel tracking"],
|
| 145 |
+
decision_maker=True,
|
| 146 |
+
true_budget=0.7,
|
| 147 |
+
close_threshold=0.5,
|
| 148 |
+
stall_probability=0.0,
|
| 149 |
+
),
|
| 150 |
+
# 9 more L2 profiles: budget hidden, one objection expected
|
| 151 |
+
]
|
| 152 |
+
|
| 153 |
+
PROFILES_L3 = [
|
| 154 |
+
ProspectProfile(
|
| 155 |
+
company_name="Nova Financial",
|
| 156 |
+
company_size="enterprise",
|
| 157 |
+
industry="finance",
|
| 158 |
+
budget_signal="unknown",
|
| 159 |
+
pain_points=["compliance reporting", "audit trails", "data silos"],
|
| 160 |
+
decision_maker=False, # must navigate to decision maker
|
| 161 |
+
true_budget=0.6,
|
| 162 |
+
close_threshold=0.55,
|
| 163 |
+
stall_probability=0.3, # will stall at turn 10
|
| 164 |
+
),
|
| 165 |
+
# 9 more L3 profiles: budget hidden, two objections, mode shift
|
| 166 |
+
]
|
| 167 |
+
|
| 168 |
+
PROFILES_L4 = [
|
| 169 |
+
ProspectProfile(
|
| 170 |
+
company_name="Cipher Tech",
|
| 171 |
+
company_size="small",
|
| 172 |
+
industry="technology",
|
| 173 |
+
budget_signal="high", # MISLEADING — true_budget is actually low
|
| 174 |
+
pain_points=["security", "compliance"],
|
| 175 |
+
decision_maker=True,
|
| 176 |
+
true_budget=0.2, # can't actually afford it
|
| 177 |
+
close_threshold=0.5,
|
| 178 |
+
stall_probability=0.5,
|
| 179 |
+
),
|
| 180 |
+
# 9 more L4: misleading signals, correct answer is DISQUALIFY
|
| 181 |
+
]
|
| 182 |
+
|
| 183 |
+
ALL_PROFILES = {1: PROFILES_L1, 2: PROFILES_L2, 3: PROFILES_L3, 4: PROFILES_L4}
|
| 184 |
+
|
| 185 |
+
def sample_profile(difficulty: int) -> ProspectProfile:
|
| 186 |
+
return random.choice(ALL_PROFILES[difficulty])
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
---
|
| 190 |
+
|
| 191 |
+
## Phase 3: Business Rules (Person A) — `server/rules.py`
|
| 192 |
+
|
| 193 |
+
```python
|
| 194 |
+
# server/rules.py
|
| 195 |
+
from dataclasses import dataclass
|
| 196 |
+
from typing import Callable
|
| 197 |
+
from ..models import SalesPathAction, SalesPathState
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
@dataclass
|
| 201 |
+
class BusinessRule:
|
| 202 |
+
rule_id: str
|
| 203 |
+
name: str
|
| 204 |
+
description: str
|
| 205 |
+
check: Callable[[SalesPathState, SalesPathAction], bool]
|
| 206 |
+
# Returns True if VIOLATED
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def _qualify_before_present(state: SalesPathState, action: SalesPathAction) -> bool:
|
| 210 |
+
if action.action_type == "PRESENT":
|
| 211 |
+
return "QUALIFY" not in state.steps_completed
|
| 212 |
+
return False
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def _demo_before_negotiate(state: SalesPathState, action: SalesPathAction) -> bool:
|
| 216 |
+
if action.action_type == "NEGOTIATE":
|
| 217 |
+
return "OFFER_DEMO" not in state.steps_completed
|
| 218 |
+
return False
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def _budget_known_to_negotiate(state: SalesPathState, action: SalesPathAction) -> bool:
|
| 222 |
+
if action.action_type == "NEGOTIATE":
|
| 223 |
+
return state.prospect_profile.get("budget_signal") == "unknown"
|
| 224 |
+
return False
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def _discount_after_objections(state: SalesPathState, action: SalesPathAction) -> bool:
|
| 228 |
+
if action.action_type == "NEGOTIATE":
|
| 229 |
+
if "discount" in action.content.lower():
|
| 230 |
+
return state.objections_handled < 2
|
| 231 |
+
return False
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def _no_repeat_action(state: SalesPathState, action: SalesPathAction) -> bool:
|
| 235 |
+
if state.conversation_history:
|
| 236 |
+
last_action = state.conversation_history[-1].get("action_type", "")
|
| 237 |
+
return last_action == action.action_type
|
| 238 |
+
return False
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def _prospect_first(state: SalesPathState, action: SalesPathAction) -> bool:
|
| 242 |
+
if state.turn_number == 1:
|
| 243 |
+
return action.action_type != "PROSPECT"
|
| 244 |
+
return False
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def _followup_timing(state: SalesPathState, action: SalesPathAction) -> bool:
|
| 248 |
+
if action.action_type == "FOLLOW_UP":
|
| 249 |
+
if state.conversation_history:
|
| 250 |
+
last_speaker = state.conversation_history[-1].get("speaker", "agent")
|
| 251 |
+
return last_speaker == "prospect" # prospect just responded
|
| 252 |
+
return False
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def _disqualify_logic(state: SalesPathState, action: SalesPathAction) -> bool:
|
| 256 |
+
if action.action_type == "DISQUALIFY":
|
| 257 |
+
profile = state.prospect_profile
|
| 258 |
+
true_budget = state._hidden.get("true_budget", 0.5)
|
| 259 |
+
close_threshold = state._hidden.get("close_threshold", 0.5)
|
| 260 |
+
dm = profile.get("decision_maker", True)
|
| 261 |
+
# Violation: disqualifying when prospect is actually closeable
|
| 262 |
+
return (true_budget >= close_threshold) and dm
|
| 263 |
+
return False
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def _close_requires_demo(state: SalesPathState, action: SalesPathAction) -> bool:
|
| 267 |
+
if action.action_type == "CLOSE":
|
| 268 |
+
if state.difficulty >= 2:
|
| 269 |
+
return "OFFER_DEMO" not in state.steps_completed
|
| 270 |
+
return False
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
BUSINESS_RULES = [
|
| 274 |
+
BusinessRule("R01", "qualify_before_present",
|
| 275 |
+
"Must QUALIFY before PRESENT", _qualify_before_present),
|
| 276 |
+
BusinessRule("R02", "demo_before_negotiate",
|
| 277 |
+
"Must OFFER_DEMO before NEGOTIATE", _demo_before_negotiate),
|
| 278 |
+
BusinessRule("R03", "budget_known_to_negotiate",
|
| 279 |
+
"Budget must be known before NEGOTIATE", _budget_known_to_negotiate),
|
| 280 |
+
BusinessRule("R04", "discount_after_objections",
|
| 281 |
+
"Discount only after 2 objections", _discount_after_objections),
|
| 282 |
+
BusinessRule("R05", "no_repeat_action",
|
| 283 |
+
"Cannot repeat same action consecutively", _no_repeat_action),
|
| 284 |
+
BusinessRule("R06", "prospect_first",
|
| 285 |
+
"First action must be PROSPECT", _prospect_first),
|
| 286 |
+
BusinessRule("R07", "followup_timing",
|
| 287 |
+
"FOLLOW_UP only after prospect silence", _followup_timing),
|
| 288 |
+
BusinessRule("R08", "disqualify_logic",
|
| 289 |
+
"DISQUALIFY only when prospect is genuinely unqualified", _disqualify_logic),
|
| 290 |
+
BusinessRule("R09", "close_requires_demo",
|
| 291 |
+
"Must OFFER_DEMO before CLOSE (Levels 2+)", _close_requires_demo),
|
| 292 |
+
]
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def check_rules(state: SalesPathState, action: SalesPathAction) -> list[str]:
|
| 296 |
+
"""Returns list of violated rule IDs."""
|
| 297 |
+
return [
|
| 298 |
+
rule.rule_id
|
| 299 |
+
for rule in BUSINESS_RULES
|
| 300 |
+
if rule.check(state, action)
|
| 301 |
+
]
|
| 302 |
+
```
|
| 303 |
+
|
| 304 |
+
---
|
| 305 |
+
|
| 306 |
+
## Phase 4: Prospect Simulator (Person A) — `server/prospect_simulator.py`
|
| 307 |
+
|
| 308 |
+
```python
|
| 309 |
+
# server/prospect_simulator.py
|
| 310 |
+
# PURE RULE-BASED. No LLM. No imports from transformers.
|
| 311 |
+
|
| 312 |
+
from ..models import SalesPathState, SalesPathAction
|
| 313 |
+
|
| 314 |
+
RESPONSE_TEXT = {
|
| 315 |
+
"open:positive_signal": "That sounds interesting. Tell me more about how this works.",
|
| 316 |
+
"open:neutral_signal": "I see. We're evaluating a few options at the moment.",
|
| 317 |
+
"objection:price": "The pricing seems higher than what we budgeted for.",
|
| 318 |
+
"objection:timing": "The timing isn't ideal — we're in the middle of a quarter close.",
|
| 319 |
+
"objection:premature_pitch": "I'm not sure we're ready to discuss solutions yet. What do you know about our situation?",
|
| 320 |
+
"deflect:budget_not_discussed": "We haven't really talked about what we're looking for yet.",
|
| 321 |
+
"deflect:stall": "Let me get back to you on this. A lot is happening on our end.",
|
| 322 |
+
"accept:demo_scheduled": "Yes, let's set up a demo. What time works next week?",
|
| 323 |
+
"accept:close_success": "Alright, I think we can move forward with this. Send over the paperwork.",
|
| 324 |
+
"reject:close_failed": "I don't think we're ready to commit at this point.",
|
| 325 |
+
"silence": "",
|
| 326 |
+
"exit:disqualified": "I think we're done here. This isn't the right fit.",
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class ProspectSimulator:
|
| 331 |
+
|
| 332 |
+
def respond(self, action: SalesPathAction, state: SalesPathState) -> tuple[str, str]:
|
| 333 |
+
"""
|
| 334 |
+
Returns (response_token, response_text).
|
| 335 |
+
Deterministic — same inputs always produce same output.
|
| 336 |
+
"""
|
| 337 |
+
token = self._get_token(action, state)
|
| 338 |
+
text = RESPONSE_TEXT[token]
|
| 339 |
+
return token, text
|
| 340 |
+
|
| 341 |
+
def _get_token(self, action: SalesPathAction, state: SalesPathState) -> str:
|
| 342 |
+
atype = action.action_type
|
| 343 |
+
hidden = state._hidden
|
| 344 |
+
turn = state.turn_number
|
| 345 |
+
profile = state.prospect_profile
|
| 346 |
+
objections = state.objections_handled
|
| 347 |
+
difficulty = state.difficulty
|
| 348 |
+
|
| 349 |
+
# Rule violation responses (priority — check first)
|
| 350 |
+
if "R01" in state.constraints_violated[-1:]:
|
| 351 |
+
return "objection:premature_pitch"
|
| 352 |
+
if "R03" in state.constraints_violated[-1:]:
|
| 353 |
+
return "deflect:budget_not_discussed"
|
| 354 |
+
|
| 355 |
+
# Action-specific logic
|
| 356 |
+
if atype == "PROSPECT":
|
| 357 |
+
return "open:positive_signal"
|
| 358 |
+
|
| 359 |
+
if atype == "QUALIFY":
|
| 360 |
+
# Reveal budget signal if it was hidden
|
| 361 |
+
if profile.get("budget_signal") == "unknown":
|
| 362 |
+
state.prospect_profile["budget_signal"] = hidden.get("revealed_budget", "medium")
|
| 363 |
+
return "open:neutral_signal"
|
| 364 |
+
|
| 365 |
+
if atype == "PRESENT":
|
| 366 |
+
if difficulty >= 2:
|
| 367 |
+
return "objection:price" if objections == 0 else "open:positive_signal"
|
| 368 |
+
return "open:positive_signal"
|
| 369 |
+
|
| 370 |
+
if atype == "HANDLE_OBJECTION":
|
| 371 |
+
state.objections_handled += 1
|
| 372 |
+
if objections + 1 >= hidden.get("num_objections", 1):
|
| 373 |
+
return "open:positive_signal"
|
| 374 |
+
return "objection:timing" if objections == 0 else "open:positive_signal"
|
| 375 |
+
|
| 376 |
+
if atype == "OFFER_DEMO":
|
| 377 |
+
return "accept:demo_scheduled"
|
| 378 |
+
|
| 379 |
+
if atype == "NEGOTIATE":
|
| 380 |
+
return "open:neutral_signal"
|
| 381 |
+
|
| 382 |
+
if atype == "CLOSE":
|
| 383 |
+
true_budget = hidden.get("true_budget", 0.7)
|
| 384 |
+
threshold = hidden.get("close_threshold", 0.5)
|
| 385 |
+
if true_budget >= threshold and profile.get("decision_maker", True):
|
| 386 |
+
return "accept:close_success"
|
| 387 |
+
return "reject:close_failed"
|
| 388 |
+
|
| 389 |
+
if atype == "FOLLOW_UP":
|
| 390 |
+
return "open:neutral_signal"
|
| 391 |
+
|
| 392 |
+
if atype == "DISQUALIFY":
|
| 393 |
+
return "exit:disqualified"
|
| 394 |
+
|
| 395 |
+
# Mode shift at turn 10 for Level 3+
|
| 396 |
+
if difficulty >= 3 and turn >= 10:
|
| 397 |
+
import random
|
| 398 |
+
if random.random() < hidden.get("stall_probability", 0.0):
|
| 399 |
+
return "deflect:stall"
|
| 400 |
+
|
| 401 |
+
return "open:neutral_signal"
|
| 402 |
+
```
|
| 403 |
+
|
| 404 |
+
---
|
| 405 |
+
|
| 406 |
+
## Phase 5: Reward Function (Person B) — `server/reward.py`
|
| 407 |
+
|
| 408 |
+
```python
|
| 409 |
+
# server/reward.py
|
| 410 |
+
|
| 411 |
+
from ..models import SalesPathState, SalesPathAction
|
| 412 |
+
|
| 413 |
+
DIFFICULTY_OPTIMAL_TURNS = {1: 5, 2: 8, 3: 12, 4: 14}
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def compute_reward(
|
| 417 |
+
state: SalesPathState,
|
| 418 |
+
action: SalesPathAction,
|
| 419 |
+
response_token: str,
|
| 420 |
+
new_violations: list[str],
|
| 421 |
+
episode_done: bool,
|
| 422 |
+
) -> tuple[float, dict]:
|
| 423 |
+
"""
|
| 424 |
+
Returns (total_reward, component_dict).
|
| 425 |
+
Always returns components — never a single scalar.
|
| 426 |
+
"""
|
| 427 |
+
components = {}
|
| 428 |
+
|
| 429 |
+
# --- Component 1: Outcome (only on terminal step) ---
|
| 430 |
+
r_outcome = 0.0
|
| 431 |
+
if episode_done:
|
| 432 |
+
if response_token == "accept:close_success":
|
| 433 |
+
r_outcome = 1.0
|
| 434 |
+
elif action.action_type == "DISQUALIFY":
|
| 435 |
+
# Check if disqualify was correct (no R08 violation)
|
| 436 |
+
if "R08" not in new_violations:
|
| 437 |
+
r_outcome = 0.5
|
| 438 |
+
else:
|
| 439 |
+
r_outcome = -0.5
|
| 440 |
+
elif state.turn_number >= 20:
|
| 441 |
+
r_outcome = -0.3
|
| 442 |
+
elif len(state.constraints_violated) >= 3:
|
| 443 |
+
r_outcome = -0.5
|
| 444 |
+
else:
|
| 445 |
+
r_outcome = -0.5 # failed close
|
| 446 |
+
components["r_outcome"] = r_outcome
|
| 447 |
+
|
| 448 |
+
# --- Component 2: Compliance ---
|
| 449 |
+
total_violations = len(state.constraints_violated) + len(new_violations)
|
| 450 |
+
r_compliance = max(-1.0, -0.2 * len(new_violations)) # per-step signal
|
| 451 |
+
components["r_compliance"] = r_compliance
|
| 452 |
+
|
| 453 |
+
# --- Component 3: Step Ordering ---
|
| 454 |
+
required = state.required_workflow
|
| 455 |
+
completed = state.steps_completed
|
| 456 |
+
if len(required) > 1 and len(completed) > 0:
|
| 457 |
+
# Count correct transitions
|
| 458 |
+
correct = sum(
|
| 459 |
+
1 for i in range(min(len(completed), len(required)))
|
| 460 |
+
if completed[i] == required[i]
|
| 461 |
+
)
|
| 462 |
+
r_ordering = correct / len(required)
|
| 463 |
+
else:
|
| 464 |
+
r_ordering = 1.0 if (not required or action.action_type == required[0]) else 0.0
|
| 465 |
+
components["r_ordering"] = r_ordering
|
| 466 |
+
|
| 467 |
+
# --- Component 4: Efficiency ---
|
| 468 |
+
if episode_done:
|
| 469 |
+
optimal = DIFFICULTY_OPTIMAL_TURNS.get(state.difficulty, 10)
|
| 470 |
+
overhead = max(0, state.turn_number - optimal)
|
| 471 |
+
r_efficiency = max(-0.3, -0.05 * overhead)
|
| 472 |
+
else:
|
| 473 |
+
r_efficiency = 0.0 # only computed at episode end
|
| 474 |
+
components["r_efficiency"] = r_efficiency
|
| 475 |
+
|
| 476 |
+
# --- Component 5: Format ---
|
| 477 |
+
r_format = 1.0 if action.is_valid() else -0.1
|
| 478 |
+
components["r_format"] = r_format
|
| 479 |
+
|
| 480 |
+
# --- Weighted total ---
|
| 481 |
+
weights = {
|
| 482 |
+
"r_outcome": 0.40,
|
| 483 |
+
"r_compliance": 0.30,
|
| 484 |
+
"r_ordering": 0.15,
|
| 485 |
+
"r_efficiency": 0.10,
|
| 486 |
+
"r_format": 0.05,
|
| 487 |
+
}
|
| 488 |
+
total = sum(weights[k] * v for k, v in components.items())
|
| 489 |
+
components["total"] = total
|
| 490 |
+
|
| 491 |
+
return total, components
|
| 492 |
+
```
|
| 493 |
+
|
| 494 |
+
---
|
| 495 |
+
|
| 496 |
+
## Phase 6: Environment Core (Person A) — `server/salespath_environment.py`
|
| 497 |
+
|
| 498 |
+
```python
|
| 499 |
+
# server/salespath_environment.py
|
| 500 |
+
import uuid
|
| 501 |
+
from openenv.core.env_server import Environment
|
| 502 |
+
from ..models import SalesPathAction, SalesPathObservation, SalesPathState
|
| 503 |
+
from .task_bank import sample_profile
|
| 504 |
+
from .rules import check_rules, BUSINESS_RULES
|
| 505 |
+
from .reward import compute_reward
|
| 506 |
+
from .prospect_simulator import ProspectSimulator
|
| 507 |
+
|
| 508 |
+
DIFFICULTY_WORKFLOW = {
|
| 509 |
+
1: ["QUALIFY", "PRESENT", "CLOSE"],
|
| 510 |
+
2: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "CLOSE"],
|
| 511 |
+
3: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO",
|
| 512 |
+
"HANDLE_OBJECTION", "NEGOTIATE", "CLOSE"],
|
| 513 |
+
4: [], # agent must determine; DISQUALIFY may be correct
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
MAX_VIOLATIONS_BEFORE_TERMINATE = 3
|
| 517 |
+
MAX_TURNS = 20
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
class SalesPathEnvironment(Environment):
|
| 521 |
+
|
| 522 |
+
def __init__(self):
|
| 523 |
+
super().__init__()
|
| 524 |
+
self._state = SalesPathState()
|
| 525 |
+
self._simulator = ProspectSimulator()
|
| 526 |
+
|
| 527 |
+
def reset(self, difficulty: int = 1) -> SalesPathObservation:
|
| 528 |
+
profile = sample_profile(difficulty)
|
| 529 |
+
hidden = {
|
| 530 |
+
"true_budget": profile.true_budget,
|
| 531 |
+
"close_threshold": profile.close_threshold,
|
| 532 |
+
"stall_probability": profile.stall_probability,
|
| 533 |
+
"num_objections": {1: 0, 2: 1, 3: 2, 4: 2}[difficulty],
|
| 534 |
+
"revealed_budget": (
|
| 535 |
+
"high" if profile.true_budget >= 0.7
|
| 536 |
+
else "medium" if profile.true_budget >= 0.4
|
| 537 |
+
else "low"
|
| 538 |
+
),
|
| 539 |
+
}
|
| 540 |
+
public_profile = {
|
| 541 |
+
"company_name": profile.company_name,
|
| 542 |
+
"company_size": profile.company_size,
|
| 543 |
+
"industry": profile.industry,
|
| 544 |
+
"budget_signal": profile.budget_signal,
|
| 545 |
+
"pain_points": profile.pain_points,
|
| 546 |
+
"decision_maker": profile.decision_maker,
|
| 547 |
+
}
|
| 548 |
+
self._state = SalesPathState(
|
| 549 |
+
episode_id=str(uuid.uuid4()),
|
| 550 |
+
prospect_profile=public_profile,
|
| 551 |
+
required_workflow=DIFFICULTY_WORKFLOW[difficulty],
|
| 552 |
+
difficulty=difficulty,
|
| 553 |
+
)
|
| 554 |
+
self._state._hidden = hidden
|
| 555 |
+
|
| 556 |
+
return SalesPathObservation(
|
| 557 |
+
prospect_response=(
|
| 558 |
+
f"You are engaging {profile.company_name}, a {profile.company_size} "
|
| 559 |
+
f"{profile.industry} company. Pain points: {', '.join(profile.pain_points)}. "
|
| 560 |
+
f"Begin the sales conversation."
|
| 561 |
+
),
|
| 562 |
+
workflow_stage="START",
|
| 563 |
+
steps_completed=[],
|
| 564 |
+
constraints_violated=[],
|
| 565 |
+
turn_number=0,
|
| 566 |
+
reward=0.0,
|
| 567 |
+
done=False,
|
| 568 |
+
info={"difficulty": difficulty, "episode_id": self._state.episode_id},
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
def step(self, action: SalesPathAction) -> SalesPathObservation:
|
| 572 |
+
state = self._state
|
| 573 |
+
state.turn_number += 1
|
| 574 |
+
|
| 575 |
+
# Validate action format
|
| 576 |
+
if not action.is_valid():
|
| 577 |
+
return SalesPathObservation(
|
| 578 |
+
prospect_response="Invalid action type.",
|
| 579 |
+
workflow_stage=state.workflow_stage,
|
| 580 |
+
steps_completed=list(state.steps_completed),
|
| 581 |
+
constraints_violated=list(state.constraints_violated),
|
| 582 |
+
turn_number=state.turn_number,
|
| 583 |
+
reward=-0.2,
|
| 584 |
+
done=False,
|
| 585 |
+
info={"error": f"Invalid action_type: {action.action_type}",
|
| 586 |
+
"r_format": -0.1},
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
# Check business rules
|
| 590 |
+
new_violations = check_rules(state, action)
|
| 591 |
+
state.constraints_violated.extend(new_violations)
|
| 592 |
+
|
| 593 |
+
# Update conversation history
|
| 594 |
+
state.conversation_history.append({
|
| 595 |
+
"turn": state.turn_number,
|
| 596 |
+
"speaker": "agent",
|
| 597 |
+
"action_type": action.action_type,
|
| 598 |
+
"content": action.content,
|
| 599 |
+
})
|
| 600 |
+
|
| 601 |
+
# Update steps completed
|
| 602 |
+
if action.action_type not in state.steps_completed:
|
| 603 |
+
state.steps_completed.append(action.action_type)
|
| 604 |
+
state.workflow_stage = action.action_type
|
| 605 |
+
|
| 606 |
+
# Get prospect response
|
| 607 |
+
response_token, response_text = self._simulator.respond(action, state)
|
| 608 |
+
state.conversation_history.append({
|
| 609 |
+
"turn": state.turn_number,
|
| 610 |
+
"speaker": "prospect",
|
| 611 |
+
"response_token": response_token,
|
| 612 |
+
"text": response_text,
|
| 613 |
+
})
|
| 614 |
+
|
| 615 |
+
# Determine episode termination
|
| 616 |
+
terminal_actions = {"CLOSE", "DISQUALIFY"}
|
| 617 |
+
too_many_violations = len(state.constraints_violated) >= MAX_VIOLATIONS_BEFORE_TERMINATE
|
| 618 |
+
turn_limit = state.turn_number >= MAX_TURNS
|
| 619 |
+
done = (
|
| 620 |
+
action.action_type in terminal_actions
|
| 621 |
+
or too_many_violations
|
| 622 |
+
or turn_limit
|
| 623 |
+
)
|
| 624 |
+
state.done = done
|
| 625 |
+
|
| 626 |
+
# Compute reward
|
| 627 |
+
total_reward, components = compute_reward(
|
| 628 |
+
state, action, response_token, new_violations, done
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
return SalesPathObservation(
|
| 632 |
+
prospect_response=response_text,
|
| 633 |
+
workflow_stage=state.workflow_stage,
|
| 634 |
+
steps_completed=list(state.steps_completed),
|
| 635 |
+
constraints_violated=list(state.constraints_violated),
|
| 636 |
+
turn_number=state.turn_number,
|
| 637 |
+
reward=total_reward,
|
| 638 |
+
reward_components=components,
|
| 639 |
+
done=done,
|
| 640 |
+
info={
|
| 641 |
+
"response_token": response_token,
|
| 642 |
+
"new_violations": new_violations,
|
| 643 |
+
"episode_id": state.episode_id,
|
| 644 |
+
},
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
@property
|
| 648 |
+
def state(self) -> SalesPathState:
|
| 649 |
+
return self._state
|
| 650 |
+
```
|
| 651 |
+
|
| 652 |
+
---
|
| 653 |
+
|
| 654 |
+
## Phase 7: FastAPI App (Person A) — `server/app.py`
|
| 655 |
+
|
| 656 |
+
```python
|
| 657 |
+
# server/app.py — thin wrapper only
|
| 658 |
+
from openenv.core.env_server import create_fastapi_app
|
| 659 |
+
from ..models import SalesPathAction, SalesPathObservation
|
| 660 |
+
from .salespath_environment import SalesPathEnvironment
|
| 661 |
+
|
| 662 |
+
app = create_fastapi_app(
|
| 663 |
+
SalesPathEnvironment,
|
| 664 |
+
SalesPathAction,
|
| 665 |
+
SalesPathObservation,
|
| 666 |
+
)
|
| 667 |
+
```
|
| 668 |
+
|
| 669 |
+
---
|
| 670 |
+
|
| 671 |
+
## Phase 8: Client (Person B) — `client.py`
|
| 672 |
+
|
| 673 |
+
```python
|
| 674 |
+
# client.py
|
| 675 |
+
from openenv.core import EnvClient
|
| 676 |
+
from .models import SalesPathAction, SalesPathObservation, SalesPathState
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
class SalesPathEnv(EnvClient):
|
| 680 |
+
action_type = SalesPathAction
|
| 681 |
+
observation_type = SalesPathObservation
|
| 682 |
+
state_type = SalesPathState
|
| 683 |
+
|
| 684 |
+
async def reset(self, difficulty: int = 1) -> SalesPathObservation:
|
| 685 |
+
return await super().reset(difficulty=difficulty)
|
| 686 |
+
|
| 687 |
+
async def step(self, action_type: str, content: str, target: str = "") -> SalesPathObservation:
|
| 688 |
+
action = SalesPathAction(
|
| 689 |
+
action_type=action_type,
|
| 690 |
+
content=content,
|
| 691 |
+
target=target,
|
| 692 |
+
)
|
| 693 |
+
return await super().step(action)
|
| 694 |
+
```
|
| 695 |
+
|
| 696 |
+
---
|
| 697 |
+
|
| 698 |
+
## Phase 9: Rollout Function (Person B) — `training/rollout.py`
|
| 699 |
+
|
| 700 |
+
```python
|
| 701 |
+
# training/rollout.py
|
| 702 |
+
import re
|
| 703 |
+
from salespath_env.client import SalesPathEnv
|
| 704 |
+
from salespath_env.models import SalesPathObservation
|
| 705 |
+
|
| 706 |
+
SYSTEM_PROMPT = """You are a B2B sales agent. Your goal is to close deals by following a strict workflow.
|
| 707 |
+
|
| 708 |
+
Required workflow steps (in order): {workflow}
|
| 709 |
+
|
| 710 |
+
Business rules — NEVER violate these:
|
| 711 |
+
- R01: Must QUALIFY before PRESENT
|
| 712 |
+
- R02: Must OFFER_DEMO before NEGOTIATE
|
| 713 |
+
- R03: Budget must be known before NEGOTIATE
|
| 714 |
+
- R04: Discount only after 2 objections handled
|
| 715 |
+
- R05: Cannot repeat same action twice in a row
|
| 716 |
+
- R06: First action must always be PROSPECT
|
| 717 |
+
- R07: FOLLOW_UP only after prospect goes silent
|
| 718 |
+
- R08: DISQUALIFY only if prospect is genuinely unqualified
|
| 719 |
+
- R09: Must OFFER_DEMO before CLOSE (difficulty 2+)
|
| 720 |
+
|
| 721 |
+
Respond EXACTLY in this format:
|
| 722 |
+
ACTION: <one of: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY>
|
| 723 |
+
CONTENT: <your message to the prospect>"""
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
def parse_action(text: str) -> tuple[str, str]:
|
| 727 |
+
"""Extract ACTION and CONTENT from model output."""
|
| 728 |
+
action_match = re.search(r"ACTION:\s*(\w+)", text, re.IGNORECASE)
|
| 729 |
+
content_match = re.search(r"CONTENT:\s*(.+?)(?:\n|$)", text, re.IGNORECASE | re.DOTALL)
|
| 730 |
+
|
| 731 |
+
action_type = action_match.group(1).upper() if action_match else "QUALIFY"
|
| 732 |
+
content = content_match.group(1).strip() if content_match else "Tell me more about your needs."
|
| 733 |
+
|
| 734 |
+
return action_type, content
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
def build_prompt(obs: SalesPathObservation, workflow: list[str], tokenizer) -> str:
|
| 738 |
+
messages = [
|
| 739 |
+
{"role": "system", "content": SYSTEM_PROMPT.format(workflow=" → ".join(workflow))},
|
| 740 |
+
{"role": "user", "content": (
|
| 741 |
+
f"Prospect response: {obs.prospect_response}\n"
|
| 742 |
+
f"Current stage: {obs.workflow_stage}\n"
|
| 743 |
+
f"Steps completed: {obs.steps_completed}\n"
|
| 744 |
+
f"Turn: {obs.turn_number}/20\n"
|
| 745 |
+
f"Violations so far: {obs.constraints_violated}\n\n"
|
| 746 |
+
"What is your next action?"
|
| 747 |
+
)},
|
| 748 |
+
]
|
| 749 |
+
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
async def run_episode(model, tokenizer, env_url: str, difficulty: int = 1) -> dict:
|
| 753 |
+
"""Run one full episode. Returns trajectory with rewards."""
|
| 754 |
+
DIFFICULTY_WORKFLOW = {
|
| 755 |
+
1: ["QUALIFY", "PRESENT", "CLOSE"],
|
| 756 |
+
2: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "CLOSE"],
|
| 757 |
+
3: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO",
|
| 758 |
+
"HANDLE_OBJECTION", "NEGOTIATE", "CLOSE"],
|
| 759 |
+
4: [],
|
| 760 |
+
}
|
| 761 |
+
workflow = DIFFICULTY_WORKFLOW[difficulty]
|
| 762 |
+
|
| 763 |
+
async with SalesPathEnv(base_url=env_url) as env:
|
| 764 |
+
obs = await env.reset(difficulty=difficulty)
|
| 765 |
+
trajectory = []
|
| 766 |
+
total_reward = 0.0
|
| 767 |
+
|
| 768 |
+
while not obs.done:
|
| 769 |
+
prompt = build_prompt(obs, workflow, tokenizer)
|
| 770 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 771 |
+
|
| 772 |
+
with torch.no_grad():
|
| 773 |
+
outputs = model.generate(
|
| 774 |
+
**inputs,
|
| 775 |
+
max_new_tokens=256,
|
| 776 |
+
temperature=0.8,
|
| 777 |
+
do_sample=True,
|
| 778 |
+
)
|
| 779 |
+
generated = tokenizer.decode(
|
| 780 |
+
outputs[0][inputs["input_ids"].shape[1]:],
|
| 781 |
+
skip_special_tokens=True
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
action_type, content = parse_action(generated)
|
| 785 |
+
obs = await env.step(action_type, content)
|
| 786 |
+
|
| 787 |
+
trajectory.append({
|
| 788 |
+
"prompt": prompt,
|
| 789 |
+
"generated": generated,
|
| 790 |
+
"action_type": action_type,
|
| 791 |
+
"reward": obs.reward,
|
| 792 |
+
"components": obs.reward_components,
|
| 793 |
+
"done": obs.done,
|
| 794 |
+
})
|
| 795 |
+
total_reward += obs.reward
|
| 796 |
+
|
| 797 |
+
return {
|
| 798 |
+
"trajectory": trajectory,
|
| 799 |
+
"total_reward": total_reward,
|
| 800 |
+
"steps_completed": obs.steps_completed,
|
| 801 |
+
"violations": obs.constraints_violated,
|
| 802 |
+
"difficulty": difficulty,
|
| 803 |
+
}
|
| 804 |
+
```
|
| 805 |
+
|
| 806 |
+
---
|
| 807 |
+
|
| 808 |
+
## Phase 10: Curriculum Scheduler (Person B) — `training/curriculum.py`
|
| 809 |
+
|
| 810 |
+
```python
|
| 811 |
+
# training/curriculum.py
|
| 812 |
+
from dataclasses import dataclass
|
| 813 |
+
|
| 814 |
+
@dataclass
|
| 815 |
+
class CurriculumConfig:
|
| 816 |
+
thresholds: dict # mean_reward -> difficulty_distribution
|
| 817 |
+
|
| 818 |
+
def get_distribution(self, mean_reward: float) -> dict:
|
| 819 |
+
for threshold in sorted(self.thresholds.keys(), reverse=True):
|
| 820 |
+
if mean_reward >= threshold:
|
| 821 |
+
return self.thresholds[threshold]
|
| 822 |
+
return self.thresholds[min(self.thresholds.keys())]
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
DEFAULT_CURRICULUM = CurriculumConfig(
|
| 826 |
+
thresholds={
|
| 827 |
+
0.0: {1: 0.90, 2: 0.10, 3: 0.00, 4: 0.00},
|
| 828 |
+
0.30: {1: 0.50, 2: 0.40, 3: 0.10, 4: 0.00},
|
| 829 |
+
0.50: {1: 0.20, 2: 0.40, 3: 0.35, 4: 0.05},
|
| 830 |
+
0.65: {1: 0.10, 2: 0.30, 3: 0.40, 4: 0.20},
|
| 831 |
+
}
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
|
| 835 |
+
def sample_difficulty(curriculum: CurriculumConfig, mean_reward: float) -> int:
|
| 836 |
+
import random
|
| 837 |
+
dist = curriculum.get_distribution(mean_reward)
|
| 838 |
+
return random.choices(
|
| 839 |
+
list(dist.keys()),
|
| 840 |
+
weights=list(dist.values()),
|
| 841 |
+
k=1
|
| 842 |
+
)[0]
|
| 843 |
+
```
|
| 844 |
+
|
| 845 |
+
---
|
| 846 |
+
|
| 847 |
+
## Phase 11: Training Script (Person B) — `training/grpo_train.py`
|
| 848 |
+
|
| 849 |
+
```python
|
| 850 |
+
# training/grpo_train.py
|
| 851 |
+
import torch
|
| 852 |
+
import asyncio
|
| 853 |
+
import numpy as np
|
| 854 |
+
from unsloth import FastLanguageModel
|
| 855 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 856 |
+
from curriculum import DEFAULT_CURRICULUM, sample_difficulty
|
| 857 |
+
from rollout import run_episode
|
| 858 |
+
|
| 859 |
+
# --- Model Load ---
|
| 860 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 861 |
+
model_name="unsloth/Qwen2.5-7B-Instruct",
|
| 862 |
+
max_seq_length=2048,
|
| 863 |
+
load_in_4bit=True,
|
| 864 |
+
)
|
| 865 |
+
model = FastLanguageModel.get_peft_model(
|
| 866 |
+
model,
|
| 867 |
+
r=16,
|
| 868 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 869 |
+
"gate_proj", "up_proj", "down_proj"],
|
| 870 |
+
lora_alpha=16,
|
| 871 |
+
lora_dropout=0,
|
| 872 |
+
bias="none",
|
| 873 |
+
use_gradient_checkpointing="unsloth",
|
| 874 |
+
)
|
| 875 |
+
|
| 876 |
+
ENV_URL = "http://localhost:8000" # or HuggingFace Space URL
|
| 877 |
+
|
| 878 |
+
# --- Reward function for GRPO (wraps environment) ---
|
| 879 |
+
def salespath_reward_fn(completions, prompts, **kwargs) -> list[float]:
|
| 880 |
+
"""
|
| 881 |
+
GRPO calls this with a batch of completions.
|
| 882 |
+
We run each through the environment and return rewards.
|
| 883 |
+
"""
|
| 884 |
+
rewards = []
|
| 885 |
+
for completion in completions:
|
| 886 |
+
# Parse action from completion
|
| 887 |
+
from rollout import parse_action
|
| 888 |
+
action_type, content = parse_action(completion)
|
| 889 |
+
# For GRPO, we use a simplified single-step reward
|
| 890 |
+
# Full episode reward is tracked separately in curriculum loop
|
| 891 |
+
reward = kwargs.get("step_rewards", {}).get(completion, 0.0)
|
| 892 |
+
rewards.append(reward)
|
| 893 |
+
return rewards
|
| 894 |
+
|
| 895 |
+
|
| 896 |
+
# --- Training config ---
|
| 897 |
+
training_config = GRPOConfig(
|
| 898 |
+
output_dir="salespath_grpo_output",
|
| 899 |
+
num_train_epochs=3,
|
| 900 |
+
per_device_train_batch_size=2,
|
| 901 |
+
gradient_accumulation_steps=4,
|
| 902 |
+
num_generations=8,
|
| 903 |
+
max_new_tokens=256,
|
| 904 |
+
temperature=0.8,
|
| 905 |
+
learning_rate=1e-5,
|
| 906 |
+
logging_steps=10,
|
| 907 |
+
save_steps=100,
|
| 908 |
+
report_to="none",
|
| 909 |
+
)
|
| 910 |
+
|
| 911 |
+
# --- Curriculum training loop ---
|
| 912 |
+
async def curriculum_train():
|
| 913 |
+
mean_reward = 0.0
|
| 914 |
+
reward_history = []
|
| 915 |
+
|
| 916 |
+
for step in range(500):
|
| 917 |
+
difficulty = sample_difficulty(DEFAULT_CURRICULUM, mean_reward)
|
| 918 |
+
result = await run_episode(model, tokenizer, ENV_URL, difficulty)
|
| 919 |
+
|
| 920 |
+
reward_history.append(result["total_reward"])
|
| 921 |
+
if len(reward_history) > 20:
|
| 922 |
+
mean_reward = np.mean(reward_history[-20:])
|
| 923 |
+
|
| 924 |
+
# Log metrics
|
| 925 |
+
if step % 10 == 0:
|
| 926 |
+
print(f"Step {step:4d} | Difficulty {difficulty} | "
|
| 927 |
+
f"Reward {result['total_reward']:.3f} | "
|
| 928 |
+
f"Mean(20) {mean_reward:.3f} | "
|
| 929 |
+
f"Violations {len(result['violations'])} | "
|
| 930 |
+
f"Steps {result['steps_completed']}")
|
| 931 |
+
|
| 932 |
+
# Manual inspection every 50 steps
|
| 933 |
+
if step % 50 == 0:
|
| 934 |
+
print("\n=== RAW GENERATION SAMPLE ===")
|
| 935 |
+
if result["trajectory"]:
|
| 936 |
+
print(result["trajectory"][0]["generated"])
|
| 937 |
+
print("==============================\n")
|
| 938 |
+
|
| 939 |
+
|
| 940 |
+
if __name__ == "__main__":
|
| 941 |
+
asyncio.run(curriculum_train())
|
| 942 |
+
```
|
| 943 |
+
|
| 944 |
+
---
|
| 945 |
+
|
| 946 |
+
## Phase 12: Dockerfile (Person A) — `server/Dockerfile`
|
| 947 |
+
|
| 948 |
+
```dockerfile
|
| 949 |
+
ARG BASE_IMAGE=openenv-base:latest
|
| 950 |
+
FROM ${BASE_IMAGE}
|
| 951 |
+
|
| 952 |
+
COPY server/requirements.txt /tmp/requirements.txt
|
| 953 |
+
RUN pip install --no-cache-dir -r /tmp/requirements.txt
|
| 954 |
+
|
| 955 |
+
COPY src/openenv/core/ /app/src/openenv/core/
|
| 956 |
+
COPY salespath_env/ /app/salespath_env/
|
| 957 |
+
|
| 958 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 959 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 960 |
+
|
| 961 |
+
CMD ["uvicorn", "salespath_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
| 962 |
+
```
|
| 963 |
+
|
| 964 |
+
`server/requirements.txt`:
|
| 965 |
+
```
|
| 966 |
+
fastapi
|
| 967 |
+
uvicorn
|
| 968 |
+
pydantic>=2.0
|
| 969 |
+
```
|
| 970 |
+
|
| 971 |
+
---
|
| 972 |
+
|
| 973 |
+
## Phase 13: Deploy to HuggingFace
|
| 974 |
+
|
| 975 |
+
```bash
|
| 976 |
+
# From salespath_env/ directory
|
| 977 |
+
openenv push --repo-id Imsachin010/salespath-env
|
| 978 |
+
|
| 979 |
+
# Verify it's running
|
| 980 |
+
curl -X POST https://imsachin010-salespath-env.hf.space/reset \
|
| 981 |
+
-H "Content-Type: application/json" \
|
| 982 |
+
-d '{"difficulty": 1}'
|
| 983 |
+
```
|
| 984 |
+
|
| 985 |
+
---
|
| 986 |
+
|
| 987 |
+
## Phase 14: Model Save (After Training)
|
| 988 |
+
|
| 989 |
+
```python
|
| 990 |
+
# CORRECT save — do not change this
|
| 991 |
+
model.save_pretrained_merged(
|
| 992 |
+
"salespath_trained_merged",
|
| 993 |
+
tokenizer,
|
| 994 |
+
save_method="merged_16bit",
|
| 995 |
+
)
|
| 996 |
+
|
| 997 |
+
# Push to HuggingFace Hub
|
| 998 |
+
model.push_to_hub_merged(
|
| 999 |
+
"Imsachin010/salespath-qwen25-7b",
|
| 1000 |
+
tokenizer,
|
| 1001 |
+
save_method="merged_16bit",
|
| 1002 |
+
)
|
| 1003 |
+
```
|
| 1004 |
+
|
| 1005 |
+
---
|
| 1006 |
+
|
| 1007 |
+
## Build Order Summary
|
| 1008 |
+
|
| 1009 |
+
```
|
| 1010 |
+
Person A (Environment): Person B (Training):
|
| 1011 |
+
1. models.py (wait for models.py)
|
| 1012 |
+
2. server/task_bank.py 1. server/reward.py
|
| 1013 |
+
3. server/rules.py 2. training/rollout.py
|
| 1014 |
+
4. server/prospect_simulator.py 3. training/curriculum.py
|
| 1015 |
+
5. server/salespath_environment 4. training/grpo_train.py
|
| 1016 |
+
6. server/app.py 5. training/colab_train.ipynb
|
| 1017 |
+
7. Dockerfile
|
| 1018 |
+
8. openenv push → verify health
|
| 1019 |
+
6. Connect rollout to live env URL
|
| 1020 |
+
7. Run first training loop (difficulty=1 only)
|
| 1021 |
+
8. Verify reward > 0 on step 1
|
| 1022 |
+
9. Enable curriculum
|
| 1023 |
+
```
|
| 1024 |
+
|
| 1025 |
+
**Critical gate:** Person B does not run training until Person A has confirmed:
|
| 1026 |
+
- `POST /reset` returns a valid observation
|
| 1027 |
+
- `POST /step` with a valid action returns a valid observation
|
| 1028 |
+
- `POST /step` with an invalid action returns error in `info`, not a 500
|
Dockerfile
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
# HuggingFace Spaces runs on port 7860 by default
|
| 4 |
+
ENV PORT=7860
|
| 5 |
+
ENV PYTHONUNBUFFERED=1
|
| 6 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 7 |
+
|
| 8 |
+
WORKDIR /app
|
| 9 |
+
|
| 10 |
+
# Install system dependencies
|
| 11 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 12 |
+
curl \
|
| 13 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
+
|
| 15 |
+
# Install Python dependencies
|
| 16 |
+
COPY requirements.txt .
|
| 17 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 18 |
+
|
| 19 |
+
# Copy the salespath_env package
|
| 20 |
+
COPY salespath_env/ ./salespath_env/
|
| 21 |
+
|
| 22 |
+
# Health check
|
| 23 |
+
HEALTHCHECK --interval=30s --timeout=5s --start-period=15s --retries=3 \
|
| 24 |
+
CMD curl -f http://localhost:${PORT}/health || exit 1
|
| 25 |
+
|
| 26 |
+
# Start the FastAPI server on HF Spaces port
|
| 27 |
+
CMD ["sh", "-c", "uvicorn salespath_env.server.app:app --host 0.0.0.0 --port ${PORT}"]
|
README.md
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: SalesPath Environment
|
| 3 |
+
emoji: 🤝
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
pinned: false
|
| 9 |
+
license: mit
|
| 10 |
+
short_description: RL gym environment for sales agent training
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# SalesPath Environment
|
| 14 |
+
|
| 15 |
+
A [OpenEnv](https://github.com/openenv)-compatible Reinforcement Learning gym environment for training sales agents via LLM fine-tuning.
|
| 16 |
+
|
| 17 |
+
## API Endpoints
|
| 18 |
+
|
| 19 |
+
| Method | Endpoint | Description |
|
| 20 |
+
|--------|----------|-------------|
|
| 21 |
+
| `POST` | `/reset` | Reset the environment, returns initial observation |
|
| 22 |
+
| `POST` | `/step` | Take an action, returns next observation + reward |
|
| 23 |
+
| `GET` | `/health` | Health check |
|
| 24 |
+
|
| 25 |
+
## Quick Start
|
| 26 |
+
|
| 27 |
+
### Reset
|
| 28 |
+
```bash
|
| 29 |
+
curl -X POST https://imsachin010-salespath-env.hf.space/reset \
|
| 30 |
+
-H "Content-Type: application/json" \
|
| 31 |
+
-d '{"difficulty": 1}'
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
### Step
|
| 35 |
+
```bash
|
| 36 |
+
curl -X POST https://imsachin010-salespath-env.hf.space/step \
|
| 37 |
+
-H "Content-Type: application/json" \
|
| 38 |
+
-d '{"action": {"action_type": "PROSPECT", "content": "Hello, tell me about your workflow challenges."}}'
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
## Action Types
|
| 42 |
+
|
| 43 |
+
- `PROSPECT` — Initial outreach and discovery
|
| 44 |
+
- `QUALIFY` — Qualify the lead
|
| 45 |
+
- `PRESENT` — Deliver the sales pitch
|
| 46 |
+
- `HANDLE_OBJECTION` — Handle prospect objections
|
| 47 |
+
- `OFFER_DEMO` — Offer product demonstration
|
| 48 |
+
- `NEGOTIATE` — Discuss pricing and terms
|
| 49 |
+
- `FOLLOW_UP` — Follow-up message
|
| 50 |
+
- `DISQUALIFY` — Exit if prospect is not a fit
|
| 51 |
+
- `CLOSE` — Attempt to close the deal
|
RULES.md
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SalesPath — Agent Rules & Constraints
|
| 2 |
+
### Read this before touching any file. These are non-negotiable.
|
| 3 |
+
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
## 0. Project Identity
|
| 7 |
+
|
| 8 |
+
- **Project name:** `salespath_env`
|
| 9 |
+
- **HuggingFace repo:** `Imsachin010/salespath-env`
|
| 10 |
+
- **Theme:** Theme #2 — Long-Horizon Planning (Scale AI bonus prize)
|
| 11 |
+
- **Stack:** OpenEnv + GRPO (HF TRL) + Unsloth + Qwen 2.5 7B Instruct
|
| 12 |
+
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
## 1. Directory Structure — Do Not Deviate
|
| 16 |
+
|
| 17 |
+
```
|
| 18 |
+
salespath_env/
|
| 19 |
+
├── __init__.py
|
| 20 |
+
├── models.py ← ALL Pydantic dataclasses live here only
|
| 21 |
+
├── client.py ← SalesPathEnv(EnvClient) lives here only
|
| 22 |
+
├── README.md
|
| 23 |
+
├── openenv.yaml
|
| 24 |
+
├── pyproject.toml
|
| 25 |
+
├── server/
|
| 26 |
+
│ ├── __init__.py
|
| 27 |
+
│ ├── salespath_environment.py ← SalesPathEnvironment(Environment)
|
| 28 |
+
│ ├── prospect_simulator.py ← ProspectSimulator (rule-based only)
|
| 29 |
+
│ ├── reward.py ← ALL reward logic lives here only
|
| 30 |
+
│ ├── task_bank.py ← ALL prospect profiles and tasks
|
| 31 |
+
│ ├── rules.py ← ALL business rule definitions
|
| 32 |
+
│ ├── app.py ← FastAPI app only, no logic
|
| 33 |
+
│ ├── requirements.txt
|
| 34 |
+
│ └── Dockerfile
|
| 35 |
+
training/
|
| 36 |
+
├── grpo_train.py ← training script
|
| 37 |
+
├── rollout.py ← rollout function
|
| 38 |
+
├── curriculum.py ← difficulty scheduler
|
| 39 |
+
└── colab_train.ipynb ← Colab notebook for judges
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
---
|
| 43 |
+
|
| 44 |
+
## 2. OpenEnv API — Exact Signatures to Follow
|
| 45 |
+
|
| 46 |
+
```python
|
| 47 |
+
# models.py — extend these base classes
|
| 48 |
+
from openenv.core import Action, Observation, State # actual imports
|
| 49 |
+
|
| 50 |
+
class SalesPathAction(Action):
|
| 51 |
+
action_type: str # one of the 9 valid action types
|
| 52 |
+
content: str # natural language content of the action
|
| 53 |
+
target: str = "" # optional target (e.g., which objection)
|
| 54 |
+
|
| 55 |
+
class SalesPathObservation(Observation):
|
| 56 |
+
prospect_response: str
|
| 57 |
+
workflow_stage: str
|
| 58 |
+
constraints_violated: list[str]
|
| 59 |
+
steps_completed: list[str]
|
| 60 |
+
turn_number: int
|
| 61 |
+
reward: float
|
| 62 |
+
done: bool
|
| 63 |
+
info: dict
|
| 64 |
+
|
| 65 |
+
class SalesPathState(State):
|
| 66 |
+
episode_id: str
|
| 67 |
+
prospect_profile: dict
|
| 68 |
+
conversation_history: list[dict]
|
| 69 |
+
workflow_stage: str
|
| 70 |
+
steps_completed: list[str]
|
| 71 |
+
constraints_violated: list[str]
|
| 72 |
+
turn_number: int
|
| 73 |
+
difficulty: int # 1, 2, 3, or 4
|
| 74 |
+
hidden_state: dict # NOT exposed to agent
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
```python
|
| 78 |
+
# server/salespath_environment.py
|
| 79 |
+
from openenv.core.env_server import Environment
|
| 80 |
+
|
| 81 |
+
class SalesPathEnvironment(Environment):
|
| 82 |
+
def reset(self, difficulty: int = 1) -> SalesPathObservation: ...
|
| 83 |
+
def step(self, action: SalesPathAction) -> SalesPathObservation: ...
|
| 84 |
+
@property
|
| 85 |
+
def state(self) -> SalesPathState: ...
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
```python
|
| 89 |
+
# server/app.py — nothing else in this file
|
| 90 |
+
from openenv.core.env_server import create_fastapi_app
|
| 91 |
+
from ..models import SalesPathAction, SalesPathObservation
|
| 92 |
+
from .salespath_environment import SalesPathEnvironment
|
| 93 |
+
|
| 94 |
+
app = create_fastapi_app(SalesPathEnvironment, SalesPathAction, SalesPathObservation)
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
---
|
| 98 |
+
|
| 99 |
+
## 3. Hard Rules — Code Will Be Rejected If Violated
|
| 100 |
+
|
| 101 |
+
### 3.1 No LLM in the Environment
|
| 102 |
+
- `ProspectSimulator` is a **pure rule-based state machine**
|
| 103 |
+
- No API calls, no model inference, no `transformers` imports inside `server/`
|
| 104 |
+
- If you find yourself writing `model.generate()` inside `server/`, stop. Wrong file.
|
| 105 |
+
|
| 106 |
+
### 3.2 Immutable Prospect State
|
| 107 |
+
- Once `reset()` sets the prospect profile, agent actions **cannot modify `hidden_state`**
|
| 108 |
+
- `hidden_state` is read-only after `reset()`
|
| 109 |
+
- Never expose `hidden_state` fields in `SalesPathObservation`
|
| 110 |
+
|
| 111 |
+
### 3.3 Reward Lives in One Place
|
| 112 |
+
- All reward computation goes in `server/reward.py`
|
| 113 |
+
- `salespath_environment.py` calls `compute_reward()` — it does not compute reward itself
|
| 114 |
+
- Never compute reward inside `step()` directly
|
| 115 |
+
|
| 116 |
+
### 3.4 Business Rules Live in One Place
|
| 117 |
+
- All rule definitions go in `server/rules.py` as a list of `BusinessRule` dataclasses
|
| 118 |
+
- `step()` calls `check_rules(state, action)` from `rules.py` — it does not check rules inline
|
| 119 |
+
|
| 120 |
+
### 3.5 Turn Limit is Absolute
|
| 121 |
+
- Max turns = 20. Hard terminate. No exceptions.
|
| 122 |
+
- Episode must set `done=True` and assign `r_outcome = -0.3` at turn 20 regardless of state
|
| 123 |
+
|
| 124 |
+
### 3.6 Action Validation is Strict
|
| 125 |
+
- If `action_type` is not one of the 9 valid types, return `done=False`, `reward=-0.2`, observation with error message
|
| 126 |
+
- Do not raise exceptions to the agent — return a valid `SalesPathObservation` with error in `info`
|
| 127 |
+
|
| 128 |
+
### 3.7 Reward Must Be Multi-Component
|
| 129 |
+
- Reward function must log all 5 components separately in `info` dict
|
| 130 |
+
- Never return a single scalar reward without component breakdown
|
| 131 |
+
- Component keys: `r_outcome`, `r_compliance`, `r_ordering`, `r_efficiency`, `r_format`
|
| 132 |
+
|
| 133 |
+
### 3.8 No Global Mutable State in Environment
|
| 134 |
+
- Each WebSocket session gets its own `SalesPathEnvironment` instance
|
| 135 |
+
- No class-level variables that change during episodes
|
| 136 |
+
- No module-level state
|
| 137 |
+
|
| 138 |
+
---
|
| 139 |
+
|
| 140 |
+
## 4. Valid Action Types — Exact Strings
|
| 141 |
+
|
| 142 |
+
```python
|
| 143 |
+
VALID_ACTIONS = {
|
| 144 |
+
"PROSPECT", # initial outreach — only valid on turn 1
|
| 145 |
+
"QUALIFY", # ask qualification questions
|
| 146 |
+
"PRESENT", # deliver pitch
|
| 147 |
+
"HANDLE_OBJECTION", # respond to raised objection
|
| 148 |
+
"OFFER_DEMO", # propose product demonstration
|
| 149 |
+
"NEGOTIATE", # discuss pricing/terms
|
| 150 |
+
"CLOSE", # submit closing offer → terminates episode
|
| 151 |
+
"FOLLOW_UP", # follow up after no response
|
| 152 |
+
"DISQUALIFY", # exit if prospect is not a fit → terminates episode
|
| 153 |
+
}
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
---
|
| 157 |
+
|
| 158 |
+
## 5. Business Rules — Exact Definitions
|
| 159 |
+
|
| 160 |
+
These are checked after every `step()`. Each violation increments `constraints_violated`.
|
| 161 |
+
|
| 162 |
+
```python
|
| 163 |
+
RULES = [
|
| 164 |
+
# ID Name Condition for VIOLATION
|
| 165 |
+
R01 "qualify_before_present" PRESENT called before any QUALIFY
|
| 166 |
+
R02 "demo_before_negotiate" NEGOTIATE called before OFFER_DEMO
|
| 167 |
+
R03 "budget_known_to_negotiate" NEGOTIATE called while budget_signal == "unknown"
|
| 168 |
+
R04 "discount_after_objections" Discount mentioned in NEGOTIATE before 2 objections handled
|
| 169 |
+
R05 "no_repeat_action" Same action_type on consecutive turns
|
| 170 |
+
R06 "prospect_first" Any action other than PROSPECT on turn 1
|
| 171 |
+
R07 "followup_timing" FOLLOW_UP called when prospect responded last turn
|
| 172 |
+
R08 "disqualify_logic" DISQUALIFY called when budget >= threshold AND decision_maker==True
|
| 173 |
+
R09 "close_requires_demo" CLOSE called before OFFER_DEMO
|
| 174 |
+
]
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
Three violations → `done=True`, `r_outcome = -0.5`
|
| 178 |
+
|
| 179 |
+
---
|
| 180 |
+
|
| 181 |
+
## 6. Prospect Simulator — Exact Response Rules
|
| 182 |
+
|
| 183 |
+
`ProspectSimulator.respond(action, state)` returns one of these string tokens. The environment converts tokens to natural language text for the observation.
|
| 184 |
+
|
| 185 |
+
```python
|
| 186 |
+
RESPONSE_TOKENS = {
|
| 187 |
+
"open:positive_signal", # prospect is engaged and open
|
| 188 |
+
"open:neutral_signal", # prospect acknowledges but non-committal
|
| 189 |
+
"objection:price", # raises price objection
|
| 190 |
+
"objection:timing", # raises timing objection
|
| 191 |
+
"objection:premature_pitch", # triggered by R01 violation
|
| 192 |
+
"deflect:budget_not_discussed", # triggered by R03 violation
|
| 193 |
+
"deflect:stall", # prospect stalls (Level 3+)
|
| 194 |
+
"accept:demo_scheduled", # agrees to demo
|
| 195 |
+
"accept:close_success", # agrees to close → episode success
|
| 196 |
+
"reject:close_failed", # rejects close
|
| 197 |
+
"silence", # no response (enables FOLLOW_UP)
|
| 198 |
+
"exit:disqualified", # prospect exits conversation
|
| 199 |
+
}
|
| 200 |
+
```
|
| 201 |
+
|
| 202 |
+
---
|
| 203 |
+
|
| 204 |
+
## 7. Difficulty Configuration
|
| 205 |
+
|
| 206 |
+
```python
|
| 207 |
+
DIFFICULTY_CONFIG = {
|
| 208 |
+
1: {
|
| 209 |
+
"max_turns": 20,
|
| 210 |
+
"workflow_steps": ["QUALIFY", "PRESENT", "CLOSE"],
|
| 211 |
+
"num_objections": 0,
|
| 212 |
+
"budget_hidden": False,
|
| 213 |
+
"mode_shift": False,
|
| 214 |
+
"optimal_turns": 5,
|
| 215 |
+
},
|
| 216 |
+
2: {
|
| 217 |
+
"max_turns": 20,
|
| 218 |
+
"workflow_steps": ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "CLOSE"],
|
| 219 |
+
"num_objections": 1,
|
| 220 |
+
"budget_hidden": True, # revealed after QUALIFY
|
| 221 |
+
"mode_shift": False,
|
| 222 |
+
"optimal_turns": 8,
|
| 223 |
+
},
|
| 224 |
+
3: {
|
| 225 |
+
"max_turns": 20,
|
| 226 |
+
"workflow_steps": ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO",
|
| 227 |
+
"HANDLE_OBJECTION", "NEGOTIATE", "CLOSE"],
|
| 228 |
+
"num_objections": 2,
|
| 229 |
+
"budget_hidden": True,
|
| 230 |
+
"mode_shift": True, # prospect signals shift at turn 10
|
| 231 |
+
"optimal_turns": 12,
|
| 232 |
+
},
|
| 233 |
+
4: {
|
| 234 |
+
"max_turns": 20,
|
| 235 |
+
"workflow_steps": "full", # agent must determine correct path
|
| 236 |
+
"num_objections": 2,
|
| 237 |
+
"budget_hidden": True,
|
| 238 |
+
"mode_shift": True,
|
| 239 |
+
"misleading_signals": True, # budget signals are deceptive
|
| 240 |
+
"optimal_turns": 14,
|
| 241 |
+
},
|
| 242 |
+
}
|
| 243 |
+
```
|
| 244 |
+
|
| 245 |
+
---
|
| 246 |
+
|
| 247 |
+
## 8. Reward — Exact Weights
|
| 248 |
+
|
| 249 |
+
```python
|
| 250 |
+
REWARD_WEIGHTS = {
|
| 251 |
+
"r_outcome": 0.40,
|
| 252 |
+
"r_compliance": 0.30,
|
| 253 |
+
"r_ordering": 0.15,
|
| 254 |
+
"r_efficiency": 0.10,
|
| 255 |
+
"r_format": 0.05,
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
OUTCOME_VALUES = {
|
| 259 |
+
"close_success": 1.0,
|
| 260 |
+
"disqualify_correct": 0.5,
|
| 261 |
+
"turn_limit_reached": -0.3,
|
| 262 |
+
"close_failed": -0.5,
|
| 263 |
+
"three_violations": -0.5,
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
COMPLIANCE_PER_VIOLATION = -0.2 # capped at -1.0
|
| 267 |
+
EFFICIENCY_PER_EXTRA_TURN = -0.05 # capped at -0.3
|
| 268 |
+
FORMAT_PASS = 1.0
|
| 269 |
+
FORMAT_FAIL = -0.1
|
| 270 |
+
```
|
| 271 |
+
|
| 272 |
+
---
|
| 273 |
+
|
| 274 |
+
## 9. Training Rules
|
| 275 |
+
|
| 276 |
+
### Prompt Format (what gets sent to the LLM)
|
| 277 |
+
```
|
| 278 |
+
System: You are a B2B sales agent. Follow this workflow strictly:
|
| 279 |
+
{workflow_steps_for_difficulty}
|
| 280 |
+
|
| 281 |
+
Business rules you must never violate:
|
| 282 |
+
{rules_list}
|
| 283 |
+
|
| 284 |
+
Current state:
|
| 285 |
+
- Prospect: {prospect_summary}
|
| 286 |
+
- Stage: {workflow_stage}
|
| 287 |
+
- Steps done: {steps_completed}
|
| 288 |
+
- Turn: {turn_number}/20
|
| 289 |
+
|
| 290 |
+
Prospect said: {prospect_response}
|
| 291 |
+
|
| 292 |
+
Respond with:
|
| 293 |
+
ACTION: <action_type>
|
| 294 |
+
CONTENT: <your message>
|
| 295 |
+
```
|
| 296 |
+
|
| 297 |
+
### Response parsing
|
| 298 |
+
- Extract `ACTION:` line → `action_type`
|
| 299 |
+
- Extract `CONTENT:` line → `content`
|
| 300 |
+
- If parsing fails → `r_format = -0.1`, use fallback QUALIFY
|
| 301 |
+
|
| 302 |
+
### GRPO config
|
| 303 |
+
```python
|
| 304 |
+
GRPOConfig(
|
| 305 |
+
num_generations=8, # rollouts per prompt
|
| 306 |
+
max_new_tokens=256,
|
| 307 |
+
temperature=0.8,
|
| 308 |
+
learning_rate=1e-5,
|
| 309 |
+
per_device_train_batch_size=2,
|
| 310 |
+
gradient_accumulation_steps=4,
|
| 311 |
+
)
|
| 312 |
+
```
|
| 313 |
+
|
| 314 |
+
---
|
| 315 |
+
|
| 316 |
+
## 10. What to Monitor During Training
|
| 317 |
+
|
| 318 |
+
Log these every 10 steps. If any of these goes wrong, stop and inspect raw generations:
|
| 319 |
+
|
| 320 |
+
| Metric | Healthy Range | Alarm |
|
| 321 |
+
|--------|--------------|-------|
|
| 322 |
+
| `mean_reward` | Rising | Flat for >50 steps |
|
| 323 |
+
| `mean_r_compliance` | Rising | < -0.5 after step 100 |
|
| 324 |
+
| `violations_per_episode` | Falling | > 3.0 after step 100 |
|
| 325 |
+
| `ordering_rate` | Rising toward 0.85 | < 0.3 after step 150 |
|
| 326 |
+
| `close_success_rate` | Rising | 0 after step 200 |
|
| 327 |
+
|
| 328 |
+
Inspect raw generations every 50 steps. Look for: repeated actions, empty CONTENT, invalid ACTION types, CLOSE before QUALIFY.
|
| 329 |
+
|
| 330 |
+
---
|
| 331 |
+
|
| 332 |
+
## 11. Save Model Correctly
|
| 333 |
+
|
| 334 |
+
```python
|
| 335 |
+
# CORRECT — do not deviate
|
| 336 |
+
model.save_pretrained_merged(
|
| 337 |
+
"salespath_trained",
|
| 338 |
+
tokenizer,
|
| 339 |
+
save_method="merged_16bit", # NOT naive upcast of 4bit
|
| 340 |
+
)
|
| 341 |
+
```
|
| 342 |
+
|
| 343 |
+
Never do: `model.save_pretrained()` on a 4-bit model without merging first.
|
| 344 |
+
|
| 345 |
+
---
|
| 346 |
+
|
| 347 |
+
## 12. File Ownership (2-Person Team)
|
| 348 |
+
|
| 349 |
+
| Person | Files |
|
| 350 |
+
|--------|-------|
|
| 351 |
+
| **A** | `models.py`, `server/salespath_environment.py`, `server/prospect_simulator.py`, `server/rules.py`, `server/task_bank.py`, `server/app.py`, `Dockerfile` |
|
| 352 |
+
| **B** | `server/reward.py`, `training/grpo_train.py`, `training/rollout.py`, `training/curriculum.py`, `training/colab_train.ipynb`, `client.py` |
|
| 353 |
+
|
| 354 |
+
Both: `README.md`, `openenv.yaml`, `pyproject.toml`
|
push_to_hub.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import HfApi
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
REPO_ID = "Imsachin010/salespath-env"
|
| 5 |
+
FOLDER_PATH = "."
|
| 6 |
+
|
| 7 |
+
IGNORE_PATTERNS = [
|
| 8 |
+
"*.pyc",
|
| 9 |
+
"**/__pycache__/**",
|
| 10 |
+
".git/**",
|
| 11 |
+
".spa/**",
|
| 12 |
+
".SPA/**",
|
| 13 |
+
"*.egg-info/**",
|
| 14 |
+
"push_to_hub.py",
|
| 15 |
+
"salespath_env/server/Dockerfile", # root Dockerfile is used instead
|
| 16 |
+
"training/**", # exclude training scripts from Space
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
def main():
|
| 20 |
+
api = HfApi()
|
| 21 |
+
|
| 22 |
+
api.create_repo(
|
| 23 |
+
repo_id=REPO_ID,
|
| 24 |
+
repo_type="space",
|
| 25 |
+
space_sdk="docker",
|
| 26 |
+
exist_ok=True,
|
| 27 |
+
private=False,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
api.upload_folder(
|
| 31 |
+
folder_path=FOLDER_PATH,
|
| 32 |
+
repo_id=REPO_ID,
|
| 33 |
+
repo_type="space",
|
| 34 |
+
ignore_patterns=IGNORE_PATTERNS,
|
| 35 |
+
commit_message="Deploy SalesPath Environment",
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
print(
|
| 39 |
+
f"Live Space URL:\n"
|
| 40 |
+
f"https://{REPO_ID.replace('/', '-')}.hf.space"
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
if __name__ == "__main__":
|
| 44 |
+
main()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=42"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "salespath_env"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
requires-python = ">=3.10"
|
| 9 |
+
dependencies = [
|
| 10 |
+
"openenv",
|
| 11 |
+
"fastapi",
|
| 12 |
+
"uvicorn",
|
| 13 |
+
"pydantic>=2.0",
|
| 14 |
+
"trl>=0.8.0",
|
| 15 |
+
"unsloth",
|
| 16 |
+
"torch",
|
| 17 |
+
"transformers",
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
[tool.setuptools.packages.find]
|
| 21 |
+
where = ["."]
|
| 22 |
+
include = ["salespath_env*"]
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.110.0
|
| 2 |
+
uvicorn[standard]>=0.29.0
|
| 3 |
+
pydantic>=2.0
|
| 4 |
+
openenv-core>=0.2.3
|
salespath_env.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: salespath-env
|
| 3 |
+
Version: 0.1.0
|
| 4 |
+
Requires-Python: >=3.10
|
| 5 |
+
Requires-Dist: openenv-core>=0.2.3
|
| 6 |
+
Requires-Dist: fastapi>=0.110.0
|
| 7 |
+
Requires-Dist: uvicorn[standard]>=0.29.0
|
| 8 |
+
Requires-Dist: pydantic>=2.0
|
salespath_env.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
README.md
|
| 2 |
+
pyproject.toml
|
| 3 |
+
salespath_env/__init__.py
|
| 4 |
+
salespath_env/client.py
|
| 5 |
+
salespath_env/models.py
|
| 6 |
+
salespath_env.egg-info/PKG-INFO
|
| 7 |
+
salespath_env.egg-info/SOURCES.txt
|
| 8 |
+
salespath_env.egg-info/dependency_links.txt
|
| 9 |
+
salespath_env.egg-info/requires.txt
|
| 10 |
+
salespath_env.egg-info/top_level.txt
|
| 11 |
+
salespath_env/server/__init__.py
|
| 12 |
+
salespath_env/server/app.py
|
| 13 |
+
salespath_env/server/prospect_simulator.py
|
| 14 |
+
salespath_env/server/reward.py
|
| 15 |
+
salespath_env/server/rules.py
|
| 16 |
+
salespath_env/server/salespath_environment.py
|
| 17 |
+
salespath_env/server/task_bank.py
|
salespath_env.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
salespath_env.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core>=0.2.3
|
| 2 |
+
fastapi>=0.110.0
|
| 3 |
+
uvicorn[standard]>=0.29.0
|
| 4 |
+
pydantic>=2.0
|
salespath_env.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
salespath_env
|
salespath_env/README.md
ADDED
|
File without changes
|
salespath_env/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SalesPath OpenEnv package."""
|
| 2 |
+
|
salespath_env/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (218 Bytes). View file
|
|
|
salespath_env/__pycache__/client.cpython-313.pyc
ADDED
|
Binary file (3.56 kB). View file
|
|
|
salespath_env/__pycache__/models.cpython-313.pyc
ADDED
|
Binary file (3.36 kB). View file
|
|
|
salespath_env/client.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# salespath_env/client.py
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict
|
| 4 |
+
|
| 5 |
+
from openenv.core import EnvClient
|
| 6 |
+
from openenv.core.client_types import StepResult
|
| 7 |
+
|
| 8 |
+
from .models import (
|
| 9 |
+
SalesPathAction,
|
| 10 |
+
SalesPathObservation,
|
| 11 |
+
SalesPathState,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SalesPathEnv(EnvClient[SalesPathAction, SalesPathObservation, SalesPathState]):
|
| 16 |
+
|
| 17 |
+
# ------------------------------------------------------------------ #
|
| 18 |
+
# Abstract method implementations required by EnvClient #
|
| 19 |
+
# ------------------------------------------------------------------ #
|
| 20 |
+
|
| 21 |
+
def _step_payload(self, action: SalesPathAction) -> Dict[str, Any]:
|
| 22 |
+
"""Serialise action → JSON dict for the WebSocket server.
|
| 23 |
+
WSStepMessage.data IS the action dict directly (no wrapper key).
|
| 24 |
+
"""
|
| 25 |
+
return action.model_dump(exclude={"metadata"})
|
| 26 |
+
|
| 27 |
+
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[SalesPathObservation]:
|
| 28 |
+
"""Deserialise server JSON → StepResult[SalesPathObservation]."""
|
| 29 |
+
# Server may nest obs under an 'observation' key
|
| 30 |
+
obs_data = payload.get("observation", payload)
|
| 31 |
+
obs = SalesPathObservation(**obs_data)
|
| 32 |
+
return StepResult(
|
| 33 |
+
observation=obs,
|
| 34 |
+
reward=payload.get("reward", obs.reward),
|
| 35 |
+
done=payload.get("done", obs.done),
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def _parse_state(self, payload: Dict[str, Any]) -> SalesPathState:
|
| 39 |
+
"""Deserialise server JSON → SalesPathState."""
|
| 40 |
+
state_data = payload.get("state", payload)
|
| 41 |
+
return SalesPathState(**state_data)
|
| 42 |
+
|
| 43 |
+
# ------------------------------------------------------------------ #
|
| 44 |
+
# Convenience wrappers that return the unwrapped observation directly #
|
| 45 |
+
# ------------------------------------------------------------------ #
|
| 46 |
+
|
| 47 |
+
@staticmethod
|
| 48 |
+
def _with_step_fields(
|
| 49 |
+
result: StepResult[SalesPathObservation],
|
| 50 |
+
) -> SalesPathObservation:
|
| 51 |
+
"""
|
| 52 |
+
Keep observation fields in sync with StepResult wrapper fields.
|
| 53 |
+
Some server payloads provide reward/done only at top-level.
|
| 54 |
+
"""
|
| 55 |
+
return result.observation.model_copy(
|
| 56 |
+
update={
|
| 57 |
+
"reward": result.reward,
|
| 58 |
+
"done": result.done,
|
| 59 |
+
}
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
async def reset(
|
| 63 |
+
self,
|
| 64 |
+
difficulty: int = 1,
|
| 65 |
+
) -> SalesPathObservation:
|
| 66 |
+
result = await super().reset(difficulty=difficulty)
|
| 67 |
+
return self._with_step_fields(result)
|
| 68 |
+
|
| 69 |
+
async def step(
|
| 70 |
+
self,
|
| 71 |
+
action_type: str,
|
| 72 |
+
content: str,
|
| 73 |
+
target: str = "",
|
| 74 |
+
) -> SalesPathObservation:
|
| 75 |
+
action = SalesPathAction(
|
| 76 |
+
action_type=action_type,
|
| 77 |
+
content=content,
|
| 78 |
+
target=target,
|
| 79 |
+
)
|
| 80 |
+
result = await super().step(action)
|
| 81 |
+
return self._with_step_fields(result)
|
salespath_env/models.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# salespath_env/models.py
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import uuid
|
| 6 |
+
from typing import Dict, List
|
| 7 |
+
from pydantic import BaseModel, Field
|
| 8 |
+
|
| 9 |
+
# Safe OpenEnv Imports: Use OpenEnv base classes if available,
|
| 10 |
+
# otherwise fall back to Pydantic to bypass security blocks.
|
| 11 |
+
try:
|
| 12 |
+
from openenv.core import Action, Observation, State
|
| 13 |
+
except (ImportError, Exception):
|
| 14 |
+
Action = BaseModel
|
| 15 |
+
Observation = BaseModel
|
| 16 |
+
State = BaseModel
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
VALID_ACTIONS = {
|
| 20 |
+
"PROSPECT",
|
| 21 |
+
"QUALIFY",
|
| 22 |
+
"PRESENT",
|
| 23 |
+
"HANDLE_OBJECTION",
|
| 24 |
+
"OFFER_DEMO",
|
| 25 |
+
"NEGOTIATE",
|
| 26 |
+
"CLOSE",
|
| 27 |
+
"FOLLOW_UP",
|
| 28 |
+
"DISQUALIFY",
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class SalesPathAction(Action):
|
| 33 |
+
"""
|
| 34 |
+
Action sent by the agent to the environment.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
action_type: str
|
| 38 |
+
content: str
|
| 39 |
+
target: str = ""
|
| 40 |
+
|
| 41 |
+
def is_valid(self) -> bool:
|
| 42 |
+
"""
|
| 43 |
+
Strict validation of allowed action types.
|
| 44 |
+
"""
|
| 45 |
+
return self.action_type in VALID_ACTIONS
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class SalesPathObservation(Observation):
|
| 49 |
+
"""
|
| 50 |
+
What the agent is allowed to observe.
|
| 51 |
+
Hidden state must NEVER be exposed here.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
prospect_response: str = ""
|
| 55 |
+
workflow_stage: str = "START"
|
| 56 |
+
|
| 57 |
+
constraints_violated: List[str] = Field(default_factory=list)
|
| 58 |
+
steps_completed: List[str] = Field(default_factory=list)
|
| 59 |
+
|
| 60 |
+
turn_number: int = 0
|
| 61 |
+
|
| 62 |
+
reward: float = 0.0
|
| 63 |
+
reward_components: Dict = Field(default_factory=dict)
|
| 64 |
+
|
| 65 |
+
done: bool = False
|
| 66 |
+
info: Dict = Field(default_factory=dict)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class SalesPathState(State):
|
| 70 |
+
"""
|
| 71 |
+
Internal environment state.
|
| 72 |
+
Includes hidden state not exposed to the agent.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
episode_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
| 76 |
+
|
| 77 |
+
prospect_profile: Dict = Field(default_factory=dict)
|
| 78 |
+
conversation_history: List[Dict] = Field(default_factory=list)
|
| 79 |
+
|
| 80 |
+
workflow_stage: str = "START"
|
| 81 |
+
required_workflow: List[str] = Field(default_factory=list)
|
| 82 |
+
|
| 83 |
+
steps_completed: List[str] = Field(default_factory=list)
|
| 84 |
+
constraints_violated: List[str] = Field(default_factory=list)
|
| 85 |
+
|
| 86 |
+
objections_handled: int = 0
|
| 87 |
+
turn_number: int = 0
|
| 88 |
+
difficulty: int = 1
|
| 89 |
+
|
| 90 |
+
done: bool = False
|
| 91 |
+
|
| 92 |
+
# Hidden state — NEVER exposed in Observation
|
| 93 |
+
hidden_state: Dict = Field(default_factory=dict)
|
salespath_env/openenv.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "salespath_env"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
dependencies = [
|
| 5 |
+
"openenv",
|
| 6 |
+
"fastapi",
|
| 7 |
+
"uvicorn",
|
| 8 |
+
"pydantic>=2.0",
|
| 9 |
+
"trl>=0.8.0",
|
| 10 |
+
"unsloth",
|
| 11 |
+
"torch",
|
| 12 |
+
"transformers",
|
| 13 |
+
]
|
salespath_env/pyproject.toml
ADDED
|
File without changes
|
salespath_env/server/Dockerfile
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ARG BASE_IMAGE=openenv-base:latest
|
| 2 |
+
FROM ${BASE_IMAGE}
|
| 3 |
+
|
| 4 |
+
COPY server/requirements.txt /tmp/requirements.txt
|
| 5 |
+
RUN pip install --no-cache-dir -r /tmp/requirements.txt
|
| 6 |
+
|
| 7 |
+
COPY salespath_env/ /app/salespath_env/
|
| 8 |
+
|
| 9 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 10 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 11 |
+
|
| 12 |
+
CMD ["uvicorn", "salespath_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
salespath_env/server/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SalesPath environment server package."""
|
| 2 |
+
|
salespath_env/server/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (236 Bytes). View file
|
|
|
salespath_env/server/__pycache__/app.cpython-313.pyc
ADDED
|
Binary file (455 Bytes). View file
|
|
|
salespath_env/server/__pycache__/prospect_simulator.cpython-313.pyc
ADDED
|
Binary file (4.15 kB). View file
|
|
|
salespath_env/server/__pycache__/reward.cpython-313.pyc
ADDED
|
Binary file (3.1 kB). View file
|
|
|
salespath_env/server/__pycache__/rules.cpython-313.pyc
ADDED
|
Binary file (6.24 kB). View file
|
|
|
salespath_env/server/__pycache__/salespath_environment.cpython-313.pyc
ADDED
|
Binary file (6.4 kB). View file
|
|
|
salespath_env/server/__pycache__/task_bank.cpython-313.pyc
ADDED
|
Binary file (2.72 kB). View file
|
|
|
salespath_env/server/app.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# salespath_env/server/app.py
|
| 2 |
+
|
| 3 |
+
from openenv.core.env_server import create_fastapi_app
|
| 4 |
+
|
| 5 |
+
from ..models import (
|
| 6 |
+
SalesPathAction,
|
| 7 |
+
SalesPathObservation,
|
| 8 |
+
)
|
| 9 |
+
from .salespath_environment import (
|
| 10 |
+
SalesPathEnvironment,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
app = create_fastapi_app(
|
| 15 |
+
SalesPathEnvironment,
|
| 16 |
+
SalesPathAction,
|
| 17 |
+
SalesPathObservation,
|
| 18 |
+
)
|
salespath_env/server/prospect_simulator.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# salespath_env/server/prospect_simulator.py
|
| 2 |
+
|
| 3 |
+
from ..models import SalesPathAction, SalesPathState
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
RESPONSE_TEXT = {
|
| 7 |
+
"open:positive_signal": "That sounds interesting. Tell me more about how this works.",
|
| 8 |
+
"open:neutral_signal": "I see. We're evaluating a few options at the moment.",
|
| 9 |
+
|
| 10 |
+
"objection:price": "The pricing seems higher than what we budgeted for.",
|
| 11 |
+
"objection:timing": "The timing isn't ideal — we're in the middle of a quarter close.",
|
| 12 |
+
"objection:premature_pitch": (
|
| 13 |
+
"I'm not sure we're ready to discuss solutions yet. "
|
| 14 |
+
"What do you know about our current situation?"
|
| 15 |
+
),
|
| 16 |
+
|
| 17 |
+
"deflect:budget_not_discussed": (
|
| 18 |
+
"We haven't really talked about what we're looking for yet."
|
| 19 |
+
),
|
| 20 |
+
"deflect:stall": (
|
| 21 |
+
"Let me get back to you on this. A lot is happening on our end."
|
| 22 |
+
),
|
| 23 |
+
|
| 24 |
+
"accept:demo_scheduled": (
|
| 25 |
+
"Yes, let's set up a demo. What time works next week?"
|
| 26 |
+
),
|
| 27 |
+
"accept:close_success": (
|
| 28 |
+
"Alright, I think we can move forward with this. "
|
| 29 |
+
"Send over the paperwork."
|
| 30 |
+
),
|
| 31 |
+
|
| 32 |
+
"reject:close_failed": (
|
| 33 |
+
"I don't think we're ready to commit at this point."
|
| 34 |
+
),
|
| 35 |
+
|
| 36 |
+
"silence": "",
|
| 37 |
+
|
| 38 |
+
"exit:disqualified": (
|
| 39 |
+
"I think we're done here. This isn't the right fit."
|
| 40 |
+
),
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ProspectSimulator:
|
| 45 |
+
"""
|
| 46 |
+
Pure rule-based simulator.
|
| 47 |
+
No LLM. No transformers. Deterministic behavior.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def respond(
|
| 51 |
+
self,
|
| 52 |
+
action: SalesPathAction,
|
| 53 |
+
state: SalesPathState,
|
| 54 |
+
) -> tuple[str, str]:
|
| 55 |
+
"""
|
| 56 |
+
Returns:
|
| 57 |
+
(response_token, response_text)
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
token = self._get_token(action, state)
|
| 61 |
+
text = RESPONSE_TEXT[token]
|
| 62 |
+
|
| 63 |
+
return token, text
|
| 64 |
+
|
| 65 |
+
def _get_token(
|
| 66 |
+
self,
|
| 67 |
+
action: SalesPathAction,
|
| 68 |
+
state: SalesPathState,
|
| 69 |
+
) -> str:
|
| 70 |
+
atype = action.action_type
|
| 71 |
+
difficulty = state.difficulty
|
| 72 |
+
turn = state.turn_number
|
| 73 |
+
profile = state.prospect_profile
|
| 74 |
+
hidden = state.hidden_state
|
| 75 |
+
objections = state.objections_handled
|
| 76 |
+
|
| 77 |
+
# -----------------------------
|
| 78 |
+
# Rule-triggered responses first
|
| 79 |
+
# -----------------------------
|
| 80 |
+
|
| 81 |
+
if state.constraints_violated:
|
| 82 |
+
latest = state.constraints_violated[-1]
|
| 83 |
+
|
| 84 |
+
if latest == "R01":
|
| 85 |
+
return "objection:premature_pitch"
|
| 86 |
+
|
| 87 |
+
if latest == "R03":
|
| 88 |
+
return "deflect:budget_not_discussed"
|
| 89 |
+
|
| 90 |
+
# -----------------------------
|
| 91 |
+
# Action-based responses
|
| 92 |
+
# -----------------------------
|
| 93 |
+
|
| 94 |
+
if atype == "PROSPECT":
|
| 95 |
+
return "open:positive_signal"
|
| 96 |
+
|
| 97 |
+
if atype == "QUALIFY":
|
| 98 |
+
# Reveal budget if hidden
|
| 99 |
+
if profile.get("budget_signal") == "unknown":
|
| 100 |
+
state.prospect_profile["budget_signal"] = hidden.get(
|
| 101 |
+
"revealed_budget",
|
| 102 |
+
"medium",
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
return "open:neutral_signal"
|
| 106 |
+
|
| 107 |
+
if atype == "PRESENT":
|
| 108 |
+
if difficulty >= 2:
|
| 109 |
+
if objections == 0:
|
| 110 |
+
return "objection:price"
|
| 111 |
+
|
| 112 |
+
return "open:positive_signal"
|
| 113 |
+
|
| 114 |
+
if atype == "HANDLE_OBJECTION":
|
| 115 |
+
state.objections_handled += 1
|
| 116 |
+
|
| 117 |
+
required_objections = hidden.get("num_objections", 1)
|
| 118 |
+
|
| 119 |
+
if state.objections_handled >= required_objections:
|
| 120 |
+
return "open:positive_signal"
|
| 121 |
+
|
| 122 |
+
if objections == 0:
|
| 123 |
+
return "objection:timing"
|
| 124 |
+
|
| 125 |
+
return "open:positive_signal"
|
| 126 |
+
|
| 127 |
+
if atype == "OFFER_DEMO":
|
| 128 |
+
return "accept:demo_scheduled"
|
| 129 |
+
|
| 130 |
+
if atype == "NEGOTIATE":
|
| 131 |
+
return "open:neutral_signal"
|
| 132 |
+
|
| 133 |
+
if atype == "CLOSE":
|
| 134 |
+
true_budget = hidden.get("true_budget", 0.7)
|
| 135 |
+
close_threshold = hidden.get("close_threshold", 0.5)
|
| 136 |
+
decision_maker = profile.get("decision_maker", True)
|
| 137 |
+
|
| 138 |
+
if (
|
| 139 |
+
true_budget >= close_threshold
|
| 140 |
+
and decision_maker
|
| 141 |
+
):
|
| 142 |
+
return "accept:close_success"
|
| 143 |
+
|
| 144 |
+
return "reject:close_failed"
|
| 145 |
+
|
| 146 |
+
if atype == "FOLLOW_UP":
|
| 147 |
+
return "open:neutral_signal"
|
| 148 |
+
|
| 149 |
+
if atype == "DISQUALIFY":
|
| 150 |
+
return "exit:disqualified"
|
| 151 |
+
|
| 152 |
+
# -----------------------------
|
| 153 |
+
# Difficulty 3+ mode shift
|
| 154 |
+
# -----------------------------
|
| 155 |
+
|
| 156 |
+
if difficulty >= 3 and turn >= 10:
|
| 157 |
+
import random
|
| 158 |
+
|
| 159 |
+
if random.random() < hidden.get("stall_probability", 0.0):
|
| 160 |
+
return "deflect:stall"
|
| 161 |
+
|
| 162 |
+
return "open:neutral_signal"
|
salespath_env/server/requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn
|
| 3 |
+
pydantic>=2.0
|
salespath_env/server/reward.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# salespath_env/server/reward.py
|
| 2 |
+
|
| 3 |
+
from ..models import SalesPathAction, SalesPathState
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
DIFFICULTY_OPTIMAL_TURNS = {
|
| 7 |
+
1: 5,
|
| 8 |
+
2: 8,
|
| 9 |
+
3: 12,
|
| 10 |
+
4: 14,
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def compute_reward(
|
| 15 |
+
state: SalesPathState,
|
| 16 |
+
action: SalesPathAction,
|
| 17 |
+
response_token: str,
|
| 18 |
+
new_violations: list[str],
|
| 19 |
+
episode_done: bool,
|
| 20 |
+
) -> tuple[float, dict]:
|
| 21 |
+
"""
|
| 22 |
+
Returns:
|
| 23 |
+
(total_reward, reward_components)
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
components = {}
|
| 27 |
+
|
| 28 |
+
# --------------------------------------------------
|
| 29 |
+
# 1. Outcome Reward (terminal only)
|
| 30 |
+
# --------------------------------------------------
|
| 31 |
+
|
| 32 |
+
r_outcome = 0.0
|
| 33 |
+
|
| 34 |
+
if episode_done:
|
| 35 |
+
if response_token == "accept:close_success":
|
| 36 |
+
r_outcome = 1.0
|
| 37 |
+
|
| 38 |
+
elif action.action_type == "DISQUALIFY":
|
| 39 |
+
if "R08" not in new_violations:
|
| 40 |
+
r_outcome = 0.5
|
| 41 |
+
else:
|
| 42 |
+
r_outcome = -0.5
|
| 43 |
+
|
| 44 |
+
elif state.turn_number >= 20:
|
| 45 |
+
r_outcome = -0.3
|
| 46 |
+
|
| 47 |
+
elif len(state.constraints_violated) >= 3:
|
| 48 |
+
r_outcome = -0.5
|
| 49 |
+
|
| 50 |
+
else:
|
| 51 |
+
r_outcome = -0.5
|
| 52 |
+
|
| 53 |
+
components["r_outcome"] = r_outcome
|
| 54 |
+
|
| 55 |
+
# --------------------------------------------------
|
| 56 |
+
# 2. Compliance Reward
|
| 57 |
+
# --------------------------------------------------
|
| 58 |
+
|
| 59 |
+
r_compliance = max(
|
| 60 |
+
-1.0,
|
| 61 |
+
-0.2 * len(new_violations),
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
components["r_compliance"] = r_compliance
|
| 65 |
+
|
| 66 |
+
# --------------------------------------------------
|
| 67 |
+
# 3. Ordering Reward
|
| 68 |
+
# --------------------------------------------------
|
| 69 |
+
|
| 70 |
+
required = state.required_workflow
|
| 71 |
+
completed = state.steps_completed
|
| 72 |
+
|
| 73 |
+
if len(required) > 0 and len(completed) > 0:
|
| 74 |
+
correct = sum(
|
| 75 |
+
1
|
| 76 |
+
for i in range(min(len(required), len(completed)))
|
| 77 |
+
if required[i] == completed[i]
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
r_ordering = correct / len(required)
|
| 81 |
+
|
| 82 |
+
else:
|
| 83 |
+
r_ordering = 1.0
|
| 84 |
+
|
| 85 |
+
components["r_ordering"] = r_ordering
|
| 86 |
+
|
| 87 |
+
# --------------------------------------------------
|
| 88 |
+
# 4. Efficiency Reward
|
| 89 |
+
# --------------------------------------------------
|
| 90 |
+
|
| 91 |
+
if episode_done:
|
| 92 |
+
optimal = DIFFICULTY_OPTIMAL_TURNS.get(
|
| 93 |
+
state.difficulty,
|
| 94 |
+
10,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
extra_turns = max(
|
| 98 |
+
0,
|
| 99 |
+
state.turn_number - optimal,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
r_efficiency = max(
|
| 103 |
+
-0.3,
|
| 104 |
+
-0.05 * extra_turns,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
else:
|
| 108 |
+
r_efficiency = 0.0
|
| 109 |
+
|
| 110 |
+
components["r_efficiency"] = r_efficiency
|
| 111 |
+
|
| 112 |
+
# --------------------------------------------------
|
| 113 |
+
# 5. Format Reward
|
| 114 |
+
# --------------------------------------------------
|
| 115 |
+
|
| 116 |
+
r_format = 1.0 if action.is_valid() else -0.1
|
| 117 |
+
components["r_format"] = r_format
|
| 118 |
+
|
| 119 |
+
# --------------------------------------------------
|
| 120 |
+
# Final Weighted Reward
|
| 121 |
+
# --------------------------------------------------
|
| 122 |
+
|
| 123 |
+
weights = {
|
| 124 |
+
"r_outcome": 0.40,
|
| 125 |
+
"r_compliance": 0.30,
|
| 126 |
+
"r_ordering": 0.15,
|
| 127 |
+
"r_efficiency": 0.10,
|
| 128 |
+
"r_format": 0.05,
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
total_reward = sum(
|
| 132 |
+
weights[key] * components[key]
|
| 133 |
+
for key in weights
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
components["total"] = total_reward
|
| 137 |
+
|
| 138 |
+
return total_reward, components
|
salespath_env/server/rules.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# salespath_env/server/rules.py
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Callable
|
| 5 |
+
|
| 6 |
+
from ..models import SalesPathAction, SalesPathState
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class BusinessRule:
|
| 11 |
+
"""
|
| 12 |
+
Returns True when the rule is VIOLATED.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
rule_id: str
|
| 16 |
+
name: str
|
| 17 |
+
description: str
|
| 18 |
+
check: Callable[[SalesPathState, SalesPathAction], bool]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _qualify_before_present(
|
| 22 |
+
state: SalesPathState,
|
| 23 |
+
action: SalesPathAction,
|
| 24 |
+
) -> bool:
|
| 25 |
+
"""
|
| 26 |
+
R01:
|
| 27 |
+
PRESENT before QUALIFY is invalid.
|
| 28 |
+
"""
|
| 29 |
+
if action.action_type == "PRESENT":
|
| 30 |
+
return "QUALIFY" not in state.steps_completed
|
| 31 |
+
return False
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _demo_before_negotiate(
|
| 35 |
+
state: SalesPathState,
|
| 36 |
+
action: SalesPathAction,
|
| 37 |
+
) -> bool:
|
| 38 |
+
"""
|
| 39 |
+
R02:
|
| 40 |
+
NEGOTIATE before OFFER_DEMO is invalid.
|
| 41 |
+
"""
|
| 42 |
+
if action.action_type == "NEGOTIATE":
|
| 43 |
+
return "OFFER_DEMO" not in state.steps_completed
|
| 44 |
+
return False
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _budget_known_to_negotiate(
|
| 48 |
+
state: SalesPathState,
|
| 49 |
+
action: SalesPathAction,
|
| 50 |
+
) -> bool:
|
| 51 |
+
"""
|
| 52 |
+
R03:
|
| 53 |
+
Cannot NEGOTIATE while budget is unknown.
|
| 54 |
+
"""
|
| 55 |
+
if action.action_type == "NEGOTIATE":
|
| 56 |
+
return state.prospect_profile.get("budget_signal") == "unknown"
|
| 57 |
+
return False
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _discount_after_objections(
|
| 61 |
+
state: SalesPathState,
|
| 62 |
+
action: SalesPathAction,
|
| 63 |
+
) -> bool:
|
| 64 |
+
"""
|
| 65 |
+
R04:
|
| 66 |
+
Discount only after 2 objections handled.
|
| 67 |
+
"""
|
| 68 |
+
if action.action_type == "NEGOTIATE":
|
| 69 |
+
if "discount" in action.content.lower():
|
| 70 |
+
return state.objections_handled < 2
|
| 71 |
+
return False
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _no_repeat_action(
|
| 75 |
+
state: SalesPathState,
|
| 76 |
+
action: SalesPathAction,
|
| 77 |
+
) -> bool:
|
| 78 |
+
"""
|
| 79 |
+
R05:
|
| 80 |
+
Same action twice in a row is invalid.
|
| 81 |
+
"""
|
| 82 |
+
if state.conversation_history:
|
| 83 |
+
last_action = state.conversation_history[-1].get("action_type", "")
|
| 84 |
+
return last_action == action.action_type
|
| 85 |
+
return False
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _prospect_first(
|
| 89 |
+
state: SalesPathState,
|
| 90 |
+
action: SalesPathAction,
|
| 91 |
+
) -> bool:
|
| 92 |
+
"""
|
| 93 |
+
R06:
|
| 94 |
+
First action must be PROSPECT.
|
| 95 |
+
"""
|
| 96 |
+
if state.turn_number == 1:
|
| 97 |
+
return action.action_type != "PROSPECT"
|
| 98 |
+
return False
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _followup_timing(
|
| 102 |
+
state: SalesPathState,
|
| 103 |
+
action: SalesPathAction,
|
| 104 |
+
) -> bool:
|
| 105 |
+
"""
|
| 106 |
+
R07:
|
| 107 |
+
FOLLOW_UP only valid after silence.
|
| 108 |
+
If prospect just responded last turn, violation.
|
| 109 |
+
"""
|
| 110 |
+
if action.action_type == "FOLLOW_UP":
|
| 111 |
+
if state.conversation_history:
|
| 112 |
+
last_speaker = state.conversation_history[-1].get("speaker", "agent")
|
| 113 |
+
return last_speaker == "prospect"
|
| 114 |
+
return False
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _disqualify_logic(
|
| 118 |
+
state: SalesPathState,
|
| 119 |
+
action: SalesPathAction,
|
| 120 |
+
) -> bool:
|
| 121 |
+
"""
|
| 122 |
+
R08:
|
| 123 |
+
DISQUALIFY only when prospect is genuinely not closeable.
|
| 124 |
+
Violation if prospect is actually closeable.
|
| 125 |
+
"""
|
| 126 |
+
if action.action_type == "DISQUALIFY":
|
| 127 |
+
true_budget = state.hidden_state.get("true_budget", 0.5)
|
| 128 |
+
close_threshold = state.hidden_state.get("close_threshold", 0.5)
|
| 129 |
+
decision_maker = state.prospect_profile.get("decision_maker", True)
|
| 130 |
+
|
| 131 |
+
return (true_budget >= close_threshold) and decision_maker
|
| 132 |
+
|
| 133 |
+
return False
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _close_requires_demo(
|
| 137 |
+
state: SalesPathState,
|
| 138 |
+
action: SalesPathAction,
|
| 139 |
+
) -> bool:
|
| 140 |
+
"""
|
| 141 |
+
R09:
|
| 142 |
+
Difficulty 2+ requires OFFER_DEMO before CLOSE.
|
| 143 |
+
"""
|
| 144 |
+
if action.action_type == "CLOSE":
|
| 145 |
+
if state.difficulty >= 2:
|
| 146 |
+
return "OFFER_DEMO" not in state.steps_completed
|
| 147 |
+
return False
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
BUSINESS_RULES = [
|
| 151 |
+
BusinessRule(
|
| 152 |
+
"R01",
|
| 153 |
+
"qualify_before_present",
|
| 154 |
+
"Must QUALIFY before PRESENT",
|
| 155 |
+
_qualify_before_present,
|
| 156 |
+
),
|
| 157 |
+
BusinessRule(
|
| 158 |
+
"R02",
|
| 159 |
+
"demo_before_negotiate",
|
| 160 |
+
"Must OFFER_DEMO before NEGOTIATE",
|
| 161 |
+
_demo_before_negotiate,
|
| 162 |
+
),
|
| 163 |
+
BusinessRule(
|
| 164 |
+
"R03",
|
| 165 |
+
"budget_known_to_negotiate",
|
| 166 |
+
"Budget must be known before NEGOTIATE",
|
| 167 |
+
_budget_known_to_negotiate,
|
| 168 |
+
),
|
| 169 |
+
BusinessRule(
|
| 170 |
+
"R04",
|
| 171 |
+
"discount_after_objections",
|
| 172 |
+
"Discount only after 2 objections handled",
|
| 173 |
+
_discount_after_objections,
|
| 174 |
+
),
|
| 175 |
+
BusinessRule(
|
| 176 |
+
"R05",
|
| 177 |
+
"no_repeat_action",
|
| 178 |
+
"Cannot repeat same action consecutively",
|
| 179 |
+
_no_repeat_action,
|
| 180 |
+
),
|
| 181 |
+
BusinessRule(
|
| 182 |
+
"R06",
|
| 183 |
+
"prospect_first",
|
| 184 |
+
"First action must be PROSPECT",
|
| 185 |
+
_prospect_first,
|
| 186 |
+
),
|
| 187 |
+
BusinessRule(
|
| 188 |
+
"R07",
|
| 189 |
+
"followup_timing",
|
| 190 |
+
"FOLLOW_UP only after prospect silence",
|
| 191 |
+
_followup_timing,
|
| 192 |
+
),
|
| 193 |
+
BusinessRule(
|
| 194 |
+
"R08",
|
| 195 |
+
"disqualify_logic",
|
| 196 |
+
"DISQUALIFY only when prospect is genuinely unqualified",
|
| 197 |
+
_disqualify_logic,
|
| 198 |
+
),
|
| 199 |
+
BusinessRule(
|
| 200 |
+
"R09",
|
| 201 |
+
"close_requires_demo",
|
| 202 |
+
"Must OFFER_DEMO before CLOSE (difficulty 2+)",
|
| 203 |
+
_close_requires_demo,
|
| 204 |
+
),
|
| 205 |
+
]
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def check_rules(
|
| 209 |
+
state: SalesPathState,
|
| 210 |
+
action: SalesPathAction,
|
| 211 |
+
) -> list[str]:
|
| 212 |
+
"""
|
| 213 |
+
Returns list of violated rule IDs.
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
violated = []
|
| 217 |
+
|
| 218 |
+
for rule in BUSINESS_RULES:
|
| 219 |
+
if rule.check(state, action):
|
| 220 |
+
violated.append(rule.rule_id)
|
| 221 |
+
|
| 222 |
+
return violated
|
salespath_env/server/salespath_environment.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# salespath_env/server/salespath_environment.py
|
| 2 |
+
|
| 3 |
+
import uuid
|
| 4 |
+
|
| 5 |
+
from openenv.core.env_server import Environment
|
| 6 |
+
|
| 7 |
+
from ..models import (
|
| 8 |
+
SalesPathAction,
|
| 9 |
+
SalesPathObservation,
|
| 10 |
+
SalesPathState,
|
| 11 |
+
)
|
| 12 |
+
from .task_bank import sample_profile
|
| 13 |
+
from .rules import check_rules
|
| 14 |
+
from .reward import compute_reward
|
| 15 |
+
from .prospect_simulator import ProspectSimulator
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
DIFFICULTY_WORKFLOW = {
|
| 19 |
+
1: [
|
| 20 |
+
"QUALIFY",
|
| 21 |
+
"PRESENT",
|
| 22 |
+
"CLOSE",
|
| 23 |
+
],
|
| 24 |
+
2: [
|
| 25 |
+
"QUALIFY",
|
| 26 |
+
"PRESENT",
|
| 27 |
+
"HANDLE_OBJECTION",
|
| 28 |
+
"OFFER_DEMO",
|
| 29 |
+
"CLOSE",
|
| 30 |
+
],
|
| 31 |
+
3: [
|
| 32 |
+
"QUALIFY",
|
| 33 |
+
"PRESENT",
|
| 34 |
+
"HANDLE_OBJECTION",
|
| 35 |
+
"OFFER_DEMO",
|
| 36 |
+
"HANDLE_OBJECTION",
|
| 37 |
+
"NEGOTIATE",
|
| 38 |
+
"CLOSE",
|
| 39 |
+
],
|
| 40 |
+
4: [], # Agent must determine; DISQUALIFY may be correct
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
MAX_VIOLATIONS_BEFORE_TERMINATE = 3
|
| 45 |
+
MAX_TURNS = 20
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class SalesPathEnvironment(Environment):
|
| 49 |
+
"""
|
| 50 |
+
Core OpenEnv environment.
|
| 51 |
+
All business logic routes through:
|
| 52 |
+
- rules.py
|
| 53 |
+
- reward.py
|
| 54 |
+
- prospect_simulator.py
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self):
|
| 58 |
+
super().__init__()
|
| 59 |
+
self._state = SalesPathState()
|
| 60 |
+
self._simulator = ProspectSimulator()
|
| 61 |
+
|
| 62 |
+
def reset(self, difficulty: int = 1) -> SalesPathObservation:
|
| 63 |
+
"""
|
| 64 |
+
Start a new episode.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
profile = sample_profile(difficulty)
|
| 68 |
+
|
| 69 |
+
hidden_state = {
|
| 70 |
+
"true_budget": profile.true_budget,
|
| 71 |
+
"close_threshold": profile.close_threshold,
|
| 72 |
+
"stall_probability": profile.stall_probability,
|
| 73 |
+
"num_objections": {
|
| 74 |
+
1: 0,
|
| 75 |
+
2: 1,
|
| 76 |
+
3: 2,
|
| 77 |
+
4: 2,
|
| 78 |
+
}[difficulty],
|
| 79 |
+
"revealed_budget": (
|
| 80 |
+
"high"
|
| 81 |
+
if profile.true_budget >= 0.7
|
| 82 |
+
else "medium"
|
| 83 |
+
if profile.true_budget >= 0.4
|
| 84 |
+
else "low"
|
| 85 |
+
),
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
public_profile = {
|
| 89 |
+
"company_name": profile.company_name,
|
| 90 |
+
"company_size": profile.company_size,
|
| 91 |
+
"industry": profile.industry,
|
| 92 |
+
"budget_signal": profile.budget_signal,
|
| 93 |
+
"pain_points": profile.pain_points,
|
| 94 |
+
"decision_maker": profile.decision_maker,
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
self._state = SalesPathState(
|
| 98 |
+
episode_id=str(uuid.uuid4()),
|
| 99 |
+
prospect_profile=public_profile,
|
| 100 |
+
conversation_history=[],
|
| 101 |
+
workflow_stage="START",
|
| 102 |
+
required_workflow=DIFFICULTY_WORKFLOW[difficulty],
|
| 103 |
+
steps_completed=[],
|
| 104 |
+
constraints_violated=[],
|
| 105 |
+
objections_handled=0,
|
| 106 |
+
turn_number=0,
|
| 107 |
+
difficulty=difficulty,
|
| 108 |
+
done=False,
|
| 109 |
+
hidden_state=hidden_state,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
intro_message = (
|
| 113 |
+
f"You are engaging {profile.company_name}, "
|
| 114 |
+
f"a {profile.company_size} {profile.industry} company. "
|
| 115 |
+
f"Pain points: {', '.join(profile.pain_points)}. "
|
| 116 |
+
f"Begin the sales conversation."
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
return SalesPathObservation(
|
| 120 |
+
prospect_response=intro_message,
|
| 121 |
+
workflow_stage="START",
|
| 122 |
+
constraints_violated=[],
|
| 123 |
+
steps_completed=[],
|
| 124 |
+
turn_number=0,
|
| 125 |
+
reward=0.0,
|
| 126 |
+
reward_components={},
|
| 127 |
+
done=False,
|
| 128 |
+
info={
|
| 129 |
+
"difficulty": difficulty,
|
| 130 |
+
"episode_id": self._state.episode_id,
|
| 131 |
+
},
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def step(
|
| 135 |
+
self,
|
| 136 |
+
action: SalesPathAction,
|
| 137 |
+
) -> SalesPathObservation:
|
| 138 |
+
"""
|
| 139 |
+
One environment transition.
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
state = self._state
|
| 143 |
+
|
| 144 |
+
# -----------------------------------
|
| 145 |
+
# Advance turn
|
| 146 |
+
# -----------------------------------
|
| 147 |
+
|
| 148 |
+
state.turn_number += 1
|
| 149 |
+
|
| 150 |
+
# -----------------------------------
|
| 151 |
+
# Strict action validation
|
| 152 |
+
# Must return observation, never crash
|
| 153 |
+
# -----------------------------------
|
| 154 |
+
|
| 155 |
+
if not action.is_valid():
|
| 156 |
+
return SalesPathObservation(
|
| 157 |
+
prospect_response="Invalid action type.",
|
| 158 |
+
workflow_stage=state.workflow_stage,
|
| 159 |
+
constraints_violated=list(state.constraints_violated),
|
| 160 |
+
steps_completed=list(state.steps_completed),
|
| 161 |
+
turn_number=state.turn_number,
|
| 162 |
+
reward=-0.2,
|
| 163 |
+
reward_components={
|
| 164 |
+
"r_format": -0.1,
|
| 165 |
+
},
|
| 166 |
+
done=False,
|
| 167 |
+
info={
|
| 168 |
+
"error": (
|
| 169 |
+
f"Invalid action_type: "
|
| 170 |
+
f"{action.action_type}"
|
| 171 |
+
)
|
| 172 |
+
},
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# -----------------------------------
|
| 176 |
+
# Rule checks
|
| 177 |
+
# -----------------------------------
|
| 178 |
+
|
| 179 |
+
new_violations = check_rules(
|
| 180 |
+
state,
|
| 181 |
+
action,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
state.constraints_violated.extend(
|
| 185 |
+
new_violations
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# -----------------------------------
|
| 189 |
+
# Record agent action
|
| 190 |
+
# -----------------------------------
|
| 191 |
+
|
| 192 |
+
state.conversation_history.append(
|
| 193 |
+
{
|
| 194 |
+
"turn": state.turn_number,
|
| 195 |
+
"speaker": "agent",
|
| 196 |
+
"action_type": action.action_type,
|
| 197 |
+
"content": action.content,
|
| 198 |
+
}
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# -----------------------------------
|
| 202 |
+
# Update workflow state
|
| 203 |
+
# -----------------------------------
|
| 204 |
+
|
| 205 |
+
if action.action_type not in state.steps_completed:
|
| 206 |
+
state.steps_completed.append(
|
| 207 |
+
action.action_type
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
state.workflow_stage = action.action_type
|
| 211 |
+
|
| 212 |
+
# -----------------------------------
|
| 213 |
+
# Prospect response
|
| 214 |
+
# -----------------------------------
|
| 215 |
+
|
| 216 |
+
response_token, response_text = (
|
| 217 |
+
self._simulator.respond(
|
| 218 |
+
action,
|
| 219 |
+
state,
|
| 220 |
+
)
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
state.conversation_history.append(
|
| 224 |
+
{
|
| 225 |
+
"turn": state.turn_number,
|
| 226 |
+
"speaker": "prospect",
|
| 227 |
+
"response_token": response_token,
|
| 228 |
+
"text": response_text,
|
| 229 |
+
}
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# -----------------------------------
|
| 233 |
+
# Episode termination
|
| 234 |
+
# -----------------------------------
|
| 235 |
+
|
| 236 |
+
terminal_actions = {
|
| 237 |
+
"CLOSE",
|
| 238 |
+
"DISQUALIFY",
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
too_many_violations = (
|
| 242 |
+
len(state.constraints_violated)
|
| 243 |
+
>= MAX_VIOLATIONS_BEFORE_TERMINATE
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
turn_limit_reached = (
|
| 247 |
+
state.turn_number >= MAX_TURNS
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
done = (
|
| 251 |
+
action.action_type in terminal_actions
|
| 252 |
+
or too_many_violations
|
| 253 |
+
or turn_limit_reached
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
state.done = done
|
| 257 |
+
|
| 258 |
+
# -----------------------------------
|
| 259 |
+
# Reward
|
| 260 |
+
# -----------------------------------
|
| 261 |
+
|
| 262 |
+
total_reward, components = (
|
| 263 |
+
compute_reward(
|
| 264 |
+
state=state,
|
| 265 |
+
action=action,
|
| 266 |
+
response_token=response_token,
|
| 267 |
+
new_violations=new_violations,
|
| 268 |
+
episode_done=done,
|
| 269 |
+
)
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
return SalesPathObservation(
|
| 273 |
+
prospect_response=response_text,
|
| 274 |
+
workflow_stage=state.workflow_stage,
|
| 275 |
+
constraints_violated=list(
|
| 276 |
+
state.constraints_violated
|
| 277 |
+
),
|
| 278 |
+
steps_completed=list(
|
| 279 |
+
state.steps_completed
|
| 280 |
+
),
|
| 281 |
+
turn_number=state.turn_number,
|
| 282 |
+
reward=total_reward,
|
| 283 |
+
reward_components=components,
|
| 284 |
+
done=done,
|
| 285 |
+
info={
|
| 286 |
+
"response_token": response_token,
|
| 287 |
+
"new_violations": new_violations,
|
| 288 |
+
"episode_id": state.episode_id,
|
| 289 |
+
},
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
@property
|
| 293 |
+
def state(self) -> SalesPathState:
|
| 294 |
+
return self._state
|
salespath_env/server/task_bank.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# salespath_env/server/task_bank.py
|
| 2 |
+
|
| 3 |
+
import random
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class ProspectProfile:
|
| 9 |
+
company_name: str
|
| 10 |
+
company_size: str # small / medium / enterprise
|
| 11 |
+
industry: str
|
| 12 |
+
budget_signal: str # high / medium / low / unknown
|
| 13 |
+
pain_points: list[str]
|
| 14 |
+
decision_maker: bool
|
| 15 |
+
|
| 16 |
+
# Hidden values — never exposed directly to agent
|
| 17 |
+
true_budget: float # 0.0 → 1.0
|
| 18 |
+
close_threshold: float
|
| 19 |
+
stall_probability: float
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# -------------------------
|
| 23 |
+
# LEVEL 1 — Easy
|
| 24 |
+
# budget known
|
| 25 |
+
# decision maker present
|
| 26 |
+
# close is usually possible
|
| 27 |
+
# -------------------------
|
| 28 |
+
|
| 29 |
+
PROFILES_L1 = [
|
| 30 |
+
ProspectProfile(
|
| 31 |
+
company_name="Meridian Retail",
|
| 32 |
+
company_size="medium",
|
| 33 |
+
industry="retail",
|
| 34 |
+
budget_signal="high",
|
| 35 |
+
pain_points=[
|
| 36 |
+
"manual inventory tracking",
|
| 37 |
+
"slow reporting",
|
| 38 |
+
],
|
| 39 |
+
decision_maker=True,
|
| 40 |
+
true_budget=0.8,
|
| 41 |
+
close_threshold=0.5,
|
| 42 |
+
stall_probability=0.0,
|
| 43 |
+
),
|
| 44 |
+
|
| 45 |
+
ProspectProfile(
|
| 46 |
+
company_name="Northline Foods",
|
| 47 |
+
company_size="small",
|
| 48 |
+
industry="food distribution",
|
| 49 |
+
budget_signal="medium",
|
| 50 |
+
pain_points=[
|
| 51 |
+
"supplier delays",
|
| 52 |
+
"inventory mismatch",
|
| 53 |
+
],
|
| 54 |
+
decision_maker=True,
|
| 55 |
+
true_budget=0.6,
|
| 56 |
+
close_threshold=0.5,
|
| 57 |
+
stall_probability=0.0,
|
| 58 |
+
),
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# -------------------------
|
| 63 |
+
# LEVEL 2 — Medium
|
| 64 |
+
# budget hidden initially
|
| 65 |
+
# one objection expected
|
| 66 |
+
# -------------------------
|
| 67 |
+
|
| 68 |
+
PROFILES_L2 = [
|
| 69 |
+
ProspectProfile(
|
| 70 |
+
company_name="Apex Logistics",
|
| 71 |
+
company_size="enterprise",
|
| 72 |
+
industry="logistics",
|
| 73 |
+
budget_signal="unknown",
|
| 74 |
+
pain_points=[
|
| 75 |
+
"route optimization",
|
| 76 |
+
"driver coordination",
|
| 77 |
+
"fuel tracking",
|
| 78 |
+
],
|
| 79 |
+
decision_maker=True,
|
| 80 |
+
true_budget=0.7,
|
| 81 |
+
close_threshold=0.5,
|
| 82 |
+
stall_probability=0.0,
|
| 83 |
+
),
|
| 84 |
+
|
| 85 |
+
ProspectProfile(
|
| 86 |
+
company_name="Vertex Supply",
|
| 87 |
+
company_size="medium",
|
| 88 |
+
industry="manufacturing",
|
| 89 |
+
budget_signal="unknown",
|
| 90 |
+
pain_points=[
|
| 91 |
+
"vendor visibility",
|
| 92 |
+
"purchase delays",
|
| 93 |
+
],
|
| 94 |
+
decision_maker=True,
|
| 95 |
+
true_budget=0.55,
|
| 96 |
+
close_threshold=0.5,
|
| 97 |
+
stall_probability=0.0,
|
| 98 |
+
),
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# -------------------------
|
| 103 |
+
# LEVEL 3 — Hard
|
| 104 |
+
# budget hidden
|
| 105 |
+
# 2 objections
|
| 106 |
+
# possible stalling
|
| 107 |
+
# decision maker may be absent
|
| 108 |
+
# -------------------------
|
| 109 |
+
|
| 110 |
+
PROFILES_L3 = [
|
| 111 |
+
ProspectProfile(
|
| 112 |
+
company_name="Nova Financial",
|
| 113 |
+
company_size="enterprise",
|
| 114 |
+
industry="finance",
|
| 115 |
+
budget_signal="unknown",
|
| 116 |
+
pain_points=[
|
| 117 |
+
"compliance reporting",
|
| 118 |
+
"audit trails",
|
| 119 |
+
"data silos",
|
| 120 |
+
],
|
| 121 |
+
decision_maker=False,
|
| 122 |
+
true_budget=0.6,
|
| 123 |
+
close_threshold=0.55,
|
| 124 |
+
stall_probability=0.3,
|
| 125 |
+
),
|
| 126 |
+
|
| 127 |
+
ProspectProfile(
|
| 128 |
+
company_name="Atlas Health",
|
| 129 |
+
company_size="enterprise",
|
| 130 |
+
industry="healthcare",
|
| 131 |
+
budget_signal="unknown",
|
| 132 |
+
pain_points=[
|
| 133 |
+
"patient workflow delays",
|
| 134 |
+
"reporting compliance",
|
| 135 |
+
],
|
| 136 |
+
decision_maker=False,
|
| 137 |
+
true_budget=0.65,
|
| 138 |
+
close_threshold=0.55,
|
| 139 |
+
stall_probability=0.25,
|
| 140 |
+
),
|
| 141 |
+
]
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# -------------------------
|
| 145 |
+
# LEVEL 4 — Trap cases
|
| 146 |
+
# misleading signals
|
| 147 |
+
# correct action may be DISQUALIFY
|
| 148 |
+
# -------------------------
|
| 149 |
+
|
| 150 |
+
PROFILES_L4 = [
|
| 151 |
+
ProspectProfile(
|
| 152 |
+
company_name="Cipher Tech",
|
| 153 |
+
company_size="small",
|
| 154 |
+
industry="technology",
|
| 155 |
+
budget_signal="high", # misleading
|
| 156 |
+
pain_points=[
|
| 157 |
+
"security",
|
| 158 |
+
"compliance",
|
| 159 |
+
],
|
| 160 |
+
decision_maker=True,
|
| 161 |
+
true_budget=0.2,
|
| 162 |
+
close_threshold=0.5,
|
| 163 |
+
stall_probability=0.5,
|
| 164 |
+
),
|
| 165 |
+
|
| 166 |
+
ProspectProfile(
|
| 167 |
+
company_name="BluePeak Studio",
|
| 168 |
+
company_size="small",
|
| 169 |
+
industry="creative agency",
|
| 170 |
+
budget_signal="high", # misleading
|
| 171 |
+
pain_points=[
|
| 172 |
+
"project visibility",
|
| 173 |
+
"client reporting",
|
| 174 |
+
],
|
| 175 |
+
decision_maker=True,
|
| 176 |
+
true_budget=0.25,
|
| 177 |
+
close_threshold=0.5,
|
| 178 |
+
stall_probability=0.4,
|
| 179 |
+
),
|
| 180 |
+
]
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
ALL_PROFILES = {
|
| 184 |
+
1: PROFILES_L1,
|
| 185 |
+
2: PROFILES_L2,
|
| 186 |
+
3: PROFILES_L3,
|
| 187 |
+
4: PROFILES_L4,
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def sample_profile(difficulty: int) -> ProspectProfile:
|
| 192 |
+
"""
|
| 193 |
+
Returns one sampled profile for the selected difficulty.
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
if difficulty not in ALL_PROFILES:
|
| 197 |
+
difficulty = 1
|
| 198 |
+
|
| 199 |
+
return random.choice(ALL_PROFILES[difficulty])
|
training/__init__.py
ADDED
|
File without changes
|
training/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (172 Bytes). View file
|
|
|
training/__pycache__/curriculum.cpython-313.pyc
ADDED
|
Binary file (2.02 kB). View file
|
|
|
training/__pycache__/debug_episode.cpython-313.pyc
ADDED
|
Binary file (2.8 kB). View file
|
|
|
training/__pycache__/grpo_train.cpython-313.pyc
ADDED
|
Binary file (14.3 kB). View file
|
|
|
training/__pycache__/rollout.cpython-313.pyc
ADDED
|
Binary file (5.47 kB). View file
|
|
|
training/__pycache__/test_rollout.cpython-313.pyc
ADDED
|
Binary file (1.77 kB). View file
|
|
|
training/colab_train.ipynb
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# SalesPath Colab Training\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"This notebook installs dependencies, runs a local environment server, validates rollout, and launches curriculum training."
|
| 10 |
+
]
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
"cell_type": "code",
|
| 14 |
+
"execution_count": null,
|
| 15 |
+
"metadata": {},
|
| 16 |
+
"outputs": [],
|
| 17 |
+
"source": [
|
| 18 |
+
"!pip install -U pip\n",
|
| 19 |
+
"!pip install fastapi uvicorn pydantic httpx torch transformers trl unsloth openenv"
|
| 20 |
+
]
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"cell_type": "code",
|
| 24 |
+
"execution_count": null,
|
| 25 |
+
"metadata": {},
|
| 26 |
+
"outputs": [],
|
| 27 |
+
"source": [
|
| 28 |
+
"# If the repo is not already present, clone it.\n",
|
| 29 |
+
"# !git clone https://github.com/<your-org-or-user>/salespath_env.git\n",
|
| 30 |
+
"# %cd salespath_env\n",
|
| 31 |
+
"\n",
|
| 32 |
+
"%cd /content/salespath_env"
|
| 33 |
+
]
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"cell_type": "code",
|
| 37 |
+
"execution_count": null,
|
| 38 |
+
"metadata": {},
|
| 39 |
+
"outputs": [],
|
| 40 |
+
"source": [
|
| 41 |
+
"# Start the OpenEnv-compatible server in background.\n",
|
| 42 |
+
"!nohup python -m uvicorn salespath_env.server.app:app --host 0.0.0.0 --port 8000 > /content/server.log 2>&1 &\n",
|
| 43 |
+
"!sleep 3\n",
|
| 44 |
+
"!python -c \"import httpx; r=httpx.get('http://127.0.0.1:8000/health', timeout=30); print(r.status_code, r.text)\""
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "code",
|
| 49 |
+
"execution_count": null,
|
| 50 |
+
"metadata": {},
|
| 51 |
+
"outputs": [],
|
| 52 |
+
"source": [
|
| 53 |
+
"# Rollout smoke test (single episode)\n",
|
| 54 |
+
"!python -m training.test_rollout"
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"cell_type": "code",
|
| 59 |
+
"execution_count": null,
|
| 60 |
+
"metadata": {},
|
| 61 |
+
"outputs": [],
|
| 62 |
+
"source": [
|
| 63 |
+
"# Curriculum run (example)\n",
|
| 64 |
+
"!python -m training.grpo_train --steps 30 --env-url http://127.0.0.1:8000 --model-name Qwen/Qwen2.5-0.5B-Instruct"
|
| 65 |
+
]
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"cell_type": "markdown",
|
| 69 |
+
"metadata": {},
|
| 70 |
+
"source": [
|
| 71 |
+
"## Optional: Push merged model to Hugging Face\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"Set your token first:\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"```python\n",
|
| 76 |
+
"import os\n",
|
| 77 |
+
"os.environ['HF_TOKEN'] = 'hf_xxx'\n",
|
| 78 |
+
"```\n",
|
| 79 |
+
"\n",
|
| 80 |
+
"Then run:\n",
|
| 81 |
+
"\n",
|
| 82 |
+
"```bash\n",
|
| 83 |
+
"python -m training.grpo_train --steps 100 --push-merged --hub-repo Imsachin010/salespath-qwen25-7b\n",
|
| 84 |
+
"```"
|
| 85 |
+
]
|
| 86 |
+
}
|
| 87 |
+
],
|
| 88 |
+
"metadata": {
|
| 89 |
+
"kernelspec": {
|
| 90 |
+
"display_name": "Python 3",
|
| 91 |
+
"language": "python",
|
| 92 |
+
"name": "python3"
|
| 93 |
+
},
|
| 94 |
+
"language_info": {
|
| 95 |
+
"name": "python"
|
| 96 |
+
}
|
| 97 |
+
},
|
| 98 |
+
"nbformat": 4,
|
| 99 |
+
"nbformat_minor": 5
|
| 100 |
+
}
|
training/curriculum.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# training/curriculum.py
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class CurriculumConfig:
|
| 9 |
+
"""
|
| 10 |
+
Maps mean reward → difficulty distribution
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
thresholds: dict
|
| 14 |
+
|
| 15 |
+
def get_distribution(
|
| 16 |
+
self,
|
| 17 |
+
mean_reward: float,
|
| 18 |
+
) -> dict:
|
| 19 |
+
for threshold in sorted(
|
| 20 |
+
self.thresholds.keys(),
|
| 21 |
+
reverse=True,
|
| 22 |
+
):
|
| 23 |
+
if mean_reward >= threshold:
|
| 24 |
+
return self.thresholds[threshold]
|
| 25 |
+
|
| 26 |
+
return self.thresholds[
|
| 27 |
+
min(self.thresholds.keys())
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
DEFAULT_CURRICULUM = CurriculumConfig(
|
| 32 |
+
thresholds={
|
| 33 |
+
0.0: {
|
| 34 |
+
1: 0.90,
|
| 35 |
+
2: 0.10,
|
| 36 |
+
3: 0.00,
|
| 37 |
+
4: 0.00,
|
| 38 |
+
},
|
| 39 |
+
|
| 40 |
+
0.30: {
|
| 41 |
+
1: 0.50,
|
| 42 |
+
2: 0.40,
|
| 43 |
+
3: 0.10,
|
| 44 |
+
4: 0.00,
|
| 45 |
+
},
|
| 46 |
+
|
| 47 |
+
0.50: {
|
| 48 |
+
1: 0.20,
|
| 49 |
+
2: 0.40,
|
| 50 |
+
3: 0.35,
|
| 51 |
+
4: 0.05,
|
| 52 |
+
},
|
| 53 |
+
|
| 54 |
+
0.65: {
|
| 55 |
+
1: 0.10,
|
| 56 |
+
2: 0.30,
|
| 57 |
+
3: 0.40,
|
| 58 |
+
4: 0.20,
|
| 59 |
+
},
|
| 60 |
+
}
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def sample_difficulty(
|
| 65 |
+
curriculum: CurriculumConfig,
|
| 66 |
+
mean_reward: float,
|
| 67 |
+
) -> int:
|
| 68 |
+
"""
|
| 69 |
+
Sample difficulty from curriculum schedule.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
dist = curriculum.get_distribution(
|
| 73 |
+
mean_reward
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
return random.choices(
|
| 77 |
+
list(dist.keys()),
|
| 78 |
+
weights=list(dist.values()),
|
| 79 |
+
k=1,
|
| 80 |
+
)[0]
|
training/debug_episode.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import asyncio
|
| 3 |
+
|
| 4 |
+
from salespath_env.client import SalesPathEnv
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
async def run_debug(env_url: str, difficulty: int):
|
| 8 |
+
actions = [
|
| 9 |
+
("PRESENT", "pitch too early"),
|
| 10 |
+
("PRESENT", "repeat pitch"),
|
| 11 |
+
("PRESENT", "repeat pitch again"),
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
async with SalesPathEnv(base_url=env_url) as env:
|
| 15 |
+
obs = await env.reset(difficulty=difficulty)
|
| 16 |
+
print("RESET")
|
| 17 |
+
print(f" turn={obs.turn_number} done={obs.done} reward={obs.reward}")
|
| 18 |
+
print(f" response={obs.prospect_response}")
|
| 19 |
+
|
| 20 |
+
for idx, (action_type, content) in enumerate(actions, start=1):
|
| 21 |
+
obs = await env.step(action_type=action_type, content=content, target="")
|
| 22 |
+
print(f"\nSTEP {idx} action={action_type}")
|
| 23 |
+
print(f" turn={obs.turn_number} done={obs.done} reward={obs.reward}")
|
| 24 |
+
print(f" violations={obs.constraints_violated}")
|
| 25 |
+
print(f" new_violations={obs.info.get('new_violations')}")
|
| 26 |
+
print(f" components={obs.reward_components}")
|
| 27 |
+
if obs.done:
|
| 28 |
+
break
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def parse_args():
|
| 32 |
+
parser = argparse.ArgumentParser(description="Debug stateful episode transitions.")
|
| 33 |
+
parser.add_argument("--env-url", default="http://127.0.0.1:8000")
|
| 34 |
+
parser.add_argument("--difficulty", type=int, default=2)
|
| 35 |
+
return parser.parse_args()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
if __name__ == "__main__":
|
| 39 |
+
args = parse_args()
|
| 40 |
+
asyncio.run(run_debug(args.env_url, args.difficulty))
|
training/grpo_train.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import asyncio
|
| 3 |
+
import ast
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 10 |
+
|
| 11 |
+
from training.curriculum import DEFAULT_CURRICULUM, sample_difficulty
|
| 12 |
+
from training.rollout import run_episode
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
DEFAULT_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
|
| 16 |
+
DEFAULT_ENV_URL = "http://127.0.0.1:8000"
|
| 17 |
+
VALID_ACTIONS = {
|
| 18 |
+
"PROSPECT",
|
| 19 |
+
"QUALIFY",
|
| 20 |
+
"PRESENT",
|
| 21 |
+
"HANDLE_OBJECTION",
|
| 22 |
+
"OFFER_DEMO",
|
| 23 |
+
"NEGOTIATE",
|
| 24 |
+
"CLOSE",
|
| 25 |
+
"FOLLOW_UP",
|
| 26 |
+
"DISQUALIFY",
|
| 27 |
+
}
|
| 28 |
+
WORKFLOW_MAP = {
|
| 29 |
+
1: ["QUALIFY", "PRESENT", "CLOSE"],
|
| 30 |
+
2: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "CLOSE"],
|
| 31 |
+
3: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "HANDLE_OBJECTION", "NEGOTIATE", "CLOSE"],
|
| 32 |
+
4: [],
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _load_model_and_tokenizer(model_name: str):
|
| 37 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 38 |
+
if tokenizer.pad_token is None:
|
| 39 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 40 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 41 |
+
model_name,
|
| 42 |
+
dtype="auto",
|
| 43 |
+
device_map="auto",
|
| 44 |
+
)
|
| 45 |
+
return model, tokenizer
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
async def curriculum_train(
|
| 49 |
+
model,
|
| 50 |
+
tokenizer,
|
| 51 |
+
env_url: str,
|
| 52 |
+
total_steps: int = 100,
|
| 53 |
+
print_every: int = 10,
|
| 54 |
+
):
|
| 55 |
+
"""Curriculum rollout loop to benchmark env + policy behavior."""
|
| 56 |
+
mean_reward = 0.0
|
| 57 |
+
reward_history: list[float] = []
|
| 58 |
+
run_log: list[dict] = []
|
| 59 |
+
|
| 60 |
+
for step in range(total_steps):
|
| 61 |
+
difficulty = sample_difficulty(DEFAULT_CURRICULUM, mean_reward)
|
| 62 |
+
result = await run_episode(
|
| 63 |
+
model=model,
|
| 64 |
+
tokenizer=tokenizer,
|
| 65 |
+
env_url=env_url,
|
| 66 |
+
difficulty=difficulty,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
reward_history.append(float(result["total_reward"]))
|
| 70 |
+
mean_reward = float(np.mean(reward_history[-20:]))
|
| 71 |
+
|
| 72 |
+
run_log.append(
|
| 73 |
+
{
|
| 74 |
+
"step": step,
|
| 75 |
+
"difficulty": difficulty,
|
| 76 |
+
"reward": float(result["total_reward"]),
|
| 77 |
+
"violations": len(result["violations"]),
|
| 78 |
+
"steps_completed": list(result["steps_completed"]),
|
| 79 |
+
}
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
if step % print_every == 0:
|
| 83 |
+
print(
|
| 84 |
+
f"Step {step:04d} | Difficulty {difficulty} | "
|
| 85 |
+
f"Reward {result['total_reward']:.3f} | Mean(20) {mean_reward:.3f} | "
|
| 86 |
+
f"Violations {len(result['violations'])} | Steps {result['steps_completed']}"
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
return {
|
| 90 |
+
"mean_reward": mean_reward,
|
| 91 |
+
"reward_history": reward_history,
|
| 92 |
+
"run_log": run_log,
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _save_metrics(output_dir: str, metrics: dict):
|
| 97 |
+
output_path = Path(output_dir)
|
| 98 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 99 |
+
rewards_path = output_path / "reward_history.txt"
|
| 100 |
+
with rewards_path.open("w", encoding="utf-8") as f:
|
| 101 |
+
for idx, reward in enumerate(metrics["reward_history"]):
|
| 102 |
+
f.write(f"{idx}\t{reward:.6f}\n")
|
| 103 |
+
print(f"Saved reward history to {rewards_path}")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _extract_action_content(text: str) -> tuple[str, str]:
|
| 107 |
+
action_match = re.search(r"ACTION:\s*(\w+)", text, re.IGNORECASE)
|
| 108 |
+
content_match = re.search(r"CONTENT:\s*(.+?)(?:\n|$)", text, re.IGNORECASE | re.DOTALL)
|
| 109 |
+
action_type = action_match.group(1).upper() if action_match else ""
|
| 110 |
+
content = content_match.group(1).strip() if content_match else ""
|
| 111 |
+
return action_type, content
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _extract_steps_completed(prompt_text: str) -> list[str]:
|
| 115 |
+
match = re.search(r"Steps completed:\s*(\[.*?\])", prompt_text, re.DOTALL)
|
| 116 |
+
if not match:
|
| 117 |
+
return []
|
| 118 |
+
try:
|
| 119 |
+
parsed = ast.literal_eval(match.group(1))
|
| 120 |
+
if isinstance(parsed, list):
|
| 121 |
+
return [str(v).upper() for v in parsed]
|
| 122 |
+
except Exception:
|
| 123 |
+
return []
|
| 124 |
+
return []
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def salespath_reward_func(prompts, completions, **kwargs):
|
| 128 |
+
"""
|
| 129 |
+
Lightweight GRPO reward signal aligned with project rules.
|
| 130 |
+
Uses format validity + basic workflow order constraints.
|
| 131 |
+
"""
|
| 132 |
+
rewards: list[float] = []
|
| 133 |
+
|
| 134 |
+
for prompt, completion in zip(prompts, completions):
|
| 135 |
+
action_type, content = _extract_action_content(completion)
|
| 136 |
+
steps_completed = _extract_steps_completed(prompt)
|
| 137 |
+
|
| 138 |
+
reward = 0.0
|
| 139 |
+
|
| 140 |
+
# Format + valid action
|
| 141 |
+
if action_type in VALID_ACTIONS and content:
|
| 142 |
+
reward += 0.1
|
| 143 |
+
else:
|
| 144 |
+
rewards.append(-0.2)
|
| 145 |
+
continue
|
| 146 |
+
|
| 147 |
+
# Rule hints
|
| 148 |
+
if not steps_completed and action_type != "PROSPECT":
|
| 149 |
+
reward -= 0.2 # R06
|
| 150 |
+
if action_type == "PRESENT" and "QUALIFY" not in steps_completed:
|
| 151 |
+
reward -= 0.2 # R01
|
| 152 |
+
if action_type == "NEGOTIATE" and "OFFER_DEMO" not in steps_completed:
|
| 153 |
+
reward -= 0.2 # R02
|
| 154 |
+
if action_type == "CLOSE" and "OFFER_DEMO" not in steps_completed:
|
| 155 |
+
reward -= 0.2 # R09
|
| 156 |
+
|
| 157 |
+
rewards.append(float(reward))
|
| 158 |
+
|
| 159 |
+
return rewards
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _build_grpo_dataset_rows(num_rows: int = 128):
|
| 163 |
+
rows = []
|
| 164 |
+
prospect_snippets = [
|
| 165 |
+
"We are evaluating options right now.",
|
| 166 |
+
"Budget is tight this quarter.",
|
| 167 |
+
"Can you explain implementation effort?",
|
| 168 |
+
"Pricing seems high compared to alternatives.",
|
| 169 |
+
]
|
| 170 |
+
|
| 171 |
+
for i in range(num_rows):
|
| 172 |
+
difficulty = (i % 4) + 1
|
| 173 |
+
workflow = WORKFLOW_MAP[difficulty]
|
| 174 |
+
steps_completed = [] if i % 3 == 0 else workflow[: min(len(workflow), i % 2 + 1)]
|
| 175 |
+
prompt = (
|
| 176 |
+
"You are a B2B sales agent.\n\n"
|
| 177 |
+
f"Required workflow steps (in order): {' -> '.join(workflow) if workflow else 'Dynamic'}\n"
|
| 178 |
+
f"Current stage: {'START' if not steps_completed else steps_completed[-1]}\n"
|
| 179 |
+
f"Steps completed: {steps_completed}\n"
|
| 180 |
+
f"Turn: {(i % 8) + 1}/20\n"
|
| 181 |
+
"Business rules: R01..R09 must be respected.\n"
|
| 182 |
+
f"Prospect response: {prospect_snippets[i % len(prospect_snippets)]}\n\n"
|
| 183 |
+
"Respond exactly with:\nACTION: <action>\nCONTENT: <message>"
|
| 184 |
+
)
|
| 185 |
+
rows.append({"prompt": prompt})
|
| 186 |
+
return rows
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def run_grpo(args):
|
| 190 |
+
try:
|
| 191 |
+
from datasets import Dataset
|
| 192 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 193 |
+
except Exception as exc:
|
| 194 |
+
raise RuntimeError(
|
| 195 |
+
"Failed to initialize TRL GRPO stack. On this machine, this is usually due to "
|
| 196 |
+
"Windows blocking pyarrow dataset binaries in the local virtualenv. "
|
| 197 |
+
"Use the provided Colab notebook (`training/colab_train.ipynb`) for GRPO runs, "
|
| 198 |
+
"or fix local pyarrow/datasets installation first."
|
| 199 |
+
) from exc
|
| 200 |
+
|
| 201 |
+
_, tokenizer = _load_model_and_tokenizer(args.model_name)
|
| 202 |
+
rows = _build_grpo_dataset_rows(args.grpo_dataset_size)
|
| 203 |
+
train_dataset = Dataset.from_list(rows)
|
| 204 |
+
|
| 205 |
+
config = GRPOConfig(
|
| 206 |
+
output_dir=args.output_dir,
|
| 207 |
+
learning_rate=args.learning_rate,
|
| 208 |
+
per_device_train_batch_size=args.per_device_train_batch_size,
|
| 209 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 210 |
+
num_generations=args.num_generations,
|
| 211 |
+
max_completion_length=args.max_completion_length,
|
| 212 |
+
temperature=args.temperature,
|
| 213 |
+
logging_steps=args.logging_steps,
|
| 214 |
+
save_steps=args.save_steps,
|
| 215 |
+
max_steps=args.grpo_steps,
|
| 216 |
+
report_to="none",
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
trainer = GRPOTrainer(
|
| 220 |
+
model=args.model_name,
|
| 221 |
+
reward_funcs=salespath_reward_func,
|
| 222 |
+
args=config,
|
| 223 |
+
train_dataset=train_dataset,
|
| 224 |
+
processing_class=tokenizer,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
trainer.train()
|
| 228 |
+
trainer.save_model(str(Path(args.output_dir) / "grpo_final"))
|
| 229 |
+
print(f"Saved GRPO model to {Path(args.output_dir) / 'grpo_final'}")
|
| 230 |
+
|
| 231 |
+
if args.push_to_hub:
|
| 232 |
+
trainer.push_to_hub(dataset_name="salespath_synthetic_grpo")
|
| 233 |
+
print(f"Pushed trainer model to hub repo: {args.hub_repo}")
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def parse_args():
|
| 237 |
+
parser = argparse.ArgumentParser(description="SalesPath training entrypoint.")
|
| 238 |
+
parser.add_argument("--mode", choices=["curriculum", "grpo"], default="curriculum")
|
| 239 |
+
parser.add_argument("--model-name", default=DEFAULT_MODEL)
|
| 240 |
+
parser.add_argument("--env-url", default=DEFAULT_ENV_URL)
|
| 241 |
+
parser.add_argument("--steps", type=int, default=100, help="Curriculum rollout steps.")
|
| 242 |
+
parser.add_argument("--print-every", type=int, default=10)
|
| 243 |
+
parser.add_argument("--output-dir", default="salespath_training_outputs")
|
| 244 |
+
parser.add_argument("--hub-repo", default="Imsachin010/salespath-qwen25-7b")
|
| 245 |
+
parser.add_argument("--push-to-hub", action="store_true")
|
| 246 |
+
parser.add_argument("--push-merged", action="store_true")
|
| 247 |
+
|
| 248 |
+
# GRPO-specific knobs
|
| 249 |
+
parser.add_argument("--grpo-steps", type=int, default=30)
|
| 250 |
+
parser.add_argument("--grpo-dataset-size", type=int, default=128)
|
| 251 |
+
parser.add_argument("--learning-rate", type=float, default=1e-5)
|
| 252 |
+
parser.add_argument("--per-device-train-batch-size", type=int, default=2)
|
| 253 |
+
parser.add_argument("--gradient-accumulation-steps", type=int, default=4)
|
| 254 |
+
parser.add_argument("--num-generations", type=int, default=8)
|
| 255 |
+
parser.add_argument("--max-completion-length", type=int, default=128)
|
| 256 |
+
parser.add_argument("--temperature", type=float, default=0.8)
|
| 257 |
+
parser.add_argument("--logging-steps", type=int, default=10)
|
| 258 |
+
parser.add_argument("--save-steps", type=int, default=100)
|
| 259 |
+
|
| 260 |
+
return parser.parse_args()
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
async def _run_curriculum_mode(args):
|
| 264 |
+
print(f"Loading model: {args.model_name}")
|
| 265 |
+
model, tokenizer = _load_model_and_tokenizer(args.model_name)
|
| 266 |
+
print(f"Starting curriculum loop against {args.env_url}")
|
| 267 |
+
|
| 268 |
+
metrics = await curriculum_train(
|
| 269 |
+
model=model,
|
| 270 |
+
tokenizer=tokenizer,
|
| 271 |
+
env_url=args.env_url,
|
| 272 |
+
total_steps=args.steps,
|
| 273 |
+
print_every=args.print_every,
|
| 274 |
+
)
|
| 275 |
+
print(f"Final mean reward (last 20): {metrics['mean_reward']:.4f}")
|
| 276 |
+
_save_metrics(args.output_dir, metrics)
|
| 277 |
+
|
| 278 |
+
if args.push_merged:
|
| 279 |
+
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
|
| 280 |
+
if hasattr(model, "save_pretrained_merged"):
|
| 281 |
+
merged_dir = Path(args.output_dir) / "salespath_trained_merged"
|
| 282 |
+
model.save_pretrained_merged(
|
| 283 |
+
str(merged_dir),
|
| 284 |
+
tokenizer,
|
| 285 |
+
save_method="merged_16bit",
|
| 286 |
+
)
|
| 287 |
+
print(f"Saved merged model to {merged_dir}")
|
| 288 |
+
if hf_token and hasattr(model, "push_to_hub_merged"):
|
| 289 |
+
model.push_to_hub_merged(
|
| 290 |
+
args.hub_repo,
|
| 291 |
+
tokenizer,
|
| 292 |
+
save_method="merged_16bit",
|
| 293 |
+
token=hf_token,
|
| 294 |
+
)
|
| 295 |
+
print(f"Pushed merged model to {args.hub_repo}")
|
| 296 |
+
else:
|
| 297 |
+
print(
|
| 298 |
+
"Model does not support merged save APIs. "
|
| 299 |
+
"Use an Unsloth merged-capable model to enable --push-merged."
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
async def _main():
|
| 304 |
+
args = parse_args()
|
| 305 |
+
if args.mode == "curriculum":
|
| 306 |
+
await _run_curriculum_mode(args)
|
| 307 |
+
return
|
| 308 |
+
|
| 309 |
+
print("Launching TRL GRPO mode...")
|
| 310 |
+
run_grpo(args)
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
if __name__ == "__main__":
|
| 314 |
+
asyncio.run(_main())
|
| 315 |
+
|
training/rollout.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# training/rollout.py
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from salespath_env.client import SalesPathEnv
|
| 7 |
+
from salespath_env.models import SalesPathObservation
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
SYSTEM_PROMPT = """
|
| 11 |
+
You are a B2B sales agent.
|
| 12 |
+
|
| 13 |
+
Your goal is to close deals by following a strict workflow.
|
| 14 |
+
|
| 15 |
+
Required workflow steps (in order):
|
| 16 |
+
{workflow}
|
| 17 |
+
|
| 18 |
+
Business rules — NEVER violate these:
|
| 19 |
+
|
| 20 |
+
- R01: Must QUALIFY before PRESENT
|
| 21 |
+
- R02: Must OFFER_DEMO before NEGOTIATE
|
| 22 |
+
- R03: Budget must be known before NEGOTIATE
|
| 23 |
+
- R04: Discount only after 2 objections handled
|
| 24 |
+
- R05: Cannot repeat same action twice in a row
|
| 25 |
+
- R06: First action must always be PROSPECT
|
| 26 |
+
- R07: FOLLOW_UP only after prospect goes silent
|
| 27 |
+
- R08: DISQUALIFY only if prospect is genuinely unqualified
|
| 28 |
+
- R09: Must OFFER_DEMO before CLOSE (difficulty 2+)
|
| 29 |
+
|
| 30 |
+
You must respond EXACTLY in this format:
|
| 31 |
+
|
| 32 |
+
ACTION: <one valid action>
|
| 33 |
+
CONTENT: <your message>
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def parse_action(text: str) -> tuple[str, str]:
|
| 38 |
+
"""
|
| 39 |
+
Extract ACTION and CONTENT from model output.
|
| 40 |
+
Fallback = QUALIFY if parsing fails.
|
| 41 |
+
"""
|
| 42 |
+
action_match = re.search(r"ACTION:\s*(\w+)", text, re.IGNORECASE)
|
| 43 |
+
content_match = re.search(r"CONTENT:\s*(.+?)(?:\n|$)", text, re.IGNORECASE | re.DOTALL)
|
| 44 |
+
|
| 45 |
+
action_type = action_match.group(1).upper() if action_match else "QUALIFY"
|
| 46 |
+
content = content_match.group(1).strip() if content_match else "Tell me more about your current process."
|
| 47 |
+
|
| 48 |
+
return action_type, content
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def build_prompt(obs: SalesPathObservation, workflow: list[str], tokenizer) -> str:
|
| 52 |
+
"""Build model prompt from environment observation."""
|
| 53 |
+
messages = [
|
| 54 |
+
{
|
| 55 |
+
"role": "system",
|
| 56 |
+
"content": SYSTEM_PROMPT.format(workflow=" -> ".join(workflow)),
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"role": "user",
|
| 60 |
+
"content": (
|
| 61 |
+
f"Prospect response: {obs.prospect_response}\n"
|
| 62 |
+
f"Current stage: {obs.workflow_stage}\n"
|
| 63 |
+
f"Steps completed: {obs.steps_completed}\n"
|
| 64 |
+
f"Turn: {obs.turn_number}/20\n"
|
| 65 |
+
f"Violations so far: {obs.constraints_violated}\n\n"
|
| 66 |
+
"What is your next action?"
|
| 67 |
+
),
|
| 68 |
+
},
|
| 69 |
+
]
|
| 70 |
+
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
async def run_episode(
|
| 74 |
+
model,
|
| 75 |
+
tokenizer,
|
| 76 |
+
env_url: str,
|
| 77 |
+
difficulty: int = 1,
|
| 78 |
+
message_timeout_s: float = 300.0,
|
| 79 |
+
) -> dict:
|
| 80 |
+
"""
|
| 81 |
+
Run one full episode using the stateful OpenEnv client.
|
| 82 |
+
Returns trajectory + rewards.
|
| 83 |
+
"""
|
| 84 |
+
DIFFICULTY_WORKFLOW = {
|
| 85 |
+
1: ["QUALIFY", "PRESENT", "CLOSE"],
|
| 86 |
+
2: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "CLOSE"],
|
| 87 |
+
3: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "HANDLE_OBJECTION", "NEGOTIATE", "CLOSE"],
|
| 88 |
+
4: [],
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
workflow = DIFFICULTY_WORKFLOW[difficulty]
|
| 92 |
+
|
| 93 |
+
async with SalesPathEnv(base_url=env_url) as env:
|
| 94 |
+
obs = await env.reset(difficulty=difficulty)
|
| 95 |
+
trajectory = []
|
| 96 |
+
total_reward = 0.0
|
| 97 |
+
|
| 98 |
+
while not obs.done:
|
| 99 |
+
# --- Model inference (CPU/GPU — no network) ---
|
| 100 |
+
prompt = build_prompt(obs, workflow, tokenizer)
|
| 101 |
+
|
| 102 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 103 |
+
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
outputs = model.generate(
|
| 106 |
+
**inputs,
|
| 107 |
+
max_new_tokens=128,
|
| 108 |
+
temperature=0.7,
|
| 109 |
+
do_sample=True,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
generated = tokenizer.decode(
|
| 113 |
+
outputs[0][inputs["input_ids"].shape[1]:],
|
| 114 |
+
skip_special_tokens=True,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
action_type, content = parse_action(generated)
|
| 118 |
+
|
| 119 |
+
# --- Stateful step via OpenEnv client ---
|
| 120 |
+
obs = await env.step(
|
| 121 |
+
action_type=action_type,
|
| 122 |
+
content=content,
|
| 123 |
+
target="",
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
trajectory.append({
|
| 127 |
+
"prompt": prompt,
|
| 128 |
+
"generated": generated,
|
| 129 |
+
"action_type": action_type,
|
| 130 |
+
"reward": obs.reward,
|
| 131 |
+
"components": obs.reward_components,
|
| 132 |
+
"done": obs.done,
|
| 133 |
+
})
|
| 134 |
+
|
| 135 |
+
total_reward += obs.reward
|
| 136 |
+
|
| 137 |
+
return {
|
| 138 |
+
"trajectory": trajectory,
|
| 139 |
+
"total_reward": total_reward,
|
| 140 |
+
"steps_completed": obs.steps_completed,
|
| 141 |
+
"violations": obs.constraints_violated,
|
| 142 |
+
"difficulty": difficulty,
|
| 143 |
+
}
|