Imsachin010 commited on
Commit
b77d3c5
·
0 Parent(s):

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +1 -0
  2. CODING_APPROACH.md +1028 -0
  3. Dockerfile +27 -0
  4. README.md +51 -0
  5. RULES.md +354 -0
  6. push_to_hub.py +44 -0
  7. pyproject.toml +22 -0
  8. requirements.txt +4 -0
  9. salespath_env.egg-info/PKG-INFO +8 -0
  10. salespath_env.egg-info/SOURCES.txt +17 -0
  11. salespath_env.egg-info/dependency_links.txt +1 -0
  12. salespath_env.egg-info/requires.txt +4 -0
  13. salespath_env.egg-info/top_level.txt +1 -0
  14. salespath_env/README.md +0 -0
  15. salespath_env/__init__.py +2 -0
  16. salespath_env/__pycache__/__init__.cpython-313.pyc +0 -0
  17. salespath_env/__pycache__/client.cpython-313.pyc +0 -0
  18. salespath_env/__pycache__/models.cpython-313.pyc +0 -0
  19. salespath_env/client.py +81 -0
  20. salespath_env/models.py +93 -0
  21. salespath_env/openenv.yaml +13 -0
  22. salespath_env/pyproject.toml +0 -0
  23. salespath_env/server/Dockerfile +12 -0
  24. salespath_env/server/__init__.py +2 -0
  25. salespath_env/server/__pycache__/__init__.cpython-313.pyc +0 -0
  26. salespath_env/server/__pycache__/app.cpython-313.pyc +0 -0
  27. salespath_env/server/__pycache__/prospect_simulator.cpython-313.pyc +0 -0
  28. salespath_env/server/__pycache__/reward.cpython-313.pyc +0 -0
  29. salespath_env/server/__pycache__/rules.cpython-313.pyc +0 -0
  30. salespath_env/server/__pycache__/salespath_environment.cpython-313.pyc +0 -0
  31. salespath_env/server/__pycache__/task_bank.cpython-313.pyc +0 -0
  32. salespath_env/server/app.py +18 -0
  33. salespath_env/server/prospect_simulator.py +162 -0
  34. salespath_env/server/requirements.txt +3 -0
  35. salespath_env/server/reward.py +138 -0
  36. salespath_env/server/rules.py +222 -0
  37. salespath_env/server/salespath_environment.py +294 -0
  38. salespath_env/server/task_bank.py +199 -0
  39. training/__init__.py +0 -0
  40. training/__pycache__/__init__.cpython-313.pyc +0 -0
  41. training/__pycache__/curriculum.cpython-313.pyc +0 -0
  42. training/__pycache__/debug_episode.cpython-313.pyc +0 -0
  43. training/__pycache__/grpo_train.cpython-313.pyc +0 -0
  44. training/__pycache__/rollout.cpython-313.pyc +0 -0
  45. training/__pycache__/test_rollout.cpython-313.pyc +0 -0
  46. training/colab_train.ipynb +100 -0
  47. training/curriculum.py +80 -0
  48. training/debug_episode.py +40 -0
  49. training/grpo_train.py +315 -0
  50. training/rollout.py +143 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ /.spa
CODING_APPROACH.md ADDED
@@ -0,0 +1,1028 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SalesPath — End-to-End Coding Approach
2
+ ### For Agent Execution. Follow in order. No skipping.
3
+
4
+ ---
5
+
6
+ ## Phase 0: Setup (Do First, ~15 min)
7
+
8
+ ```bash
9
+ # Install OpenEnv
10
+ pip install openenv
11
+
12
+ # Scaffold the project
13
+ openenv init salespath_env
14
+ cd salespath_env
15
+
16
+ # Install dependencies
17
+ pip install -e .
18
+
19
+ # Verify scaffold works
20
+ uv run server --host 0.0.0.0 --port 8000
21
+ # Should start FastAPI on 8000. Ctrl+C after confirming.
22
+ ```
23
+
24
+ Edit `pyproject.toml` — add dependencies:
25
+ ```toml
26
+ [project]
27
+ name = "salespath_env"
28
+ version = "0.1.0"
29
+ dependencies = [
30
+ "openenv",
31
+ "fastapi",
32
+ "uvicorn",
33
+ "pydantic>=2.0",
34
+ "trl>=0.8.0",
35
+ "unsloth",
36
+ "torch",
37
+ "transformers",
38
+ ]
39
+ ```
40
+
41
+ ---
42
+
43
+ ## Phase 1: Models (Person A) — `models.py`
44
+
45
+ Write this file first. Everything else depends on it.
46
+
47
+ ```python
48
+ # salespath_env/models.py
49
+ from __future__ import annotations
50
+ import uuid
51
+ from dataclasses import dataclass, field
52
+ from typing import Optional
53
+ from openenv.core import Action, Observation, State
54
+
55
+ VALID_ACTIONS = {
56
+ "PROSPECT", "QUALIFY", "PRESENT", "HANDLE_OBJECTION",
57
+ "OFFER_DEMO", "NEGOTIATE", "CLOSE", "FOLLOW_UP", "DISQUALIFY"
58
+ }
59
+
60
+ class SalesPathAction(Action):
61
+ action_type: str
62
+ content: str
63
+ target: str = ""
64
+
65
+ def is_valid(self) -> bool:
66
+ return self.action_type in VALID_ACTIONS
67
+
68
+
69
+ class SalesPathObservation(Observation):
70
+ prospect_response: str = ""
71
+ workflow_stage: str = "START"
72
+ constraints_violated: list[str] = field(default_factory=list)
73
+ steps_completed: list[str] = field(default_factory=list)
74
+ turn_number: int = 0
75
+ reward: float = 0.0
76
+ reward_components: dict = field(default_factory=dict)
77
+ done: bool = False
78
+ info: dict = field(default_factory=dict)
79
+
80
+
81
+ class SalesPathState(State):
82
+ episode_id: str = field(default_factory=lambda: str(uuid.uuid4()))
83
+ prospect_profile: dict = field(default_factory=dict)
84
+ conversation_history: list[dict] = field(default_factory=list)
85
+ workflow_stage: str = "START"
86
+ required_workflow: list[str] = field(default_factory=list)
87
+ steps_completed: list[str] = field(default_factory=list)
88
+ constraints_violated: list[str] = field(default_factory=list)
89
+ objections_handled: int = 0
90
+ turn_number: int = 0
91
+ difficulty: int = 1
92
+ done: bool = False
93
+ # Hidden — never expose in Observation
94
+ _hidden: dict = field(default_factory=dict)
95
+ ```
96
+
97
+ ---
98
+
99
+ ## Phase 2: Task Bank (Person A) — `server/task_bank.py`
100
+
101
+ This generates prospect profiles. Keep it simple — 10 profiles per difficulty level.
102
+
103
+ ```python
104
+ # server/task_bank.py
105
+ import random
106
+ from dataclasses import dataclass
107
+
108
+ @dataclass
109
+ class ProspectProfile:
110
+ company_name: str
111
+ company_size: str # "small" / "medium" / "enterprise"
112
+ industry: str
113
+ budget_signal: str # "high" / "medium" / "low" / "unknown"
114
+ pain_points: list[str]
115
+ decision_maker: bool
116
+ # Hidden — simulator uses these, agent never sees raw values
117
+ true_budget: float # 0.0 to 1.0 scale
118
+ close_threshold: float # budget needed to close
119
+ stall_probability: float # for Level 3+
120
+
121
+
122
+ PROFILES_L1 = [
123
+ ProspectProfile(
124
+ company_name="Meridian Retail",
125
+ company_size="medium",
126
+ industry="retail",
127
+ budget_signal="high",
128
+ pain_points=["manual inventory tracking", "slow reporting"],
129
+ decision_maker=True,
130
+ true_budget=0.8,
131
+ close_threshold=0.5,
132
+ stall_probability=0.0,
133
+ ),
134
+ # Add 9 more L1 profiles following same pattern
135
+ # L1: budget_signal always known, decision_maker always True, close_threshold <= 0.6
136
+ ]
137
+
138
+ PROFILES_L2 = [
139
+ ProspectProfile(
140
+ company_name="Apex Logistics",
141
+ company_size="enterprise",
142
+ industry="logistics",
143
+ budget_signal="unknown", # revealed after QUALIFY
144
+ pain_points=["route optimization", "driver coordination", "fuel tracking"],
145
+ decision_maker=True,
146
+ true_budget=0.7,
147
+ close_threshold=0.5,
148
+ stall_probability=0.0,
149
+ ),
150
+ # 9 more L2 profiles: budget hidden, one objection expected
151
+ ]
152
+
153
+ PROFILES_L3 = [
154
+ ProspectProfile(
155
+ company_name="Nova Financial",
156
+ company_size="enterprise",
157
+ industry="finance",
158
+ budget_signal="unknown",
159
+ pain_points=["compliance reporting", "audit trails", "data silos"],
160
+ decision_maker=False, # must navigate to decision maker
161
+ true_budget=0.6,
162
+ close_threshold=0.55,
163
+ stall_probability=0.3, # will stall at turn 10
164
+ ),
165
+ # 9 more L3 profiles: budget hidden, two objections, mode shift
166
+ ]
167
+
168
+ PROFILES_L4 = [
169
+ ProspectProfile(
170
+ company_name="Cipher Tech",
171
+ company_size="small",
172
+ industry="technology",
173
+ budget_signal="high", # MISLEADING — true_budget is actually low
174
+ pain_points=["security", "compliance"],
175
+ decision_maker=True,
176
+ true_budget=0.2, # can't actually afford it
177
+ close_threshold=0.5,
178
+ stall_probability=0.5,
179
+ ),
180
+ # 9 more L4: misleading signals, correct answer is DISQUALIFY
181
+ ]
182
+
183
+ ALL_PROFILES = {1: PROFILES_L1, 2: PROFILES_L2, 3: PROFILES_L3, 4: PROFILES_L4}
184
+
185
+ def sample_profile(difficulty: int) -> ProspectProfile:
186
+ return random.choice(ALL_PROFILES[difficulty])
187
+ ```
188
+
189
+ ---
190
+
191
+ ## Phase 3: Business Rules (Person A) — `server/rules.py`
192
+
193
+ ```python
194
+ # server/rules.py
195
+ from dataclasses import dataclass
196
+ from typing import Callable
197
+ from ..models import SalesPathAction, SalesPathState
198
+
199
+
200
+ @dataclass
201
+ class BusinessRule:
202
+ rule_id: str
203
+ name: str
204
+ description: str
205
+ check: Callable[[SalesPathState, SalesPathAction], bool]
206
+ # Returns True if VIOLATED
207
+
208
+
209
+ def _qualify_before_present(state: SalesPathState, action: SalesPathAction) -> bool:
210
+ if action.action_type == "PRESENT":
211
+ return "QUALIFY" not in state.steps_completed
212
+ return False
213
+
214
+
215
+ def _demo_before_negotiate(state: SalesPathState, action: SalesPathAction) -> bool:
216
+ if action.action_type == "NEGOTIATE":
217
+ return "OFFER_DEMO" not in state.steps_completed
218
+ return False
219
+
220
+
221
+ def _budget_known_to_negotiate(state: SalesPathState, action: SalesPathAction) -> bool:
222
+ if action.action_type == "NEGOTIATE":
223
+ return state.prospect_profile.get("budget_signal") == "unknown"
224
+ return False
225
+
226
+
227
+ def _discount_after_objections(state: SalesPathState, action: SalesPathAction) -> bool:
228
+ if action.action_type == "NEGOTIATE":
229
+ if "discount" in action.content.lower():
230
+ return state.objections_handled < 2
231
+ return False
232
+
233
+
234
+ def _no_repeat_action(state: SalesPathState, action: SalesPathAction) -> bool:
235
+ if state.conversation_history:
236
+ last_action = state.conversation_history[-1].get("action_type", "")
237
+ return last_action == action.action_type
238
+ return False
239
+
240
+
241
+ def _prospect_first(state: SalesPathState, action: SalesPathAction) -> bool:
242
+ if state.turn_number == 1:
243
+ return action.action_type != "PROSPECT"
244
+ return False
245
+
246
+
247
+ def _followup_timing(state: SalesPathState, action: SalesPathAction) -> bool:
248
+ if action.action_type == "FOLLOW_UP":
249
+ if state.conversation_history:
250
+ last_speaker = state.conversation_history[-1].get("speaker", "agent")
251
+ return last_speaker == "prospect" # prospect just responded
252
+ return False
253
+
254
+
255
+ def _disqualify_logic(state: SalesPathState, action: SalesPathAction) -> bool:
256
+ if action.action_type == "DISQUALIFY":
257
+ profile = state.prospect_profile
258
+ true_budget = state._hidden.get("true_budget", 0.5)
259
+ close_threshold = state._hidden.get("close_threshold", 0.5)
260
+ dm = profile.get("decision_maker", True)
261
+ # Violation: disqualifying when prospect is actually closeable
262
+ return (true_budget >= close_threshold) and dm
263
+ return False
264
+
265
+
266
+ def _close_requires_demo(state: SalesPathState, action: SalesPathAction) -> bool:
267
+ if action.action_type == "CLOSE":
268
+ if state.difficulty >= 2:
269
+ return "OFFER_DEMO" not in state.steps_completed
270
+ return False
271
+
272
+
273
+ BUSINESS_RULES = [
274
+ BusinessRule("R01", "qualify_before_present",
275
+ "Must QUALIFY before PRESENT", _qualify_before_present),
276
+ BusinessRule("R02", "demo_before_negotiate",
277
+ "Must OFFER_DEMO before NEGOTIATE", _demo_before_negotiate),
278
+ BusinessRule("R03", "budget_known_to_negotiate",
279
+ "Budget must be known before NEGOTIATE", _budget_known_to_negotiate),
280
+ BusinessRule("R04", "discount_after_objections",
281
+ "Discount only after 2 objections", _discount_after_objections),
282
+ BusinessRule("R05", "no_repeat_action",
283
+ "Cannot repeat same action consecutively", _no_repeat_action),
284
+ BusinessRule("R06", "prospect_first",
285
+ "First action must be PROSPECT", _prospect_first),
286
+ BusinessRule("R07", "followup_timing",
287
+ "FOLLOW_UP only after prospect silence", _followup_timing),
288
+ BusinessRule("R08", "disqualify_logic",
289
+ "DISQUALIFY only when prospect is genuinely unqualified", _disqualify_logic),
290
+ BusinessRule("R09", "close_requires_demo",
291
+ "Must OFFER_DEMO before CLOSE (Levels 2+)", _close_requires_demo),
292
+ ]
293
+
294
+
295
+ def check_rules(state: SalesPathState, action: SalesPathAction) -> list[str]:
296
+ """Returns list of violated rule IDs."""
297
+ return [
298
+ rule.rule_id
299
+ for rule in BUSINESS_RULES
300
+ if rule.check(state, action)
301
+ ]
302
+ ```
303
+
304
+ ---
305
+
306
+ ## Phase 4: Prospect Simulator (Person A) — `server/prospect_simulator.py`
307
+
308
+ ```python
309
+ # server/prospect_simulator.py
310
+ # PURE RULE-BASED. No LLM. No imports from transformers.
311
+
312
+ from ..models import SalesPathState, SalesPathAction
313
+
314
+ RESPONSE_TEXT = {
315
+ "open:positive_signal": "That sounds interesting. Tell me more about how this works.",
316
+ "open:neutral_signal": "I see. We're evaluating a few options at the moment.",
317
+ "objection:price": "The pricing seems higher than what we budgeted for.",
318
+ "objection:timing": "The timing isn't ideal — we're in the middle of a quarter close.",
319
+ "objection:premature_pitch": "I'm not sure we're ready to discuss solutions yet. What do you know about our situation?",
320
+ "deflect:budget_not_discussed": "We haven't really talked about what we're looking for yet.",
321
+ "deflect:stall": "Let me get back to you on this. A lot is happening on our end.",
322
+ "accept:demo_scheduled": "Yes, let's set up a demo. What time works next week?",
323
+ "accept:close_success": "Alright, I think we can move forward with this. Send over the paperwork.",
324
+ "reject:close_failed": "I don't think we're ready to commit at this point.",
325
+ "silence": "",
326
+ "exit:disqualified": "I think we're done here. This isn't the right fit.",
327
+ }
328
+
329
+
330
+ class ProspectSimulator:
331
+
332
+ def respond(self, action: SalesPathAction, state: SalesPathState) -> tuple[str, str]:
333
+ """
334
+ Returns (response_token, response_text).
335
+ Deterministic — same inputs always produce same output.
336
+ """
337
+ token = self._get_token(action, state)
338
+ text = RESPONSE_TEXT[token]
339
+ return token, text
340
+
341
+ def _get_token(self, action: SalesPathAction, state: SalesPathState) -> str:
342
+ atype = action.action_type
343
+ hidden = state._hidden
344
+ turn = state.turn_number
345
+ profile = state.prospect_profile
346
+ objections = state.objections_handled
347
+ difficulty = state.difficulty
348
+
349
+ # Rule violation responses (priority — check first)
350
+ if "R01" in state.constraints_violated[-1:]:
351
+ return "objection:premature_pitch"
352
+ if "R03" in state.constraints_violated[-1:]:
353
+ return "deflect:budget_not_discussed"
354
+
355
+ # Action-specific logic
356
+ if atype == "PROSPECT":
357
+ return "open:positive_signal"
358
+
359
+ if atype == "QUALIFY":
360
+ # Reveal budget signal if it was hidden
361
+ if profile.get("budget_signal") == "unknown":
362
+ state.prospect_profile["budget_signal"] = hidden.get("revealed_budget", "medium")
363
+ return "open:neutral_signal"
364
+
365
+ if atype == "PRESENT":
366
+ if difficulty >= 2:
367
+ return "objection:price" if objections == 0 else "open:positive_signal"
368
+ return "open:positive_signal"
369
+
370
+ if atype == "HANDLE_OBJECTION":
371
+ state.objections_handled += 1
372
+ if objections + 1 >= hidden.get("num_objections", 1):
373
+ return "open:positive_signal"
374
+ return "objection:timing" if objections == 0 else "open:positive_signal"
375
+
376
+ if atype == "OFFER_DEMO":
377
+ return "accept:demo_scheduled"
378
+
379
+ if atype == "NEGOTIATE":
380
+ return "open:neutral_signal"
381
+
382
+ if atype == "CLOSE":
383
+ true_budget = hidden.get("true_budget", 0.7)
384
+ threshold = hidden.get("close_threshold", 0.5)
385
+ if true_budget >= threshold and profile.get("decision_maker", True):
386
+ return "accept:close_success"
387
+ return "reject:close_failed"
388
+
389
+ if atype == "FOLLOW_UP":
390
+ return "open:neutral_signal"
391
+
392
+ if atype == "DISQUALIFY":
393
+ return "exit:disqualified"
394
+
395
+ # Mode shift at turn 10 for Level 3+
396
+ if difficulty >= 3 and turn >= 10:
397
+ import random
398
+ if random.random() < hidden.get("stall_probability", 0.0):
399
+ return "deflect:stall"
400
+
401
+ return "open:neutral_signal"
402
+ ```
403
+
404
+ ---
405
+
406
+ ## Phase 5: Reward Function (Person B) — `server/reward.py`
407
+
408
+ ```python
409
+ # server/reward.py
410
+
411
+ from ..models import SalesPathState, SalesPathAction
412
+
413
+ DIFFICULTY_OPTIMAL_TURNS = {1: 5, 2: 8, 3: 12, 4: 14}
414
+
415
+
416
+ def compute_reward(
417
+ state: SalesPathState,
418
+ action: SalesPathAction,
419
+ response_token: str,
420
+ new_violations: list[str],
421
+ episode_done: bool,
422
+ ) -> tuple[float, dict]:
423
+ """
424
+ Returns (total_reward, component_dict).
425
+ Always returns components — never a single scalar.
426
+ """
427
+ components = {}
428
+
429
+ # --- Component 1: Outcome (only on terminal step) ---
430
+ r_outcome = 0.0
431
+ if episode_done:
432
+ if response_token == "accept:close_success":
433
+ r_outcome = 1.0
434
+ elif action.action_type == "DISQUALIFY":
435
+ # Check if disqualify was correct (no R08 violation)
436
+ if "R08" not in new_violations:
437
+ r_outcome = 0.5
438
+ else:
439
+ r_outcome = -0.5
440
+ elif state.turn_number >= 20:
441
+ r_outcome = -0.3
442
+ elif len(state.constraints_violated) >= 3:
443
+ r_outcome = -0.5
444
+ else:
445
+ r_outcome = -0.5 # failed close
446
+ components["r_outcome"] = r_outcome
447
+
448
+ # --- Component 2: Compliance ---
449
+ total_violations = len(state.constraints_violated) + len(new_violations)
450
+ r_compliance = max(-1.0, -0.2 * len(new_violations)) # per-step signal
451
+ components["r_compliance"] = r_compliance
452
+
453
+ # --- Component 3: Step Ordering ---
454
+ required = state.required_workflow
455
+ completed = state.steps_completed
456
+ if len(required) > 1 and len(completed) > 0:
457
+ # Count correct transitions
458
+ correct = sum(
459
+ 1 for i in range(min(len(completed), len(required)))
460
+ if completed[i] == required[i]
461
+ )
462
+ r_ordering = correct / len(required)
463
+ else:
464
+ r_ordering = 1.0 if (not required or action.action_type == required[0]) else 0.0
465
+ components["r_ordering"] = r_ordering
466
+
467
+ # --- Component 4: Efficiency ---
468
+ if episode_done:
469
+ optimal = DIFFICULTY_OPTIMAL_TURNS.get(state.difficulty, 10)
470
+ overhead = max(0, state.turn_number - optimal)
471
+ r_efficiency = max(-0.3, -0.05 * overhead)
472
+ else:
473
+ r_efficiency = 0.0 # only computed at episode end
474
+ components["r_efficiency"] = r_efficiency
475
+
476
+ # --- Component 5: Format ---
477
+ r_format = 1.0 if action.is_valid() else -0.1
478
+ components["r_format"] = r_format
479
+
480
+ # --- Weighted total ---
481
+ weights = {
482
+ "r_outcome": 0.40,
483
+ "r_compliance": 0.30,
484
+ "r_ordering": 0.15,
485
+ "r_efficiency": 0.10,
486
+ "r_format": 0.05,
487
+ }
488
+ total = sum(weights[k] * v for k, v in components.items())
489
+ components["total"] = total
490
+
491
+ return total, components
492
+ ```
493
+
494
+ ---
495
+
496
+ ## Phase 6: Environment Core (Person A) — `server/salespath_environment.py`
497
+
498
+ ```python
499
+ # server/salespath_environment.py
500
+ import uuid
501
+ from openenv.core.env_server import Environment
502
+ from ..models import SalesPathAction, SalesPathObservation, SalesPathState
503
+ from .task_bank import sample_profile
504
+ from .rules import check_rules, BUSINESS_RULES
505
+ from .reward import compute_reward
506
+ from .prospect_simulator import ProspectSimulator
507
+
508
+ DIFFICULTY_WORKFLOW = {
509
+ 1: ["QUALIFY", "PRESENT", "CLOSE"],
510
+ 2: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "CLOSE"],
511
+ 3: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO",
512
+ "HANDLE_OBJECTION", "NEGOTIATE", "CLOSE"],
513
+ 4: [], # agent must determine; DISQUALIFY may be correct
514
+ }
515
+
516
+ MAX_VIOLATIONS_BEFORE_TERMINATE = 3
517
+ MAX_TURNS = 20
518
+
519
+
520
+ class SalesPathEnvironment(Environment):
521
+
522
+ def __init__(self):
523
+ super().__init__()
524
+ self._state = SalesPathState()
525
+ self._simulator = ProspectSimulator()
526
+
527
+ def reset(self, difficulty: int = 1) -> SalesPathObservation:
528
+ profile = sample_profile(difficulty)
529
+ hidden = {
530
+ "true_budget": profile.true_budget,
531
+ "close_threshold": profile.close_threshold,
532
+ "stall_probability": profile.stall_probability,
533
+ "num_objections": {1: 0, 2: 1, 3: 2, 4: 2}[difficulty],
534
+ "revealed_budget": (
535
+ "high" if profile.true_budget >= 0.7
536
+ else "medium" if profile.true_budget >= 0.4
537
+ else "low"
538
+ ),
539
+ }
540
+ public_profile = {
541
+ "company_name": profile.company_name,
542
+ "company_size": profile.company_size,
543
+ "industry": profile.industry,
544
+ "budget_signal": profile.budget_signal,
545
+ "pain_points": profile.pain_points,
546
+ "decision_maker": profile.decision_maker,
547
+ }
548
+ self._state = SalesPathState(
549
+ episode_id=str(uuid.uuid4()),
550
+ prospect_profile=public_profile,
551
+ required_workflow=DIFFICULTY_WORKFLOW[difficulty],
552
+ difficulty=difficulty,
553
+ )
554
+ self._state._hidden = hidden
555
+
556
+ return SalesPathObservation(
557
+ prospect_response=(
558
+ f"You are engaging {profile.company_name}, a {profile.company_size} "
559
+ f"{profile.industry} company. Pain points: {', '.join(profile.pain_points)}. "
560
+ f"Begin the sales conversation."
561
+ ),
562
+ workflow_stage="START",
563
+ steps_completed=[],
564
+ constraints_violated=[],
565
+ turn_number=0,
566
+ reward=0.0,
567
+ done=False,
568
+ info={"difficulty": difficulty, "episode_id": self._state.episode_id},
569
+ )
570
+
571
+ def step(self, action: SalesPathAction) -> SalesPathObservation:
572
+ state = self._state
573
+ state.turn_number += 1
574
+
575
+ # Validate action format
576
+ if not action.is_valid():
577
+ return SalesPathObservation(
578
+ prospect_response="Invalid action type.",
579
+ workflow_stage=state.workflow_stage,
580
+ steps_completed=list(state.steps_completed),
581
+ constraints_violated=list(state.constraints_violated),
582
+ turn_number=state.turn_number,
583
+ reward=-0.2,
584
+ done=False,
585
+ info={"error": f"Invalid action_type: {action.action_type}",
586
+ "r_format": -0.1},
587
+ )
588
+
589
+ # Check business rules
590
+ new_violations = check_rules(state, action)
591
+ state.constraints_violated.extend(new_violations)
592
+
593
+ # Update conversation history
594
+ state.conversation_history.append({
595
+ "turn": state.turn_number,
596
+ "speaker": "agent",
597
+ "action_type": action.action_type,
598
+ "content": action.content,
599
+ })
600
+
601
+ # Update steps completed
602
+ if action.action_type not in state.steps_completed:
603
+ state.steps_completed.append(action.action_type)
604
+ state.workflow_stage = action.action_type
605
+
606
+ # Get prospect response
607
+ response_token, response_text = self._simulator.respond(action, state)
608
+ state.conversation_history.append({
609
+ "turn": state.turn_number,
610
+ "speaker": "prospect",
611
+ "response_token": response_token,
612
+ "text": response_text,
613
+ })
614
+
615
+ # Determine episode termination
616
+ terminal_actions = {"CLOSE", "DISQUALIFY"}
617
+ too_many_violations = len(state.constraints_violated) >= MAX_VIOLATIONS_BEFORE_TERMINATE
618
+ turn_limit = state.turn_number >= MAX_TURNS
619
+ done = (
620
+ action.action_type in terminal_actions
621
+ or too_many_violations
622
+ or turn_limit
623
+ )
624
+ state.done = done
625
+
626
+ # Compute reward
627
+ total_reward, components = compute_reward(
628
+ state, action, response_token, new_violations, done
629
+ )
630
+
631
+ return SalesPathObservation(
632
+ prospect_response=response_text,
633
+ workflow_stage=state.workflow_stage,
634
+ steps_completed=list(state.steps_completed),
635
+ constraints_violated=list(state.constraints_violated),
636
+ turn_number=state.turn_number,
637
+ reward=total_reward,
638
+ reward_components=components,
639
+ done=done,
640
+ info={
641
+ "response_token": response_token,
642
+ "new_violations": new_violations,
643
+ "episode_id": state.episode_id,
644
+ },
645
+ )
646
+
647
+ @property
648
+ def state(self) -> SalesPathState:
649
+ return self._state
650
+ ```
651
+
652
+ ---
653
+
654
+ ## Phase 7: FastAPI App (Person A) — `server/app.py`
655
+
656
+ ```python
657
+ # server/app.py — thin wrapper only
658
+ from openenv.core.env_server import create_fastapi_app
659
+ from ..models import SalesPathAction, SalesPathObservation
660
+ from .salespath_environment import SalesPathEnvironment
661
+
662
+ app = create_fastapi_app(
663
+ SalesPathEnvironment,
664
+ SalesPathAction,
665
+ SalesPathObservation,
666
+ )
667
+ ```
668
+
669
+ ---
670
+
671
+ ## Phase 8: Client (Person B) — `client.py`
672
+
673
+ ```python
674
+ # client.py
675
+ from openenv.core import EnvClient
676
+ from .models import SalesPathAction, SalesPathObservation, SalesPathState
677
+
678
+
679
+ class SalesPathEnv(EnvClient):
680
+ action_type = SalesPathAction
681
+ observation_type = SalesPathObservation
682
+ state_type = SalesPathState
683
+
684
+ async def reset(self, difficulty: int = 1) -> SalesPathObservation:
685
+ return await super().reset(difficulty=difficulty)
686
+
687
+ async def step(self, action_type: str, content: str, target: str = "") -> SalesPathObservation:
688
+ action = SalesPathAction(
689
+ action_type=action_type,
690
+ content=content,
691
+ target=target,
692
+ )
693
+ return await super().step(action)
694
+ ```
695
+
696
+ ---
697
+
698
+ ## Phase 9: Rollout Function (Person B) — `training/rollout.py`
699
+
700
+ ```python
701
+ # training/rollout.py
702
+ import re
703
+ from salespath_env.client import SalesPathEnv
704
+ from salespath_env.models import SalesPathObservation
705
+
706
+ SYSTEM_PROMPT = """You are a B2B sales agent. Your goal is to close deals by following a strict workflow.
707
+
708
+ Required workflow steps (in order): {workflow}
709
+
710
+ Business rules — NEVER violate these:
711
+ - R01: Must QUALIFY before PRESENT
712
+ - R02: Must OFFER_DEMO before NEGOTIATE
713
+ - R03: Budget must be known before NEGOTIATE
714
+ - R04: Discount only after 2 objections handled
715
+ - R05: Cannot repeat same action twice in a row
716
+ - R06: First action must always be PROSPECT
717
+ - R07: FOLLOW_UP only after prospect goes silent
718
+ - R08: DISQUALIFY only if prospect is genuinely unqualified
719
+ - R09: Must OFFER_DEMO before CLOSE (difficulty 2+)
720
+
721
+ Respond EXACTLY in this format:
722
+ ACTION: <one of: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY>
723
+ CONTENT: <your message to the prospect>"""
724
+
725
+
726
+ def parse_action(text: str) -> tuple[str, str]:
727
+ """Extract ACTION and CONTENT from model output."""
728
+ action_match = re.search(r"ACTION:\s*(\w+)", text, re.IGNORECASE)
729
+ content_match = re.search(r"CONTENT:\s*(.+?)(?:\n|$)", text, re.IGNORECASE | re.DOTALL)
730
+
731
+ action_type = action_match.group(1).upper() if action_match else "QUALIFY"
732
+ content = content_match.group(1).strip() if content_match else "Tell me more about your needs."
733
+
734
+ return action_type, content
735
+
736
+
737
+ def build_prompt(obs: SalesPathObservation, workflow: list[str], tokenizer) -> str:
738
+ messages = [
739
+ {"role": "system", "content": SYSTEM_PROMPT.format(workflow=" → ".join(workflow))},
740
+ {"role": "user", "content": (
741
+ f"Prospect response: {obs.prospect_response}\n"
742
+ f"Current stage: {obs.workflow_stage}\n"
743
+ f"Steps completed: {obs.steps_completed}\n"
744
+ f"Turn: {obs.turn_number}/20\n"
745
+ f"Violations so far: {obs.constraints_violated}\n\n"
746
+ "What is your next action?"
747
+ )},
748
+ ]
749
+ return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
750
+
751
+
752
+ async def run_episode(model, tokenizer, env_url: str, difficulty: int = 1) -> dict:
753
+ """Run one full episode. Returns trajectory with rewards."""
754
+ DIFFICULTY_WORKFLOW = {
755
+ 1: ["QUALIFY", "PRESENT", "CLOSE"],
756
+ 2: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "CLOSE"],
757
+ 3: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO",
758
+ "HANDLE_OBJECTION", "NEGOTIATE", "CLOSE"],
759
+ 4: [],
760
+ }
761
+ workflow = DIFFICULTY_WORKFLOW[difficulty]
762
+
763
+ async with SalesPathEnv(base_url=env_url) as env:
764
+ obs = await env.reset(difficulty=difficulty)
765
+ trajectory = []
766
+ total_reward = 0.0
767
+
768
+ while not obs.done:
769
+ prompt = build_prompt(obs, workflow, tokenizer)
770
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
771
+
772
+ with torch.no_grad():
773
+ outputs = model.generate(
774
+ **inputs,
775
+ max_new_tokens=256,
776
+ temperature=0.8,
777
+ do_sample=True,
778
+ )
779
+ generated = tokenizer.decode(
780
+ outputs[0][inputs["input_ids"].shape[1]:],
781
+ skip_special_tokens=True
782
+ )
783
+
784
+ action_type, content = parse_action(generated)
785
+ obs = await env.step(action_type, content)
786
+
787
+ trajectory.append({
788
+ "prompt": prompt,
789
+ "generated": generated,
790
+ "action_type": action_type,
791
+ "reward": obs.reward,
792
+ "components": obs.reward_components,
793
+ "done": obs.done,
794
+ })
795
+ total_reward += obs.reward
796
+
797
+ return {
798
+ "trajectory": trajectory,
799
+ "total_reward": total_reward,
800
+ "steps_completed": obs.steps_completed,
801
+ "violations": obs.constraints_violated,
802
+ "difficulty": difficulty,
803
+ }
804
+ ```
805
+
806
+ ---
807
+
808
+ ## Phase 10: Curriculum Scheduler (Person B) — `training/curriculum.py`
809
+
810
+ ```python
811
+ # training/curriculum.py
812
+ from dataclasses import dataclass
813
+
814
+ @dataclass
815
+ class CurriculumConfig:
816
+ thresholds: dict # mean_reward -> difficulty_distribution
817
+
818
+ def get_distribution(self, mean_reward: float) -> dict:
819
+ for threshold in sorted(self.thresholds.keys(), reverse=True):
820
+ if mean_reward >= threshold:
821
+ return self.thresholds[threshold]
822
+ return self.thresholds[min(self.thresholds.keys())]
823
+
824
+
825
+ DEFAULT_CURRICULUM = CurriculumConfig(
826
+ thresholds={
827
+ 0.0: {1: 0.90, 2: 0.10, 3: 0.00, 4: 0.00},
828
+ 0.30: {1: 0.50, 2: 0.40, 3: 0.10, 4: 0.00},
829
+ 0.50: {1: 0.20, 2: 0.40, 3: 0.35, 4: 0.05},
830
+ 0.65: {1: 0.10, 2: 0.30, 3: 0.40, 4: 0.20},
831
+ }
832
+ )
833
+
834
+
835
+ def sample_difficulty(curriculum: CurriculumConfig, mean_reward: float) -> int:
836
+ import random
837
+ dist = curriculum.get_distribution(mean_reward)
838
+ return random.choices(
839
+ list(dist.keys()),
840
+ weights=list(dist.values()),
841
+ k=1
842
+ )[0]
843
+ ```
844
+
845
+ ---
846
+
847
+ ## Phase 11: Training Script (Person B) — `training/grpo_train.py`
848
+
849
+ ```python
850
+ # training/grpo_train.py
851
+ import torch
852
+ import asyncio
853
+ import numpy as np
854
+ from unsloth import FastLanguageModel
855
+ from trl import GRPOConfig, GRPOTrainer
856
+ from curriculum import DEFAULT_CURRICULUM, sample_difficulty
857
+ from rollout import run_episode
858
+
859
+ # --- Model Load ---
860
+ model, tokenizer = FastLanguageModel.from_pretrained(
861
+ model_name="unsloth/Qwen2.5-7B-Instruct",
862
+ max_seq_length=2048,
863
+ load_in_4bit=True,
864
+ )
865
+ model = FastLanguageModel.get_peft_model(
866
+ model,
867
+ r=16,
868
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
869
+ "gate_proj", "up_proj", "down_proj"],
870
+ lora_alpha=16,
871
+ lora_dropout=0,
872
+ bias="none",
873
+ use_gradient_checkpointing="unsloth",
874
+ )
875
+
876
+ ENV_URL = "http://localhost:8000" # or HuggingFace Space URL
877
+
878
+ # --- Reward function for GRPO (wraps environment) ---
879
+ def salespath_reward_fn(completions, prompts, **kwargs) -> list[float]:
880
+ """
881
+ GRPO calls this with a batch of completions.
882
+ We run each through the environment and return rewards.
883
+ """
884
+ rewards = []
885
+ for completion in completions:
886
+ # Parse action from completion
887
+ from rollout import parse_action
888
+ action_type, content = parse_action(completion)
889
+ # For GRPO, we use a simplified single-step reward
890
+ # Full episode reward is tracked separately in curriculum loop
891
+ reward = kwargs.get("step_rewards", {}).get(completion, 0.0)
892
+ rewards.append(reward)
893
+ return rewards
894
+
895
+
896
+ # --- Training config ---
897
+ training_config = GRPOConfig(
898
+ output_dir="salespath_grpo_output",
899
+ num_train_epochs=3,
900
+ per_device_train_batch_size=2,
901
+ gradient_accumulation_steps=4,
902
+ num_generations=8,
903
+ max_new_tokens=256,
904
+ temperature=0.8,
905
+ learning_rate=1e-5,
906
+ logging_steps=10,
907
+ save_steps=100,
908
+ report_to="none",
909
+ )
910
+
911
+ # --- Curriculum training loop ---
912
+ async def curriculum_train():
913
+ mean_reward = 0.0
914
+ reward_history = []
915
+
916
+ for step in range(500):
917
+ difficulty = sample_difficulty(DEFAULT_CURRICULUM, mean_reward)
918
+ result = await run_episode(model, tokenizer, ENV_URL, difficulty)
919
+
920
+ reward_history.append(result["total_reward"])
921
+ if len(reward_history) > 20:
922
+ mean_reward = np.mean(reward_history[-20:])
923
+
924
+ # Log metrics
925
+ if step % 10 == 0:
926
+ print(f"Step {step:4d} | Difficulty {difficulty} | "
927
+ f"Reward {result['total_reward']:.3f} | "
928
+ f"Mean(20) {mean_reward:.3f} | "
929
+ f"Violations {len(result['violations'])} | "
930
+ f"Steps {result['steps_completed']}")
931
+
932
+ # Manual inspection every 50 steps
933
+ if step % 50 == 0:
934
+ print("\n=== RAW GENERATION SAMPLE ===")
935
+ if result["trajectory"]:
936
+ print(result["trajectory"][0]["generated"])
937
+ print("==============================\n")
938
+
939
+
940
+ if __name__ == "__main__":
941
+ asyncio.run(curriculum_train())
942
+ ```
943
+
944
+ ---
945
+
946
+ ## Phase 12: Dockerfile (Person A) — `server/Dockerfile`
947
+
948
+ ```dockerfile
949
+ ARG BASE_IMAGE=openenv-base:latest
950
+ FROM ${BASE_IMAGE}
951
+
952
+ COPY server/requirements.txt /tmp/requirements.txt
953
+ RUN pip install --no-cache-dir -r /tmp/requirements.txt
954
+
955
+ COPY src/openenv/core/ /app/src/openenv/core/
956
+ COPY salespath_env/ /app/salespath_env/
957
+
958
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
959
+ CMD curl -f http://localhost:8000/health || exit 1
960
+
961
+ CMD ["uvicorn", "salespath_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
962
+ ```
963
+
964
+ `server/requirements.txt`:
965
+ ```
966
+ fastapi
967
+ uvicorn
968
+ pydantic>=2.0
969
+ ```
970
+
971
+ ---
972
+
973
+ ## Phase 13: Deploy to HuggingFace
974
+
975
+ ```bash
976
+ # From salespath_env/ directory
977
+ openenv push --repo-id Imsachin010/salespath-env
978
+
979
+ # Verify it's running
980
+ curl -X POST https://imsachin010-salespath-env.hf.space/reset \
981
+ -H "Content-Type: application/json" \
982
+ -d '{"difficulty": 1}'
983
+ ```
984
+
985
+ ---
986
+
987
+ ## Phase 14: Model Save (After Training)
988
+
989
+ ```python
990
+ # CORRECT save — do not change this
991
+ model.save_pretrained_merged(
992
+ "salespath_trained_merged",
993
+ tokenizer,
994
+ save_method="merged_16bit",
995
+ )
996
+
997
+ # Push to HuggingFace Hub
998
+ model.push_to_hub_merged(
999
+ "Imsachin010/salespath-qwen25-7b",
1000
+ tokenizer,
1001
+ save_method="merged_16bit",
1002
+ )
1003
+ ```
1004
+
1005
+ ---
1006
+
1007
+ ## Build Order Summary
1008
+
1009
+ ```
1010
+ Person A (Environment): Person B (Training):
1011
+ 1. models.py (wait for models.py)
1012
+ 2. server/task_bank.py 1. server/reward.py
1013
+ 3. server/rules.py 2. training/rollout.py
1014
+ 4. server/prospect_simulator.py 3. training/curriculum.py
1015
+ 5. server/salespath_environment 4. training/grpo_train.py
1016
+ 6. server/app.py 5. training/colab_train.ipynb
1017
+ 7. Dockerfile
1018
+ 8. openenv push → verify health
1019
+ 6. Connect rollout to live env URL
1020
+ 7. Run first training loop (difficulty=1 only)
1021
+ 8. Verify reward > 0 on step 1
1022
+ 9. Enable curriculum
1023
+ ```
1024
+
1025
+ **Critical gate:** Person B does not run training until Person A has confirmed:
1026
+ - `POST /reset` returns a valid observation
1027
+ - `POST /step` with a valid action returns a valid observation
1028
+ - `POST /step` with an invalid action returns error in `info`, not a 500
Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # HuggingFace Spaces runs on port 7860 by default
4
+ ENV PORT=7860
5
+ ENV PYTHONUNBUFFERED=1
6
+ ENV PYTHONDONTWRITEBYTECODE=1
7
+
8
+ WORKDIR /app
9
+
10
+ # Install system dependencies
11
+ RUN apt-get update && apt-get install -y --no-install-recommends \
12
+ curl \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # Install Python dependencies
16
+ COPY requirements.txt .
17
+ RUN pip install --no-cache-dir -r requirements.txt
18
+
19
+ # Copy the salespath_env package
20
+ COPY salespath_env/ ./salespath_env/
21
+
22
+ # Health check
23
+ HEALTHCHECK --interval=30s --timeout=5s --start-period=15s --retries=3 \
24
+ CMD curl -f http://localhost:${PORT}/health || exit 1
25
+
26
+ # Start the FastAPI server on HF Spaces port
27
+ CMD ["sh", "-c", "uvicorn salespath_env.server.app:app --host 0.0.0.0 --port ${PORT}"]
README.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SalesPath Environment
3
+ emoji: 🤝
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: docker
7
+ app_port: 7860
8
+ pinned: false
9
+ license: mit
10
+ short_description: RL gym environment for sales agent training
11
+ ---
12
+
13
+ # SalesPath Environment
14
+
15
+ A [OpenEnv](https://github.com/openenv)-compatible Reinforcement Learning gym environment for training sales agents via LLM fine-tuning.
16
+
17
+ ## API Endpoints
18
+
19
+ | Method | Endpoint | Description |
20
+ |--------|----------|-------------|
21
+ | `POST` | `/reset` | Reset the environment, returns initial observation |
22
+ | `POST` | `/step` | Take an action, returns next observation + reward |
23
+ | `GET` | `/health` | Health check |
24
+
25
+ ## Quick Start
26
+
27
+ ### Reset
28
+ ```bash
29
+ curl -X POST https://imsachin010-salespath-env.hf.space/reset \
30
+ -H "Content-Type: application/json" \
31
+ -d '{"difficulty": 1}'
32
+ ```
33
+
34
+ ### Step
35
+ ```bash
36
+ curl -X POST https://imsachin010-salespath-env.hf.space/step \
37
+ -H "Content-Type: application/json" \
38
+ -d '{"action": {"action_type": "PROSPECT", "content": "Hello, tell me about your workflow challenges."}}'
39
+ ```
40
+
41
+ ## Action Types
42
+
43
+ - `PROSPECT` — Initial outreach and discovery
44
+ - `QUALIFY` — Qualify the lead
45
+ - `PRESENT` — Deliver the sales pitch
46
+ - `HANDLE_OBJECTION` — Handle prospect objections
47
+ - `OFFER_DEMO` — Offer product demonstration
48
+ - `NEGOTIATE` — Discuss pricing and terms
49
+ - `FOLLOW_UP` — Follow-up message
50
+ - `DISQUALIFY` — Exit if prospect is not a fit
51
+ - `CLOSE` — Attempt to close the deal
RULES.md ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SalesPath — Agent Rules & Constraints
2
+ ### Read this before touching any file. These are non-negotiable.
3
+
4
+ ---
5
+
6
+ ## 0. Project Identity
7
+
8
+ - **Project name:** `salespath_env`
9
+ - **HuggingFace repo:** `Imsachin010/salespath-env`
10
+ - **Theme:** Theme #2 — Long-Horizon Planning (Scale AI bonus prize)
11
+ - **Stack:** OpenEnv + GRPO (HF TRL) + Unsloth + Qwen 2.5 7B Instruct
12
+
13
+ ---
14
+
15
+ ## 1. Directory Structure — Do Not Deviate
16
+
17
+ ```
18
+ salespath_env/
19
+ ├── __init__.py
20
+ ├── models.py ← ALL Pydantic dataclasses live here only
21
+ ├── client.py ← SalesPathEnv(EnvClient) lives here only
22
+ ├── README.md
23
+ ├── openenv.yaml
24
+ ├── pyproject.toml
25
+ ├── server/
26
+ │ ├── __init__.py
27
+ │ ├── salespath_environment.py ← SalesPathEnvironment(Environment)
28
+ │ ├── prospect_simulator.py ← ProspectSimulator (rule-based only)
29
+ │ ├── reward.py ← ALL reward logic lives here only
30
+ │ ├── task_bank.py ← ALL prospect profiles and tasks
31
+ │ ├── rules.py ← ALL business rule definitions
32
+ │ ├── app.py ← FastAPI app only, no logic
33
+ │ ├── requirements.txt
34
+ │ └── Dockerfile
35
+ training/
36
+ ├── grpo_train.py ← training script
37
+ ├── rollout.py ← rollout function
38
+ ├── curriculum.py ← difficulty scheduler
39
+ └── colab_train.ipynb ← Colab notebook for judges
40
+ ```
41
+
42
+ ---
43
+
44
+ ## 2. OpenEnv API — Exact Signatures to Follow
45
+
46
+ ```python
47
+ # models.py — extend these base classes
48
+ from openenv.core import Action, Observation, State # actual imports
49
+
50
+ class SalesPathAction(Action):
51
+ action_type: str # one of the 9 valid action types
52
+ content: str # natural language content of the action
53
+ target: str = "" # optional target (e.g., which objection)
54
+
55
+ class SalesPathObservation(Observation):
56
+ prospect_response: str
57
+ workflow_stage: str
58
+ constraints_violated: list[str]
59
+ steps_completed: list[str]
60
+ turn_number: int
61
+ reward: float
62
+ done: bool
63
+ info: dict
64
+
65
+ class SalesPathState(State):
66
+ episode_id: str
67
+ prospect_profile: dict
68
+ conversation_history: list[dict]
69
+ workflow_stage: str
70
+ steps_completed: list[str]
71
+ constraints_violated: list[str]
72
+ turn_number: int
73
+ difficulty: int # 1, 2, 3, or 4
74
+ hidden_state: dict # NOT exposed to agent
75
+ ```
76
+
77
+ ```python
78
+ # server/salespath_environment.py
79
+ from openenv.core.env_server import Environment
80
+
81
+ class SalesPathEnvironment(Environment):
82
+ def reset(self, difficulty: int = 1) -> SalesPathObservation: ...
83
+ def step(self, action: SalesPathAction) -> SalesPathObservation: ...
84
+ @property
85
+ def state(self) -> SalesPathState: ...
86
+ ```
87
+
88
+ ```python
89
+ # server/app.py — nothing else in this file
90
+ from openenv.core.env_server import create_fastapi_app
91
+ from ..models import SalesPathAction, SalesPathObservation
92
+ from .salespath_environment import SalesPathEnvironment
93
+
94
+ app = create_fastapi_app(SalesPathEnvironment, SalesPathAction, SalesPathObservation)
95
+ ```
96
+
97
+ ---
98
+
99
+ ## 3. Hard Rules — Code Will Be Rejected If Violated
100
+
101
+ ### 3.1 No LLM in the Environment
102
+ - `ProspectSimulator` is a **pure rule-based state machine**
103
+ - No API calls, no model inference, no `transformers` imports inside `server/`
104
+ - If you find yourself writing `model.generate()` inside `server/`, stop. Wrong file.
105
+
106
+ ### 3.2 Immutable Prospect State
107
+ - Once `reset()` sets the prospect profile, agent actions **cannot modify `hidden_state`**
108
+ - `hidden_state` is read-only after `reset()`
109
+ - Never expose `hidden_state` fields in `SalesPathObservation`
110
+
111
+ ### 3.3 Reward Lives in One Place
112
+ - All reward computation goes in `server/reward.py`
113
+ - `salespath_environment.py` calls `compute_reward()` — it does not compute reward itself
114
+ - Never compute reward inside `step()` directly
115
+
116
+ ### 3.4 Business Rules Live in One Place
117
+ - All rule definitions go in `server/rules.py` as a list of `BusinessRule` dataclasses
118
+ - `step()` calls `check_rules(state, action)` from `rules.py` — it does not check rules inline
119
+
120
+ ### 3.5 Turn Limit is Absolute
121
+ - Max turns = 20. Hard terminate. No exceptions.
122
+ - Episode must set `done=True` and assign `r_outcome = -0.3` at turn 20 regardless of state
123
+
124
+ ### 3.6 Action Validation is Strict
125
+ - If `action_type` is not one of the 9 valid types, return `done=False`, `reward=-0.2`, observation with error message
126
+ - Do not raise exceptions to the agent — return a valid `SalesPathObservation` with error in `info`
127
+
128
+ ### 3.7 Reward Must Be Multi-Component
129
+ - Reward function must log all 5 components separately in `info` dict
130
+ - Never return a single scalar reward without component breakdown
131
+ - Component keys: `r_outcome`, `r_compliance`, `r_ordering`, `r_efficiency`, `r_format`
132
+
133
+ ### 3.8 No Global Mutable State in Environment
134
+ - Each WebSocket session gets its own `SalesPathEnvironment` instance
135
+ - No class-level variables that change during episodes
136
+ - No module-level state
137
+
138
+ ---
139
+
140
+ ## 4. Valid Action Types — Exact Strings
141
+
142
+ ```python
143
+ VALID_ACTIONS = {
144
+ "PROSPECT", # initial outreach — only valid on turn 1
145
+ "QUALIFY", # ask qualification questions
146
+ "PRESENT", # deliver pitch
147
+ "HANDLE_OBJECTION", # respond to raised objection
148
+ "OFFER_DEMO", # propose product demonstration
149
+ "NEGOTIATE", # discuss pricing/terms
150
+ "CLOSE", # submit closing offer → terminates episode
151
+ "FOLLOW_UP", # follow up after no response
152
+ "DISQUALIFY", # exit if prospect is not a fit → terminates episode
153
+ }
154
+ ```
155
+
156
+ ---
157
+
158
+ ## 5. Business Rules — Exact Definitions
159
+
160
+ These are checked after every `step()`. Each violation increments `constraints_violated`.
161
+
162
+ ```python
163
+ RULES = [
164
+ # ID Name Condition for VIOLATION
165
+ R01 "qualify_before_present" PRESENT called before any QUALIFY
166
+ R02 "demo_before_negotiate" NEGOTIATE called before OFFER_DEMO
167
+ R03 "budget_known_to_negotiate" NEGOTIATE called while budget_signal == "unknown"
168
+ R04 "discount_after_objections" Discount mentioned in NEGOTIATE before 2 objections handled
169
+ R05 "no_repeat_action" Same action_type on consecutive turns
170
+ R06 "prospect_first" Any action other than PROSPECT on turn 1
171
+ R07 "followup_timing" FOLLOW_UP called when prospect responded last turn
172
+ R08 "disqualify_logic" DISQUALIFY called when budget >= threshold AND decision_maker==True
173
+ R09 "close_requires_demo" CLOSE called before OFFER_DEMO
174
+ ]
175
+ ```
176
+
177
+ Three violations → `done=True`, `r_outcome = -0.5`
178
+
179
+ ---
180
+
181
+ ## 6. Prospect Simulator — Exact Response Rules
182
+
183
+ `ProspectSimulator.respond(action, state)` returns one of these string tokens. The environment converts tokens to natural language text for the observation.
184
+
185
+ ```python
186
+ RESPONSE_TOKENS = {
187
+ "open:positive_signal", # prospect is engaged and open
188
+ "open:neutral_signal", # prospect acknowledges but non-committal
189
+ "objection:price", # raises price objection
190
+ "objection:timing", # raises timing objection
191
+ "objection:premature_pitch", # triggered by R01 violation
192
+ "deflect:budget_not_discussed", # triggered by R03 violation
193
+ "deflect:stall", # prospect stalls (Level 3+)
194
+ "accept:demo_scheduled", # agrees to demo
195
+ "accept:close_success", # agrees to close → episode success
196
+ "reject:close_failed", # rejects close
197
+ "silence", # no response (enables FOLLOW_UP)
198
+ "exit:disqualified", # prospect exits conversation
199
+ }
200
+ ```
201
+
202
+ ---
203
+
204
+ ## 7. Difficulty Configuration
205
+
206
+ ```python
207
+ DIFFICULTY_CONFIG = {
208
+ 1: {
209
+ "max_turns": 20,
210
+ "workflow_steps": ["QUALIFY", "PRESENT", "CLOSE"],
211
+ "num_objections": 0,
212
+ "budget_hidden": False,
213
+ "mode_shift": False,
214
+ "optimal_turns": 5,
215
+ },
216
+ 2: {
217
+ "max_turns": 20,
218
+ "workflow_steps": ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "CLOSE"],
219
+ "num_objections": 1,
220
+ "budget_hidden": True, # revealed after QUALIFY
221
+ "mode_shift": False,
222
+ "optimal_turns": 8,
223
+ },
224
+ 3: {
225
+ "max_turns": 20,
226
+ "workflow_steps": ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO",
227
+ "HANDLE_OBJECTION", "NEGOTIATE", "CLOSE"],
228
+ "num_objections": 2,
229
+ "budget_hidden": True,
230
+ "mode_shift": True, # prospect signals shift at turn 10
231
+ "optimal_turns": 12,
232
+ },
233
+ 4: {
234
+ "max_turns": 20,
235
+ "workflow_steps": "full", # agent must determine correct path
236
+ "num_objections": 2,
237
+ "budget_hidden": True,
238
+ "mode_shift": True,
239
+ "misleading_signals": True, # budget signals are deceptive
240
+ "optimal_turns": 14,
241
+ },
242
+ }
243
+ ```
244
+
245
+ ---
246
+
247
+ ## 8. Reward — Exact Weights
248
+
249
+ ```python
250
+ REWARD_WEIGHTS = {
251
+ "r_outcome": 0.40,
252
+ "r_compliance": 0.30,
253
+ "r_ordering": 0.15,
254
+ "r_efficiency": 0.10,
255
+ "r_format": 0.05,
256
+ }
257
+
258
+ OUTCOME_VALUES = {
259
+ "close_success": 1.0,
260
+ "disqualify_correct": 0.5,
261
+ "turn_limit_reached": -0.3,
262
+ "close_failed": -0.5,
263
+ "three_violations": -0.5,
264
+ }
265
+
266
+ COMPLIANCE_PER_VIOLATION = -0.2 # capped at -1.0
267
+ EFFICIENCY_PER_EXTRA_TURN = -0.05 # capped at -0.3
268
+ FORMAT_PASS = 1.0
269
+ FORMAT_FAIL = -0.1
270
+ ```
271
+
272
+ ---
273
+
274
+ ## 9. Training Rules
275
+
276
+ ### Prompt Format (what gets sent to the LLM)
277
+ ```
278
+ System: You are a B2B sales agent. Follow this workflow strictly:
279
+ {workflow_steps_for_difficulty}
280
+
281
+ Business rules you must never violate:
282
+ {rules_list}
283
+
284
+ Current state:
285
+ - Prospect: {prospect_summary}
286
+ - Stage: {workflow_stage}
287
+ - Steps done: {steps_completed}
288
+ - Turn: {turn_number}/20
289
+
290
+ Prospect said: {prospect_response}
291
+
292
+ Respond with:
293
+ ACTION: <action_type>
294
+ CONTENT: <your message>
295
+ ```
296
+
297
+ ### Response parsing
298
+ - Extract `ACTION:` line → `action_type`
299
+ - Extract `CONTENT:` line → `content`
300
+ - If parsing fails → `r_format = -0.1`, use fallback QUALIFY
301
+
302
+ ### GRPO config
303
+ ```python
304
+ GRPOConfig(
305
+ num_generations=8, # rollouts per prompt
306
+ max_new_tokens=256,
307
+ temperature=0.8,
308
+ learning_rate=1e-5,
309
+ per_device_train_batch_size=2,
310
+ gradient_accumulation_steps=4,
311
+ )
312
+ ```
313
+
314
+ ---
315
+
316
+ ## 10. What to Monitor During Training
317
+
318
+ Log these every 10 steps. If any of these goes wrong, stop and inspect raw generations:
319
+
320
+ | Metric | Healthy Range | Alarm |
321
+ |--------|--------------|-------|
322
+ | `mean_reward` | Rising | Flat for >50 steps |
323
+ | `mean_r_compliance` | Rising | < -0.5 after step 100 |
324
+ | `violations_per_episode` | Falling | > 3.0 after step 100 |
325
+ | `ordering_rate` | Rising toward 0.85 | < 0.3 after step 150 |
326
+ | `close_success_rate` | Rising | 0 after step 200 |
327
+
328
+ Inspect raw generations every 50 steps. Look for: repeated actions, empty CONTENT, invalid ACTION types, CLOSE before QUALIFY.
329
+
330
+ ---
331
+
332
+ ## 11. Save Model Correctly
333
+
334
+ ```python
335
+ # CORRECT — do not deviate
336
+ model.save_pretrained_merged(
337
+ "salespath_trained",
338
+ tokenizer,
339
+ save_method="merged_16bit", # NOT naive upcast of 4bit
340
+ )
341
+ ```
342
+
343
+ Never do: `model.save_pretrained()` on a 4-bit model without merging first.
344
+
345
+ ---
346
+
347
+ ## 12. File Ownership (2-Person Team)
348
+
349
+ | Person | Files |
350
+ |--------|-------|
351
+ | **A** | `models.py`, `server/salespath_environment.py`, `server/prospect_simulator.py`, `server/rules.py`, `server/task_bank.py`, `server/app.py`, `Dockerfile` |
352
+ | **B** | `server/reward.py`, `training/grpo_train.py`, `training/rollout.py`, `training/curriculum.py`, `training/colab_train.ipynb`, `client.py` |
353
+
354
+ Both: `README.md`, `openenv.yaml`, `pyproject.toml`
push_to_hub.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import HfApi
2
+ import os
3
+
4
+ REPO_ID = "Imsachin010/salespath-env"
5
+ FOLDER_PATH = "."
6
+
7
+ IGNORE_PATTERNS = [
8
+ "*.pyc",
9
+ "**/__pycache__/**",
10
+ ".git/**",
11
+ ".spa/**",
12
+ ".SPA/**",
13
+ "*.egg-info/**",
14
+ "push_to_hub.py",
15
+ "salespath_env/server/Dockerfile", # root Dockerfile is used instead
16
+ "training/**", # exclude training scripts from Space
17
+ ]
18
+
19
+ def main():
20
+ api = HfApi()
21
+
22
+ api.create_repo(
23
+ repo_id=REPO_ID,
24
+ repo_type="space",
25
+ space_sdk="docker",
26
+ exist_ok=True,
27
+ private=False,
28
+ )
29
+
30
+ api.upload_folder(
31
+ folder_path=FOLDER_PATH,
32
+ repo_id=REPO_ID,
33
+ repo_type="space",
34
+ ignore_patterns=IGNORE_PATTERNS,
35
+ commit_message="Deploy SalesPath Environment",
36
+ )
37
+
38
+ print(
39
+ f"Live Space URL:\n"
40
+ f"https://{REPO_ID.replace('/', '-')}.hf.space"
41
+ )
42
+
43
+ if __name__ == "__main__":
44
+ main()
pyproject.toml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=42"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "salespath_env"
7
+ version = "0.1.0"
8
+ requires-python = ">=3.10"
9
+ dependencies = [
10
+ "openenv",
11
+ "fastapi",
12
+ "uvicorn",
13
+ "pydantic>=2.0",
14
+ "trl>=0.8.0",
15
+ "unsloth",
16
+ "torch",
17
+ "transformers",
18
+ ]
19
+
20
+ [tool.setuptools.packages.find]
21
+ where = ["."]
22
+ include = ["salespath_env*"]
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ fastapi>=0.110.0
2
+ uvicorn[standard]>=0.29.0
3
+ pydantic>=2.0
4
+ openenv-core>=0.2.3
salespath_env.egg-info/PKG-INFO ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: salespath-env
3
+ Version: 0.1.0
4
+ Requires-Python: >=3.10
5
+ Requires-Dist: openenv-core>=0.2.3
6
+ Requires-Dist: fastapi>=0.110.0
7
+ Requires-Dist: uvicorn[standard]>=0.29.0
8
+ Requires-Dist: pydantic>=2.0
salespath_env.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ pyproject.toml
3
+ salespath_env/__init__.py
4
+ salespath_env/client.py
5
+ salespath_env/models.py
6
+ salespath_env.egg-info/PKG-INFO
7
+ salespath_env.egg-info/SOURCES.txt
8
+ salespath_env.egg-info/dependency_links.txt
9
+ salespath_env.egg-info/requires.txt
10
+ salespath_env.egg-info/top_level.txt
11
+ salespath_env/server/__init__.py
12
+ salespath_env/server/app.py
13
+ salespath_env/server/prospect_simulator.py
14
+ salespath_env/server/reward.py
15
+ salespath_env/server/rules.py
16
+ salespath_env/server/salespath_environment.py
17
+ salespath_env/server/task_bank.py
salespath_env.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
salespath_env.egg-info/requires.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ openenv-core>=0.2.3
2
+ fastapi>=0.110.0
3
+ uvicorn[standard]>=0.29.0
4
+ pydantic>=2.0
salespath_env.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ salespath_env
salespath_env/README.md ADDED
File without changes
salespath_env/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """SalesPath OpenEnv package."""
2
+
salespath_env/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (218 Bytes). View file
 
salespath_env/__pycache__/client.cpython-313.pyc ADDED
Binary file (3.56 kB). View file
 
salespath_env/__pycache__/models.cpython-313.pyc ADDED
Binary file (3.36 kB). View file
 
salespath_env/client.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # salespath_env/client.py
2
+
3
+ from typing import Any, Dict
4
+
5
+ from openenv.core import EnvClient
6
+ from openenv.core.client_types import StepResult
7
+
8
+ from .models import (
9
+ SalesPathAction,
10
+ SalesPathObservation,
11
+ SalesPathState,
12
+ )
13
+
14
+
15
+ class SalesPathEnv(EnvClient[SalesPathAction, SalesPathObservation, SalesPathState]):
16
+
17
+ # ------------------------------------------------------------------ #
18
+ # Abstract method implementations required by EnvClient #
19
+ # ------------------------------------------------------------------ #
20
+
21
+ def _step_payload(self, action: SalesPathAction) -> Dict[str, Any]:
22
+ """Serialise action → JSON dict for the WebSocket server.
23
+ WSStepMessage.data IS the action dict directly (no wrapper key).
24
+ """
25
+ return action.model_dump(exclude={"metadata"})
26
+
27
+ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[SalesPathObservation]:
28
+ """Deserialise server JSON → StepResult[SalesPathObservation]."""
29
+ # Server may nest obs under an 'observation' key
30
+ obs_data = payload.get("observation", payload)
31
+ obs = SalesPathObservation(**obs_data)
32
+ return StepResult(
33
+ observation=obs,
34
+ reward=payload.get("reward", obs.reward),
35
+ done=payload.get("done", obs.done),
36
+ )
37
+
38
+ def _parse_state(self, payload: Dict[str, Any]) -> SalesPathState:
39
+ """Deserialise server JSON → SalesPathState."""
40
+ state_data = payload.get("state", payload)
41
+ return SalesPathState(**state_data)
42
+
43
+ # ------------------------------------------------------------------ #
44
+ # Convenience wrappers that return the unwrapped observation directly #
45
+ # ------------------------------------------------------------------ #
46
+
47
+ @staticmethod
48
+ def _with_step_fields(
49
+ result: StepResult[SalesPathObservation],
50
+ ) -> SalesPathObservation:
51
+ """
52
+ Keep observation fields in sync with StepResult wrapper fields.
53
+ Some server payloads provide reward/done only at top-level.
54
+ """
55
+ return result.observation.model_copy(
56
+ update={
57
+ "reward": result.reward,
58
+ "done": result.done,
59
+ }
60
+ )
61
+
62
+ async def reset(
63
+ self,
64
+ difficulty: int = 1,
65
+ ) -> SalesPathObservation:
66
+ result = await super().reset(difficulty=difficulty)
67
+ return self._with_step_fields(result)
68
+
69
+ async def step(
70
+ self,
71
+ action_type: str,
72
+ content: str,
73
+ target: str = "",
74
+ ) -> SalesPathObservation:
75
+ action = SalesPathAction(
76
+ action_type=action_type,
77
+ content=content,
78
+ target=target,
79
+ )
80
+ result = await super().step(action)
81
+ return self._with_step_fields(result)
salespath_env/models.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # salespath_env/models.py
2
+
3
+ from __future__ import annotations
4
+
5
+ import uuid
6
+ from typing import Dict, List
7
+ from pydantic import BaseModel, Field
8
+
9
+ # Safe OpenEnv Imports: Use OpenEnv base classes if available,
10
+ # otherwise fall back to Pydantic to bypass security blocks.
11
+ try:
12
+ from openenv.core import Action, Observation, State
13
+ except (ImportError, Exception):
14
+ Action = BaseModel
15
+ Observation = BaseModel
16
+ State = BaseModel
17
+
18
+
19
+ VALID_ACTIONS = {
20
+ "PROSPECT",
21
+ "QUALIFY",
22
+ "PRESENT",
23
+ "HANDLE_OBJECTION",
24
+ "OFFER_DEMO",
25
+ "NEGOTIATE",
26
+ "CLOSE",
27
+ "FOLLOW_UP",
28
+ "DISQUALIFY",
29
+ }
30
+
31
+
32
+ class SalesPathAction(Action):
33
+ """
34
+ Action sent by the agent to the environment.
35
+ """
36
+
37
+ action_type: str
38
+ content: str
39
+ target: str = ""
40
+
41
+ def is_valid(self) -> bool:
42
+ """
43
+ Strict validation of allowed action types.
44
+ """
45
+ return self.action_type in VALID_ACTIONS
46
+
47
+
48
+ class SalesPathObservation(Observation):
49
+ """
50
+ What the agent is allowed to observe.
51
+ Hidden state must NEVER be exposed here.
52
+ """
53
+
54
+ prospect_response: str = ""
55
+ workflow_stage: str = "START"
56
+
57
+ constraints_violated: List[str] = Field(default_factory=list)
58
+ steps_completed: List[str] = Field(default_factory=list)
59
+
60
+ turn_number: int = 0
61
+
62
+ reward: float = 0.0
63
+ reward_components: Dict = Field(default_factory=dict)
64
+
65
+ done: bool = False
66
+ info: Dict = Field(default_factory=dict)
67
+
68
+
69
+ class SalesPathState(State):
70
+ """
71
+ Internal environment state.
72
+ Includes hidden state not exposed to the agent.
73
+ """
74
+
75
+ episode_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
76
+
77
+ prospect_profile: Dict = Field(default_factory=dict)
78
+ conversation_history: List[Dict] = Field(default_factory=list)
79
+
80
+ workflow_stage: str = "START"
81
+ required_workflow: List[str] = Field(default_factory=list)
82
+
83
+ steps_completed: List[str] = Field(default_factory=list)
84
+ constraints_violated: List[str] = Field(default_factory=list)
85
+
86
+ objections_handled: int = 0
87
+ turn_number: int = 0
88
+ difficulty: int = 1
89
+
90
+ done: bool = False
91
+
92
+ # Hidden state — NEVER exposed in Observation
93
+ hidden_state: Dict = Field(default_factory=dict)
salespath_env/openenv.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "salespath_env"
3
+ version = "0.1.0"
4
+ dependencies = [
5
+ "openenv",
6
+ "fastapi",
7
+ "uvicorn",
8
+ "pydantic>=2.0",
9
+ "trl>=0.8.0",
10
+ "unsloth",
11
+ "torch",
12
+ "transformers",
13
+ ]
salespath_env/pyproject.toml ADDED
File without changes
salespath_env/server/Dockerfile ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ARG BASE_IMAGE=openenv-base:latest
2
+ FROM ${BASE_IMAGE}
3
+
4
+ COPY server/requirements.txt /tmp/requirements.txt
5
+ RUN pip install --no-cache-dir -r /tmp/requirements.txt
6
+
7
+ COPY salespath_env/ /app/salespath_env/
8
+
9
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
10
+ CMD curl -f http://localhost:8000/health || exit 1
11
+
12
+ CMD ["uvicorn", "salespath_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
salespath_env/server/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """SalesPath environment server package."""
2
+
salespath_env/server/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (236 Bytes). View file
 
salespath_env/server/__pycache__/app.cpython-313.pyc ADDED
Binary file (455 Bytes). View file
 
salespath_env/server/__pycache__/prospect_simulator.cpython-313.pyc ADDED
Binary file (4.15 kB). View file
 
salespath_env/server/__pycache__/reward.cpython-313.pyc ADDED
Binary file (3.1 kB). View file
 
salespath_env/server/__pycache__/rules.cpython-313.pyc ADDED
Binary file (6.24 kB). View file
 
salespath_env/server/__pycache__/salespath_environment.cpython-313.pyc ADDED
Binary file (6.4 kB). View file
 
salespath_env/server/__pycache__/task_bank.cpython-313.pyc ADDED
Binary file (2.72 kB). View file
 
salespath_env/server/app.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # salespath_env/server/app.py
2
+
3
+ from openenv.core.env_server import create_fastapi_app
4
+
5
+ from ..models import (
6
+ SalesPathAction,
7
+ SalesPathObservation,
8
+ )
9
+ from .salespath_environment import (
10
+ SalesPathEnvironment,
11
+ )
12
+
13
+
14
+ app = create_fastapi_app(
15
+ SalesPathEnvironment,
16
+ SalesPathAction,
17
+ SalesPathObservation,
18
+ )
salespath_env/server/prospect_simulator.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # salespath_env/server/prospect_simulator.py
2
+
3
+ from ..models import SalesPathAction, SalesPathState
4
+
5
+
6
+ RESPONSE_TEXT = {
7
+ "open:positive_signal": "That sounds interesting. Tell me more about how this works.",
8
+ "open:neutral_signal": "I see. We're evaluating a few options at the moment.",
9
+
10
+ "objection:price": "The pricing seems higher than what we budgeted for.",
11
+ "objection:timing": "The timing isn't ideal — we're in the middle of a quarter close.",
12
+ "objection:premature_pitch": (
13
+ "I'm not sure we're ready to discuss solutions yet. "
14
+ "What do you know about our current situation?"
15
+ ),
16
+
17
+ "deflect:budget_not_discussed": (
18
+ "We haven't really talked about what we're looking for yet."
19
+ ),
20
+ "deflect:stall": (
21
+ "Let me get back to you on this. A lot is happening on our end."
22
+ ),
23
+
24
+ "accept:demo_scheduled": (
25
+ "Yes, let's set up a demo. What time works next week?"
26
+ ),
27
+ "accept:close_success": (
28
+ "Alright, I think we can move forward with this. "
29
+ "Send over the paperwork."
30
+ ),
31
+
32
+ "reject:close_failed": (
33
+ "I don't think we're ready to commit at this point."
34
+ ),
35
+
36
+ "silence": "",
37
+
38
+ "exit:disqualified": (
39
+ "I think we're done here. This isn't the right fit."
40
+ ),
41
+ }
42
+
43
+
44
+ class ProspectSimulator:
45
+ """
46
+ Pure rule-based simulator.
47
+ No LLM. No transformers. Deterministic behavior.
48
+ """
49
+
50
+ def respond(
51
+ self,
52
+ action: SalesPathAction,
53
+ state: SalesPathState,
54
+ ) -> tuple[str, str]:
55
+ """
56
+ Returns:
57
+ (response_token, response_text)
58
+ """
59
+
60
+ token = self._get_token(action, state)
61
+ text = RESPONSE_TEXT[token]
62
+
63
+ return token, text
64
+
65
+ def _get_token(
66
+ self,
67
+ action: SalesPathAction,
68
+ state: SalesPathState,
69
+ ) -> str:
70
+ atype = action.action_type
71
+ difficulty = state.difficulty
72
+ turn = state.turn_number
73
+ profile = state.prospect_profile
74
+ hidden = state.hidden_state
75
+ objections = state.objections_handled
76
+
77
+ # -----------------------------
78
+ # Rule-triggered responses first
79
+ # -----------------------------
80
+
81
+ if state.constraints_violated:
82
+ latest = state.constraints_violated[-1]
83
+
84
+ if latest == "R01":
85
+ return "objection:premature_pitch"
86
+
87
+ if latest == "R03":
88
+ return "deflect:budget_not_discussed"
89
+
90
+ # -----------------------------
91
+ # Action-based responses
92
+ # -----------------------------
93
+
94
+ if atype == "PROSPECT":
95
+ return "open:positive_signal"
96
+
97
+ if atype == "QUALIFY":
98
+ # Reveal budget if hidden
99
+ if profile.get("budget_signal") == "unknown":
100
+ state.prospect_profile["budget_signal"] = hidden.get(
101
+ "revealed_budget",
102
+ "medium",
103
+ )
104
+
105
+ return "open:neutral_signal"
106
+
107
+ if atype == "PRESENT":
108
+ if difficulty >= 2:
109
+ if objections == 0:
110
+ return "objection:price"
111
+
112
+ return "open:positive_signal"
113
+
114
+ if atype == "HANDLE_OBJECTION":
115
+ state.objections_handled += 1
116
+
117
+ required_objections = hidden.get("num_objections", 1)
118
+
119
+ if state.objections_handled >= required_objections:
120
+ return "open:positive_signal"
121
+
122
+ if objections == 0:
123
+ return "objection:timing"
124
+
125
+ return "open:positive_signal"
126
+
127
+ if atype == "OFFER_DEMO":
128
+ return "accept:demo_scheduled"
129
+
130
+ if atype == "NEGOTIATE":
131
+ return "open:neutral_signal"
132
+
133
+ if atype == "CLOSE":
134
+ true_budget = hidden.get("true_budget", 0.7)
135
+ close_threshold = hidden.get("close_threshold", 0.5)
136
+ decision_maker = profile.get("decision_maker", True)
137
+
138
+ if (
139
+ true_budget >= close_threshold
140
+ and decision_maker
141
+ ):
142
+ return "accept:close_success"
143
+
144
+ return "reject:close_failed"
145
+
146
+ if atype == "FOLLOW_UP":
147
+ return "open:neutral_signal"
148
+
149
+ if atype == "DISQUALIFY":
150
+ return "exit:disqualified"
151
+
152
+ # -----------------------------
153
+ # Difficulty 3+ mode shift
154
+ # -----------------------------
155
+
156
+ if difficulty >= 3 and turn >= 10:
157
+ import random
158
+
159
+ if random.random() < hidden.get("stall_probability", 0.0):
160
+ return "deflect:stall"
161
+
162
+ return "open:neutral_signal"
salespath_env/server/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ pydantic>=2.0
salespath_env/server/reward.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # salespath_env/server/reward.py
2
+
3
+ from ..models import SalesPathAction, SalesPathState
4
+
5
+
6
+ DIFFICULTY_OPTIMAL_TURNS = {
7
+ 1: 5,
8
+ 2: 8,
9
+ 3: 12,
10
+ 4: 14,
11
+ }
12
+
13
+
14
+ def compute_reward(
15
+ state: SalesPathState,
16
+ action: SalesPathAction,
17
+ response_token: str,
18
+ new_violations: list[str],
19
+ episode_done: bool,
20
+ ) -> tuple[float, dict]:
21
+ """
22
+ Returns:
23
+ (total_reward, reward_components)
24
+ """
25
+
26
+ components = {}
27
+
28
+ # --------------------------------------------------
29
+ # 1. Outcome Reward (terminal only)
30
+ # --------------------------------------------------
31
+
32
+ r_outcome = 0.0
33
+
34
+ if episode_done:
35
+ if response_token == "accept:close_success":
36
+ r_outcome = 1.0
37
+
38
+ elif action.action_type == "DISQUALIFY":
39
+ if "R08" not in new_violations:
40
+ r_outcome = 0.5
41
+ else:
42
+ r_outcome = -0.5
43
+
44
+ elif state.turn_number >= 20:
45
+ r_outcome = -0.3
46
+
47
+ elif len(state.constraints_violated) >= 3:
48
+ r_outcome = -0.5
49
+
50
+ else:
51
+ r_outcome = -0.5
52
+
53
+ components["r_outcome"] = r_outcome
54
+
55
+ # --------------------------------------------------
56
+ # 2. Compliance Reward
57
+ # --------------------------------------------------
58
+
59
+ r_compliance = max(
60
+ -1.0,
61
+ -0.2 * len(new_violations),
62
+ )
63
+
64
+ components["r_compliance"] = r_compliance
65
+
66
+ # --------------------------------------------------
67
+ # 3. Ordering Reward
68
+ # --------------------------------------------------
69
+
70
+ required = state.required_workflow
71
+ completed = state.steps_completed
72
+
73
+ if len(required) > 0 and len(completed) > 0:
74
+ correct = sum(
75
+ 1
76
+ for i in range(min(len(required), len(completed)))
77
+ if required[i] == completed[i]
78
+ )
79
+
80
+ r_ordering = correct / len(required)
81
+
82
+ else:
83
+ r_ordering = 1.0
84
+
85
+ components["r_ordering"] = r_ordering
86
+
87
+ # --------------------------------------------------
88
+ # 4. Efficiency Reward
89
+ # --------------------------------------------------
90
+
91
+ if episode_done:
92
+ optimal = DIFFICULTY_OPTIMAL_TURNS.get(
93
+ state.difficulty,
94
+ 10,
95
+ )
96
+
97
+ extra_turns = max(
98
+ 0,
99
+ state.turn_number - optimal,
100
+ )
101
+
102
+ r_efficiency = max(
103
+ -0.3,
104
+ -0.05 * extra_turns,
105
+ )
106
+
107
+ else:
108
+ r_efficiency = 0.0
109
+
110
+ components["r_efficiency"] = r_efficiency
111
+
112
+ # --------------------------------------------------
113
+ # 5. Format Reward
114
+ # --------------------------------------------------
115
+
116
+ r_format = 1.0 if action.is_valid() else -0.1
117
+ components["r_format"] = r_format
118
+
119
+ # --------------------------------------------------
120
+ # Final Weighted Reward
121
+ # --------------------------------------------------
122
+
123
+ weights = {
124
+ "r_outcome": 0.40,
125
+ "r_compliance": 0.30,
126
+ "r_ordering": 0.15,
127
+ "r_efficiency": 0.10,
128
+ "r_format": 0.05,
129
+ }
130
+
131
+ total_reward = sum(
132
+ weights[key] * components[key]
133
+ for key in weights
134
+ )
135
+
136
+ components["total"] = total_reward
137
+
138
+ return total_reward, components
salespath_env/server/rules.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # salespath_env/server/rules.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Callable
5
+
6
+ from ..models import SalesPathAction, SalesPathState
7
+
8
+
9
+ @dataclass
10
+ class BusinessRule:
11
+ """
12
+ Returns True when the rule is VIOLATED.
13
+ """
14
+
15
+ rule_id: str
16
+ name: str
17
+ description: str
18
+ check: Callable[[SalesPathState, SalesPathAction], bool]
19
+
20
+
21
+ def _qualify_before_present(
22
+ state: SalesPathState,
23
+ action: SalesPathAction,
24
+ ) -> bool:
25
+ """
26
+ R01:
27
+ PRESENT before QUALIFY is invalid.
28
+ """
29
+ if action.action_type == "PRESENT":
30
+ return "QUALIFY" not in state.steps_completed
31
+ return False
32
+
33
+
34
+ def _demo_before_negotiate(
35
+ state: SalesPathState,
36
+ action: SalesPathAction,
37
+ ) -> bool:
38
+ """
39
+ R02:
40
+ NEGOTIATE before OFFER_DEMO is invalid.
41
+ """
42
+ if action.action_type == "NEGOTIATE":
43
+ return "OFFER_DEMO" not in state.steps_completed
44
+ return False
45
+
46
+
47
+ def _budget_known_to_negotiate(
48
+ state: SalesPathState,
49
+ action: SalesPathAction,
50
+ ) -> bool:
51
+ """
52
+ R03:
53
+ Cannot NEGOTIATE while budget is unknown.
54
+ """
55
+ if action.action_type == "NEGOTIATE":
56
+ return state.prospect_profile.get("budget_signal") == "unknown"
57
+ return False
58
+
59
+
60
+ def _discount_after_objections(
61
+ state: SalesPathState,
62
+ action: SalesPathAction,
63
+ ) -> bool:
64
+ """
65
+ R04:
66
+ Discount only after 2 objections handled.
67
+ """
68
+ if action.action_type == "NEGOTIATE":
69
+ if "discount" in action.content.lower():
70
+ return state.objections_handled < 2
71
+ return False
72
+
73
+
74
+ def _no_repeat_action(
75
+ state: SalesPathState,
76
+ action: SalesPathAction,
77
+ ) -> bool:
78
+ """
79
+ R05:
80
+ Same action twice in a row is invalid.
81
+ """
82
+ if state.conversation_history:
83
+ last_action = state.conversation_history[-1].get("action_type", "")
84
+ return last_action == action.action_type
85
+ return False
86
+
87
+
88
+ def _prospect_first(
89
+ state: SalesPathState,
90
+ action: SalesPathAction,
91
+ ) -> bool:
92
+ """
93
+ R06:
94
+ First action must be PROSPECT.
95
+ """
96
+ if state.turn_number == 1:
97
+ return action.action_type != "PROSPECT"
98
+ return False
99
+
100
+
101
+ def _followup_timing(
102
+ state: SalesPathState,
103
+ action: SalesPathAction,
104
+ ) -> bool:
105
+ """
106
+ R07:
107
+ FOLLOW_UP only valid after silence.
108
+ If prospect just responded last turn, violation.
109
+ """
110
+ if action.action_type == "FOLLOW_UP":
111
+ if state.conversation_history:
112
+ last_speaker = state.conversation_history[-1].get("speaker", "agent")
113
+ return last_speaker == "prospect"
114
+ return False
115
+
116
+
117
+ def _disqualify_logic(
118
+ state: SalesPathState,
119
+ action: SalesPathAction,
120
+ ) -> bool:
121
+ """
122
+ R08:
123
+ DISQUALIFY only when prospect is genuinely not closeable.
124
+ Violation if prospect is actually closeable.
125
+ """
126
+ if action.action_type == "DISQUALIFY":
127
+ true_budget = state.hidden_state.get("true_budget", 0.5)
128
+ close_threshold = state.hidden_state.get("close_threshold", 0.5)
129
+ decision_maker = state.prospect_profile.get("decision_maker", True)
130
+
131
+ return (true_budget >= close_threshold) and decision_maker
132
+
133
+ return False
134
+
135
+
136
+ def _close_requires_demo(
137
+ state: SalesPathState,
138
+ action: SalesPathAction,
139
+ ) -> bool:
140
+ """
141
+ R09:
142
+ Difficulty 2+ requires OFFER_DEMO before CLOSE.
143
+ """
144
+ if action.action_type == "CLOSE":
145
+ if state.difficulty >= 2:
146
+ return "OFFER_DEMO" not in state.steps_completed
147
+ return False
148
+
149
+
150
+ BUSINESS_RULES = [
151
+ BusinessRule(
152
+ "R01",
153
+ "qualify_before_present",
154
+ "Must QUALIFY before PRESENT",
155
+ _qualify_before_present,
156
+ ),
157
+ BusinessRule(
158
+ "R02",
159
+ "demo_before_negotiate",
160
+ "Must OFFER_DEMO before NEGOTIATE",
161
+ _demo_before_negotiate,
162
+ ),
163
+ BusinessRule(
164
+ "R03",
165
+ "budget_known_to_negotiate",
166
+ "Budget must be known before NEGOTIATE",
167
+ _budget_known_to_negotiate,
168
+ ),
169
+ BusinessRule(
170
+ "R04",
171
+ "discount_after_objections",
172
+ "Discount only after 2 objections handled",
173
+ _discount_after_objections,
174
+ ),
175
+ BusinessRule(
176
+ "R05",
177
+ "no_repeat_action",
178
+ "Cannot repeat same action consecutively",
179
+ _no_repeat_action,
180
+ ),
181
+ BusinessRule(
182
+ "R06",
183
+ "prospect_first",
184
+ "First action must be PROSPECT",
185
+ _prospect_first,
186
+ ),
187
+ BusinessRule(
188
+ "R07",
189
+ "followup_timing",
190
+ "FOLLOW_UP only after prospect silence",
191
+ _followup_timing,
192
+ ),
193
+ BusinessRule(
194
+ "R08",
195
+ "disqualify_logic",
196
+ "DISQUALIFY only when prospect is genuinely unqualified",
197
+ _disqualify_logic,
198
+ ),
199
+ BusinessRule(
200
+ "R09",
201
+ "close_requires_demo",
202
+ "Must OFFER_DEMO before CLOSE (difficulty 2+)",
203
+ _close_requires_demo,
204
+ ),
205
+ ]
206
+
207
+
208
+ def check_rules(
209
+ state: SalesPathState,
210
+ action: SalesPathAction,
211
+ ) -> list[str]:
212
+ """
213
+ Returns list of violated rule IDs.
214
+ """
215
+
216
+ violated = []
217
+
218
+ for rule in BUSINESS_RULES:
219
+ if rule.check(state, action):
220
+ violated.append(rule.rule_id)
221
+
222
+ return violated
salespath_env/server/salespath_environment.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # salespath_env/server/salespath_environment.py
2
+
3
+ import uuid
4
+
5
+ from openenv.core.env_server import Environment
6
+
7
+ from ..models import (
8
+ SalesPathAction,
9
+ SalesPathObservation,
10
+ SalesPathState,
11
+ )
12
+ from .task_bank import sample_profile
13
+ from .rules import check_rules
14
+ from .reward import compute_reward
15
+ from .prospect_simulator import ProspectSimulator
16
+
17
+
18
+ DIFFICULTY_WORKFLOW = {
19
+ 1: [
20
+ "QUALIFY",
21
+ "PRESENT",
22
+ "CLOSE",
23
+ ],
24
+ 2: [
25
+ "QUALIFY",
26
+ "PRESENT",
27
+ "HANDLE_OBJECTION",
28
+ "OFFER_DEMO",
29
+ "CLOSE",
30
+ ],
31
+ 3: [
32
+ "QUALIFY",
33
+ "PRESENT",
34
+ "HANDLE_OBJECTION",
35
+ "OFFER_DEMO",
36
+ "HANDLE_OBJECTION",
37
+ "NEGOTIATE",
38
+ "CLOSE",
39
+ ],
40
+ 4: [], # Agent must determine; DISQUALIFY may be correct
41
+ }
42
+
43
+
44
+ MAX_VIOLATIONS_BEFORE_TERMINATE = 3
45
+ MAX_TURNS = 20
46
+
47
+
48
+ class SalesPathEnvironment(Environment):
49
+ """
50
+ Core OpenEnv environment.
51
+ All business logic routes through:
52
+ - rules.py
53
+ - reward.py
54
+ - prospect_simulator.py
55
+ """
56
+
57
+ def __init__(self):
58
+ super().__init__()
59
+ self._state = SalesPathState()
60
+ self._simulator = ProspectSimulator()
61
+
62
+ def reset(self, difficulty: int = 1) -> SalesPathObservation:
63
+ """
64
+ Start a new episode.
65
+ """
66
+
67
+ profile = sample_profile(difficulty)
68
+
69
+ hidden_state = {
70
+ "true_budget": profile.true_budget,
71
+ "close_threshold": profile.close_threshold,
72
+ "stall_probability": profile.stall_probability,
73
+ "num_objections": {
74
+ 1: 0,
75
+ 2: 1,
76
+ 3: 2,
77
+ 4: 2,
78
+ }[difficulty],
79
+ "revealed_budget": (
80
+ "high"
81
+ if profile.true_budget >= 0.7
82
+ else "medium"
83
+ if profile.true_budget >= 0.4
84
+ else "low"
85
+ ),
86
+ }
87
+
88
+ public_profile = {
89
+ "company_name": profile.company_name,
90
+ "company_size": profile.company_size,
91
+ "industry": profile.industry,
92
+ "budget_signal": profile.budget_signal,
93
+ "pain_points": profile.pain_points,
94
+ "decision_maker": profile.decision_maker,
95
+ }
96
+
97
+ self._state = SalesPathState(
98
+ episode_id=str(uuid.uuid4()),
99
+ prospect_profile=public_profile,
100
+ conversation_history=[],
101
+ workflow_stage="START",
102
+ required_workflow=DIFFICULTY_WORKFLOW[difficulty],
103
+ steps_completed=[],
104
+ constraints_violated=[],
105
+ objections_handled=0,
106
+ turn_number=0,
107
+ difficulty=difficulty,
108
+ done=False,
109
+ hidden_state=hidden_state,
110
+ )
111
+
112
+ intro_message = (
113
+ f"You are engaging {profile.company_name}, "
114
+ f"a {profile.company_size} {profile.industry} company. "
115
+ f"Pain points: {', '.join(profile.pain_points)}. "
116
+ f"Begin the sales conversation."
117
+ )
118
+
119
+ return SalesPathObservation(
120
+ prospect_response=intro_message,
121
+ workflow_stage="START",
122
+ constraints_violated=[],
123
+ steps_completed=[],
124
+ turn_number=0,
125
+ reward=0.0,
126
+ reward_components={},
127
+ done=False,
128
+ info={
129
+ "difficulty": difficulty,
130
+ "episode_id": self._state.episode_id,
131
+ },
132
+ )
133
+
134
+ def step(
135
+ self,
136
+ action: SalesPathAction,
137
+ ) -> SalesPathObservation:
138
+ """
139
+ One environment transition.
140
+ """
141
+
142
+ state = self._state
143
+
144
+ # -----------------------------------
145
+ # Advance turn
146
+ # -----------------------------------
147
+
148
+ state.turn_number += 1
149
+
150
+ # -----------------------------------
151
+ # Strict action validation
152
+ # Must return observation, never crash
153
+ # -----------------------------------
154
+
155
+ if not action.is_valid():
156
+ return SalesPathObservation(
157
+ prospect_response="Invalid action type.",
158
+ workflow_stage=state.workflow_stage,
159
+ constraints_violated=list(state.constraints_violated),
160
+ steps_completed=list(state.steps_completed),
161
+ turn_number=state.turn_number,
162
+ reward=-0.2,
163
+ reward_components={
164
+ "r_format": -0.1,
165
+ },
166
+ done=False,
167
+ info={
168
+ "error": (
169
+ f"Invalid action_type: "
170
+ f"{action.action_type}"
171
+ )
172
+ },
173
+ )
174
+
175
+ # -----------------------------------
176
+ # Rule checks
177
+ # -----------------------------------
178
+
179
+ new_violations = check_rules(
180
+ state,
181
+ action,
182
+ )
183
+
184
+ state.constraints_violated.extend(
185
+ new_violations
186
+ )
187
+
188
+ # -----------------------------------
189
+ # Record agent action
190
+ # -----------------------------------
191
+
192
+ state.conversation_history.append(
193
+ {
194
+ "turn": state.turn_number,
195
+ "speaker": "agent",
196
+ "action_type": action.action_type,
197
+ "content": action.content,
198
+ }
199
+ )
200
+
201
+ # -----------------------------------
202
+ # Update workflow state
203
+ # -----------------------------------
204
+
205
+ if action.action_type not in state.steps_completed:
206
+ state.steps_completed.append(
207
+ action.action_type
208
+ )
209
+
210
+ state.workflow_stage = action.action_type
211
+
212
+ # -----------------------------------
213
+ # Prospect response
214
+ # -----------------------------------
215
+
216
+ response_token, response_text = (
217
+ self._simulator.respond(
218
+ action,
219
+ state,
220
+ )
221
+ )
222
+
223
+ state.conversation_history.append(
224
+ {
225
+ "turn": state.turn_number,
226
+ "speaker": "prospect",
227
+ "response_token": response_token,
228
+ "text": response_text,
229
+ }
230
+ )
231
+
232
+ # -----------------------------------
233
+ # Episode termination
234
+ # -----------------------------------
235
+
236
+ terminal_actions = {
237
+ "CLOSE",
238
+ "DISQUALIFY",
239
+ }
240
+
241
+ too_many_violations = (
242
+ len(state.constraints_violated)
243
+ >= MAX_VIOLATIONS_BEFORE_TERMINATE
244
+ )
245
+
246
+ turn_limit_reached = (
247
+ state.turn_number >= MAX_TURNS
248
+ )
249
+
250
+ done = (
251
+ action.action_type in terminal_actions
252
+ or too_many_violations
253
+ or turn_limit_reached
254
+ )
255
+
256
+ state.done = done
257
+
258
+ # -----------------------------------
259
+ # Reward
260
+ # -----------------------------------
261
+
262
+ total_reward, components = (
263
+ compute_reward(
264
+ state=state,
265
+ action=action,
266
+ response_token=response_token,
267
+ new_violations=new_violations,
268
+ episode_done=done,
269
+ )
270
+ )
271
+
272
+ return SalesPathObservation(
273
+ prospect_response=response_text,
274
+ workflow_stage=state.workflow_stage,
275
+ constraints_violated=list(
276
+ state.constraints_violated
277
+ ),
278
+ steps_completed=list(
279
+ state.steps_completed
280
+ ),
281
+ turn_number=state.turn_number,
282
+ reward=total_reward,
283
+ reward_components=components,
284
+ done=done,
285
+ info={
286
+ "response_token": response_token,
287
+ "new_violations": new_violations,
288
+ "episode_id": state.episode_id,
289
+ },
290
+ )
291
+
292
+ @property
293
+ def state(self) -> SalesPathState:
294
+ return self._state
salespath_env/server/task_bank.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # salespath_env/server/task_bank.py
2
+
3
+ import random
4
+ from dataclasses import dataclass
5
+
6
+
7
+ @dataclass
8
+ class ProspectProfile:
9
+ company_name: str
10
+ company_size: str # small / medium / enterprise
11
+ industry: str
12
+ budget_signal: str # high / medium / low / unknown
13
+ pain_points: list[str]
14
+ decision_maker: bool
15
+
16
+ # Hidden values — never exposed directly to agent
17
+ true_budget: float # 0.0 → 1.0
18
+ close_threshold: float
19
+ stall_probability: float
20
+
21
+
22
+ # -------------------------
23
+ # LEVEL 1 — Easy
24
+ # budget known
25
+ # decision maker present
26
+ # close is usually possible
27
+ # -------------------------
28
+
29
+ PROFILES_L1 = [
30
+ ProspectProfile(
31
+ company_name="Meridian Retail",
32
+ company_size="medium",
33
+ industry="retail",
34
+ budget_signal="high",
35
+ pain_points=[
36
+ "manual inventory tracking",
37
+ "slow reporting",
38
+ ],
39
+ decision_maker=True,
40
+ true_budget=0.8,
41
+ close_threshold=0.5,
42
+ stall_probability=0.0,
43
+ ),
44
+
45
+ ProspectProfile(
46
+ company_name="Northline Foods",
47
+ company_size="small",
48
+ industry="food distribution",
49
+ budget_signal="medium",
50
+ pain_points=[
51
+ "supplier delays",
52
+ "inventory mismatch",
53
+ ],
54
+ decision_maker=True,
55
+ true_budget=0.6,
56
+ close_threshold=0.5,
57
+ stall_probability=0.0,
58
+ ),
59
+ ]
60
+
61
+
62
+ # -------------------------
63
+ # LEVEL 2 — Medium
64
+ # budget hidden initially
65
+ # one objection expected
66
+ # -------------------------
67
+
68
+ PROFILES_L2 = [
69
+ ProspectProfile(
70
+ company_name="Apex Logistics",
71
+ company_size="enterprise",
72
+ industry="logistics",
73
+ budget_signal="unknown",
74
+ pain_points=[
75
+ "route optimization",
76
+ "driver coordination",
77
+ "fuel tracking",
78
+ ],
79
+ decision_maker=True,
80
+ true_budget=0.7,
81
+ close_threshold=0.5,
82
+ stall_probability=0.0,
83
+ ),
84
+
85
+ ProspectProfile(
86
+ company_name="Vertex Supply",
87
+ company_size="medium",
88
+ industry="manufacturing",
89
+ budget_signal="unknown",
90
+ pain_points=[
91
+ "vendor visibility",
92
+ "purchase delays",
93
+ ],
94
+ decision_maker=True,
95
+ true_budget=0.55,
96
+ close_threshold=0.5,
97
+ stall_probability=0.0,
98
+ ),
99
+ ]
100
+
101
+
102
+ # -------------------------
103
+ # LEVEL 3 — Hard
104
+ # budget hidden
105
+ # 2 objections
106
+ # possible stalling
107
+ # decision maker may be absent
108
+ # -------------------------
109
+
110
+ PROFILES_L3 = [
111
+ ProspectProfile(
112
+ company_name="Nova Financial",
113
+ company_size="enterprise",
114
+ industry="finance",
115
+ budget_signal="unknown",
116
+ pain_points=[
117
+ "compliance reporting",
118
+ "audit trails",
119
+ "data silos",
120
+ ],
121
+ decision_maker=False,
122
+ true_budget=0.6,
123
+ close_threshold=0.55,
124
+ stall_probability=0.3,
125
+ ),
126
+
127
+ ProspectProfile(
128
+ company_name="Atlas Health",
129
+ company_size="enterprise",
130
+ industry="healthcare",
131
+ budget_signal="unknown",
132
+ pain_points=[
133
+ "patient workflow delays",
134
+ "reporting compliance",
135
+ ],
136
+ decision_maker=False,
137
+ true_budget=0.65,
138
+ close_threshold=0.55,
139
+ stall_probability=0.25,
140
+ ),
141
+ ]
142
+
143
+
144
+ # -------------------------
145
+ # LEVEL 4 — Trap cases
146
+ # misleading signals
147
+ # correct action may be DISQUALIFY
148
+ # -------------------------
149
+
150
+ PROFILES_L4 = [
151
+ ProspectProfile(
152
+ company_name="Cipher Tech",
153
+ company_size="small",
154
+ industry="technology",
155
+ budget_signal="high", # misleading
156
+ pain_points=[
157
+ "security",
158
+ "compliance",
159
+ ],
160
+ decision_maker=True,
161
+ true_budget=0.2,
162
+ close_threshold=0.5,
163
+ stall_probability=0.5,
164
+ ),
165
+
166
+ ProspectProfile(
167
+ company_name="BluePeak Studio",
168
+ company_size="small",
169
+ industry="creative agency",
170
+ budget_signal="high", # misleading
171
+ pain_points=[
172
+ "project visibility",
173
+ "client reporting",
174
+ ],
175
+ decision_maker=True,
176
+ true_budget=0.25,
177
+ close_threshold=0.5,
178
+ stall_probability=0.4,
179
+ ),
180
+ ]
181
+
182
+
183
+ ALL_PROFILES = {
184
+ 1: PROFILES_L1,
185
+ 2: PROFILES_L2,
186
+ 3: PROFILES_L3,
187
+ 4: PROFILES_L4,
188
+ }
189
+
190
+
191
+ def sample_profile(difficulty: int) -> ProspectProfile:
192
+ """
193
+ Returns one sampled profile for the selected difficulty.
194
+ """
195
+
196
+ if difficulty not in ALL_PROFILES:
197
+ difficulty = 1
198
+
199
+ return random.choice(ALL_PROFILES[difficulty])
training/__init__.py ADDED
File without changes
training/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (172 Bytes). View file
 
training/__pycache__/curriculum.cpython-313.pyc ADDED
Binary file (2.02 kB). View file
 
training/__pycache__/debug_episode.cpython-313.pyc ADDED
Binary file (2.8 kB). View file
 
training/__pycache__/grpo_train.cpython-313.pyc ADDED
Binary file (14.3 kB). View file
 
training/__pycache__/rollout.cpython-313.pyc ADDED
Binary file (5.47 kB). View file
 
training/__pycache__/test_rollout.cpython-313.pyc ADDED
Binary file (1.77 kB). View file
 
training/colab_train.ipynb ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# SalesPath Colab Training\n",
8
+ "\n",
9
+ "This notebook installs dependencies, runs a local environment server, validates rollout, and launches curriculum training."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "!pip install -U pip\n",
19
+ "!pip install fastapi uvicorn pydantic httpx torch transformers trl unsloth openenv"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "# If the repo is not already present, clone it.\n",
29
+ "# !git clone https://github.com/<your-org-or-user>/salespath_env.git\n",
30
+ "# %cd salespath_env\n",
31
+ "\n",
32
+ "%cd /content/salespath_env"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": null,
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": [
41
+ "# Start the OpenEnv-compatible server in background.\n",
42
+ "!nohup python -m uvicorn salespath_env.server.app:app --host 0.0.0.0 --port 8000 > /content/server.log 2>&1 &\n",
43
+ "!sleep 3\n",
44
+ "!python -c \"import httpx; r=httpx.get('http://127.0.0.1:8000/health', timeout=30); print(r.status_code, r.text)\""
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": null,
50
+ "metadata": {},
51
+ "outputs": [],
52
+ "source": [
53
+ "# Rollout smoke test (single episode)\n",
54
+ "!python -m training.test_rollout"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "metadata": {},
61
+ "outputs": [],
62
+ "source": [
63
+ "# Curriculum run (example)\n",
64
+ "!python -m training.grpo_train --steps 30 --env-url http://127.0.0.1:8000 --model-name Qwen/Qwen2.5-0.5B-Instruct"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "markdown",
69
+ "metadata": {},
70
+ "source": [
71
+ "## Optional: Push merged model to Hugging Face\n",
72
+ "\n",
73
+ "Set your token first:\n",
74
+ "\n",
75
+ "```python\n",
76
+ "import os\n",
77
+ "os.environ['HF_TOKEN'] = 'hf_xxx'\n",
78
+ "```\n",
79
+ "\n",
80
+ "Then run:\n",
81
+ "\n",
82
+ "```bash\n",
83
+ "python -m training.grpo_train --steps 100 --push-merged --hub-repo Imsachin010/salespath-qwen25-7b\n",
84
+ "```"
85
+ ]
86
+ }
87
+ ],
88
+ "metadata": {
89
+ "kernelspec": {
90
+ "display_name": "Python 3",
91
+ "language": "python",
92
+ "name": "python3"
93
+ },
94
+ "language_info": {
95
+ "name": "python"
96
+ }
97
+ },
98
+ "nbformat": 4,
99
+ "nbformat_minor": 5
100
+ }
training/curriculum.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training/curriculum.py
2
+
3
+ from dataclasses import dataclass
4
+ import random
5
+
6
+
7
+ @dataclass
8
+ class CurriculumConfig:
9
+ """
10
+ Maps mean reward → difficulty distribution
11
+ """
12
+
13
+ thresholds: dict
14
+
15
+ def get_distribution(
16
+ self,
17
+ mean_reward: float,
18
+ ) -> dict:
19
+ for threshold in sorted(
20
+ self.thresholds.keys(),
21
+ reverse=True,
22
+ ):
23
+ if mean_reward >= threshold:
24
+ return self.thresholds[threshold]
25
+
26
+ return self.thresholds[
27
+ min(self.thresholds.keys())
28
+ ]
29
+
30
+
31
+ DEFAULT_CURRICULUM = CurriculumConfig(
32
+ thresholds={
33
+ 0.0: {
34
+ 1: 0.90,
35
+ 2: 0.10,
36
+ 3: 0.00,
37
+ 4: 0.00,
38
+ },
39
+
40
+ 0.30: {
41
+ 1: 0.50,
42
+ 2: 0.40,
43
+ 3: 0.10,
44
+ 4: 0.00,
45
+ },
46
+
47
+ 0.50: {
48
+ 1: 0.20,
49
+ 2: 0.40,
50
+ 3: 0.35,
51
+ 4: 0.05,
52
+ },
53
+
54
+ 0.65: {
55
+ 1: 0.10,
56
+ 2: 0.30,
57
+ 3: 0.40,
58
+ 4: 0.20,
59
+ },
60
+ }
61
+ )
62
+
63
+
64
+ def sample_difficulty(
65
+ curriculum: CurriculumConfig,
66
+ mean_reward: float,
67
+ ) -> int:
68
+ """
69
+ Sample difficulty from curriculum schedule.
70
+ """
71
+
72
+ dist = curriculum.get_distribution(
73
+ mean_reward
74
+ )
75
+
76
+ return random.choices(
77
+ list(dist.keys()),
78
+ weights=list(dist.values()),
79
+ k=1,
80
+ )[0]
training/debug_episode.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import asyncio
3
+
4
+ from salespath_env.client import SalesPathEnv
5
+
6
+
7
+ async def run_debug(env_url: str, difficulty: int):
8
+ actions = [
9
+ ("PRESENT", "pitch too early"),
10
+ ("PRESENT", "repeat pitch"),
11
+ ("PRESENT", "repeat pitch again"),
12
+ ]
13
+
14
+ async with SalesPathEnv(base_url=env_url) as env:
15
+ obs = await env.reset(difficulty=difficulty)
16
+ print("RESET")
17
+ print(f" turn={obs.turn_number} done={obs.done} reward={obs.reward}")
18
+ print(f" response={obs.prospect_response}")
19
+
20
+ for idx, (action_type, content) in enumerate(actions, start=1):
21
+ obs = await env.step(action_type=action_type, content=content, target="")
22
+ print(f"\nSTEP {idx} action={action_type}")
23
+ print(f" turn={obs.turn_number} done={obs.done} reward={obs.reward}")
24
+ print(f" violations={obs.constraints_violated}")
25
+ print(f" new_violations={obs.info.get('new_violations')}")
26
+ print(f" components={obs.reward_components}")
27
+ if obs.done:
28
+ break
29
+
30
+
31
+ def parse_args():
32
+ parser = argparse.ArgumentParser(description="Debug stateful episode transitions.")
33
+ parser.add_argument("--env-url", default="http://127.0.0.1:8000")
34
+ parser.add_argument("--difficulty", type=int, default=2)
35
+ return parser.parse_args()
36
+
37
+
38
+ if __name__ == "__main__":
39
+ args = parse_args()
40
+ asyncio.run(run_debug(args.env_url, args.difficulty))
training/grpo_train.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import asyncio
3
+ import ast
4
+ import os
5
+ import re
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+
11
+ from training.curriculum import DEFAULT_CURRICULUM, sample_difficulty
12
+ from training.rollout import run_episode
13
+
14
+
15
+ DEFAULT_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
16
+ DEFAULT_ENV_URL = "http://127.0.0.1:8000"
17
+ VALID_ACTIONS = {
18
+ "PROSPECT",
19
+ "QUALIFY",
20
+ "PRESENT",
21
+ "HANDLE_OBJECTION",
22
+ "OFFER_DEMO",
23
+ "NEGOTIATE",
24
+ "CLOSE",
25
+ "FOLLOW_UP",
26
+ "DISQUALIFY",
27
+ }
28
+ WORKFLOW_MAP = {
29
+ 1: ["QUALIFY", "PRESENT", "CLOSE"],
30
+ 2: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "CLOSE"],
31
+ 3: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "HANDLE_OBJECTION", "NEGOTIATE", "CLOSE"],
32
+ 4: [],
33
+ }
34
+
35
+
36
+ def _load_model_and_tokenizer(model_name: str):
37
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
38
+ if tokenizer.pad_token is None:
39
+ tokenizer.pad_token = tokenizer.eos_token
40
+ model = AutoModelForCausalLM.from_pretrained(
41
+ model_name,
42
+ dtype="auto",
43
+ device_map="auto",
44
+ )
45
+ return model, tokenizer
46
+
47
+
48
+ async def curriculum_train(
49
+ model,
50
+ tokenizer,
51
+ env_url: str,
52
+ total_steps: int = 100,
53
+ print_every: int = 10,
54
+ ):
55
+ """Curriculum rollout loop to benchmark env + policy behavior."""
56
+ mean_reward = 0.0
57
+ reward_history: list[float] = []
58
+ run_log: list[dict] = []
59
+
60
+ for step in range(total_steps):
61
+ difficulty = sample_difficulty(DEFAULT_CURRICULUM, mean_reward)
62
+ result = await run_episode(
63
+ model=model,
64
+ tokenizer=tokenizer,
65
+ env_url=env_url,
66
+ difficulty=difficulty,
67
+ )
68
+
69
+ reward_history.append(float(result["total_reward"]))
70
+ mean_reward = float(np.mean(reward_history[-20:]))
71
+
72
+ run_log.append(
73
+ {
74
+ "step": step,
75
+ "difficulty": difficulty,
76
+ "reward": float(result["total_reward"]),
77
+ "violations": len(result["violations"]),
78
+ "steps_completed": list(result["steps_completed"]),
79
+ }
80
+ )
81
+
82
+ if step % print_every == 0:
83
+ print(
84
+ f"Step {step:04d} | Difficulty {difficulty} | "
85
+ f"Reward {result['total_reward']:.3f} | Mean(20) {mean_reward:.3f} | "
86
+ f"Violations {len(result['violations'])} | Steps {result['steps_completed']}"
87
+ )
88
+
89
+ return {
90
+ "mean_reward": mean_reward,
91
+ "reward_history": reward_history,
92
+ "run_log": run_log,
93
+ }
94
+
95
+
96
+ def _save_metrics(output_dir: str, metrics: dict):
97
+ output_path = Path(output_dir)
98
+ output_path.mkdir(parents=True, exist_ok=True)
99
+ rewards_path = output_path / "reward_history.txt"
100
+ with rewards_path.open("w", encoding="utf-8") as f:
101
+ for idx, reward in enumerate(metrics["reward_history"]):
102
+ f.write(f"{idx}\t{reward:.6f}\n")
103
+ print(f"Saved reward history to {rewards_path}")
104
+
105
+
106
+ def _extract_action_content(text: str) -> tuple[str, str]:
107
+ action_match = re.search(r"ACTION:\s*(\w+)", text, re.IGNORECASE)
108
+ content_match = re.search(r"CONTENT:\s*(.+?)(?:\n|$)", text, re.IGNORECASE | re.DOTALL)
109
+ action_type = action_match.group(1).upper() if action_match else ""
110
+ content = content_match.group(1).strip() if content_match else ""
111
+ return action_type, content
112
+
113
+
114
+ def _extract_steps_completed(prompt_text: str) -> list[str]:
115
+ match = re.search(r"Steps completed:\s*(\[.*?\])", prompt_text, re.DOTALL)
116
+ if not match:
117
+ return []
118
+ try:
119
+ parsed = ast.literal_eval(match.group(1))
120
+ if isinstance(parsed, list):
121
+ return [str(v).upper() for v in parsed]
122
+ except Exception:
123
+ return []
124
+ return []
125
+
126
+
127
+ def salespath_reward_func(prompts, completions, **kwargs):
128
+ """
129
+ Lightweight GRPO reward signal aligned with project rules.
130
+ Uses format validity + basic workflow order constraints.
131
+ """
132
+ rewards: list[float] = []
133
+
134
+ for prompt, completion in zip(prompts, completions):
135
+ action_type, content = _extract_action_content(completion)
136
+ steps_completed = _extract_steps_completed(prompt)
137
+
138
+ reward = 0.0
139
+
140
+ # Format + valid action
141
+ if action_type in VALID_ACTIONS and content:
142
+ reward += 0.1
143
+ else:
144
+ rewards.append(-0.2)
145
+ continue
146
+
147
+ # Rule hints
148
+ if not steps_completed and action_type != "PROSPECT":
149
+ reward -= 0.2 # R06
150
+ if action_type == "PRESENT" and "QUALIFY" not in steps_completed:
151
+ reward -= 0.2 # R01
152
+ if action_type == "NEGOTIATE" and "OFFER_DEMO" not in steps_completed:
153
+ reward -= 0.2 # R02
154
+ if action_type == "CLOSE" and "OFFER_DEMO" not in steps_completed:
155
+ reward -= 0.2 # R09
156
+
157
+ rewards.append(float(reward))
158
+
159
+ return rewards
160
+
161
+
162
+ def _build_grpo_dataset_rows(num_rows: int = 128):
163
+ rows = []
164
+ prospect_snippets = [
165
+ "We are evaluating options right now.",
166
+ "Budget is tight this quarter.",
167
+ "Can you explain implementation effort?",
168
+ "Pricing seems high compared to alternatives.",
169
+ ]
170
+
171
+ for i in range(num_rows):
172
+ difficulty = (i % 4) + 1
173
+ workflow = WORKFLOW_MAP[difficulty]
174
+ steps_completed = [] if i % 3 == 0 else workflow[: min(len(workflow), i % 2 + 1)]
175
+ prompt = (
176
+ "You are a B2B sales agent.\n\n"
177
+ f"Required workflow steps (in order): {' -> '.join(workflow) if workflow else 'Dynamic'}\n"
178
+ f"Current stage: {'START' if not steps_completed else steps_completed[-1]}\n"
179
+ f"Steps completed: {steps_completed}\n"
180
+ f"Turn: {(i % 8) + 1}/20\n"
181
+ "Business rules: R01..R09 must be respected.\n"
182
+ f"Prospect response: {prospect_snippets[i % len(prospect_snippets)]}\n\n"
183
+ "Respond exactly with:\nACTION: <action>\nCONTENT: <message>"
184
+ )
185
+ rows.append({"prompt": prompt})
186
+ return rows
187
+
188
+
189
+ def run_grpo(args):
190
+ try:
191
+ from datasets import Dataset
192
+ from trl import GRPOConfig, GRPOTrainer
193
+ except Exception as exc:
194
+ raise RuntimeError(
195
+ "Failed to initialize TRL GRPO stack. On this machine, this is usually due to "
196
+ "Windows blocking pyarrow dataset binaries in the local virtualenv. "
197
+ "Use the provided Colab notebook (`training/colab_train.ipynb`) for GRPO runs, "
198
+ "or fix local pyarrow/datasets installation first."
199
+ ) from exc
200
+
201
+ _, tokenizer = _load_model_and_tokenizer(args.model_name)
202
+ rows = _build_grpo_dataset_rows(args.grpo_dataset_size)
203
+ train_dataset = Dataset.from_list(rows)
204
+
205
+ config = GRPOConfig(
206
+ output_dir=args.output_dir,
207
+ learning_rate=args.learning_rate,
208
+ per_device_train_batch_size=args.per_device_train_batch_size,
209
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
210
+ num_generations=args.num_generations,
211
+ max_completion_length=args.max_completion_length,
212
+ temperature=args.temperature,
213
+ logging_steps=args.logging_steps,
214
+ save_steps=args.save_steps,
215
+ max_steps=args.grpo_steps,
216
+ report_to="none",
217
+ )
218
+
219
+ trainer = GRPOTrainer(
220
+ model=args.model_name,
221
+ reward_funcs=salespath_reward_func,
222
+ args=config,
223
+ train_dataset=train_dataset,
224
+ processing_class=tokenizer,
225
+ )
226
+
227
+ trainer.train()
228
+ trainer.save_model(str(Path(args.output_dir) / "grpo_final"))
229
+ print(f"Saved GRPO model to {Path(args.output_dir) / 'grpo_final'}")
230
+
231
+ if args.push_to_hub:
232
+ trainer.push_to_hub(dataset_name="salespath_synthetic_grpo")
233
+ print(f"Pushed trainer model to hub repo: {args.hub_repo}")
234
+
235
+
236
+ def parse_args():
237
+ parser = argparse.ArgumentParser(description="SalesPath training entrypoint.")
238
+ parser.add_argument("--mode", choices=["curriculum", "grpo"], default="curriculum")
239
+ parser.add_argument("--model-name", default=DEFAULT_MODEL)
240
+ parser.add_argument("--env-url", default=DEFAULT_ENV_URL)
241
+ parser.add_argument("--steps", type=int, default=100, help="Curriculum rollout steps.")
242
+ parser.add_argument("--print-every", type=int, default=10)
243
+ parser.add_argument("--output-dir", default="salespath_training_outputs")
244
+ parser.add_argument("--hub-repo", default="Imsachin010/salespath-qwen25-7b")
245
+ parser.add_argument("--push-to-hub", action="store_true")
246
+ parser.add_argument("--push-merged", action="store_true")
247
+
248
+ # GRPO-specific knobs
249
+ parser.add_argument("--grpo-steps", type=int, default=30)
250
+ parser.add_argument("--grpo-dataset-size", type=int, default=128)
251
+ parser.add_argument("--learning-rate", type=float, default=1e-5)
252
+ parser.add_argument("--per-device-train-batch-size", type=int, default=2)
253
+ parser.add_argument("--gradient-accumulation-steps", type=int, default=4)
254
+ parser.add_argument("--num-generations", type=int, default=8)
255
+ parser.add_argument("--max-completion-length", type=int, default=128)
256
+ parser.add_argument("--temperature", type=float, default=0.8)
257
+ parser.add_argument("--logging-steps", type=int, default=10)
258
+ parser.add_argument("--save-steps", type=int, default=100)
259
+
260
+ return parser.parse_args()
261
+
262
+
263
+ async def _run_curriculum_mode(args):
264
+ print(f"Loading model: {args.model_name}")
265
+ model, tokenizer = _load_model_and_tokenizer(args.model_name)
266
+ print(f"Starting curriculum loop against {args.env_url}")
267
+
268
+ metrics = await curriculum_train(
269
+ model=model,
270
+ tokenizer=tokenizer,
271
+ env_url=args.env_url,
272
+ total_steps=args.steps,
273
+ print_every=args.print_every,
274
+ )
275
+ print(f"Final mean reward (last 20): {metrics['mean_reward']:.4f}")
276
+ _save_metrics(args.output_dir, metrics)
277
+
278
+ if args.push_merged:
279
+ hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
280
+ if hasattr(model, "save_pretrained_merged"):
281
+ merged_dir = Path(args.output_dir) / "salespath_trained_merged"
282
+ model.save_pretrained_merged(
283
+ str(merged_dir),
284
+ tokenizer,
285
+ save_method="merged_16bit",
286
+ )
287
+ print(f"Saved merged model to {merged_dir}")
288
+ if hf_token and hasattr(model, "push_to_hub_merged"):
289
+ model.push_to_hub_merged(
290
+ args.hub_repo,
291
+ tokenizer,
292
+ save_method="merged_16bit",
293
+ token=hf_token,
294
+ )
295
+ print(f"Pushed merged model to {args.hub_repo}")
296
+ else:
297
+ print(
298
+ "Model does not support merged save APIs. "
299
+ "Use an Unsloth merged-capable model to enable --push-merged."
300
+ )
301
+
302
+
303
+ async def _main():
304
+ args = parse_args()
305
+ if args.mode == "curriculum":
306
+ await _run_curriculum_mode(args)
307
+ return
308
+
309
+ print("Launching TRL GRPO mode...")
310
+ run_grpo(args)
311
+
312
+
313
+ if __name__ == "__main__":
314
+ asyncio.run(_main())
315
+
training/rollout.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training/rollout.py
2
+
3
+ import re
4
+ import torch
5
+
6
+ from salespath_env.client import SalesPathEnv
7
+ from salespath_env.models import SalesPathObservation
8
+
9
+
10
+ SYSTEM_PROMPT = """
11
+ You are a B2B sales agent.
12
+
13
+ Your goal is to close deals by following a strict workflow.
14
+
15
+ Required workflow steps (in order):
16
+ {workflow}
17
+
18
+ Business rules — NEVER violate these:
19
+
20
+ - R01: Must QUALIFY before PRESENT
21
+ - R02: Must OFFER_DEMO before NEGOTIATE
22
+ - R03: Budget must be known before NEGOTIATE
23
+ - R04: Discount only after 2 objections handled
24
+ - R05: Cannot repeat same action twice in a row
25
+ - R06: First action must always be PROSPECT
26
+ - R07: FOLLOW_UP only after prospect goes silent
27
+ - R08: DISQUALIFY only if prospect is genuinely unqualified
28
+ - R09: Must OFFER_DEMO before CLOSE (difficulty 2+)
29
+
30
+ You must respond EXACTLY in this format:
31
+
32
+ ACTION: <one valid action>
33
+ CONTENT: <your message>
34
+ """
35
+
36
+
37
+ def parse_action(text: str) -> tuple[str, str]:
38
+ """
39
+ Extract ACTION and CONTENT from model output.
40
+ Fallback = QUALIFY if parsing fails.
41
+ """
42
+ action_match = re.search(r"ACTION:\s*(\w+)", text, re.IGNORECASE)
43
+ content_match = re.search(r"CONTENT:\s*(.+?)(?:\n|$)", text, re.IGNORECASE | re.DOTALL)
44
+
45
+ action_type = action_match.group(1).upper() if action_match else "QUALIFY"
46
+ content = content_match.group(1).strip() if content_match else "Tell me more about your current process."
47
+
48
+ return action_type, content
49
+
50
+
51
+ def build_prompt(obs: SalesPathObservation, workflow: list[str], tokenizer) -> str:
52
+ """Build model prompt from environment observation."""
53
+ messages = [
54
+ {
55
+ "role": "system",
56
+ "content": SYSTEM_PROMPT.format(workflow=" -> ".join(workflow)),
57
+ },
58
+ {
59
+ "role": "user",
60
+ "content": (
61
+ f"Prospect response: {obs.prospect_response}\n"
62
+ f"Current stage: {obs.workflow_stage}\n"
63
+ f"Steps completed: {obs.steps_completed}\n"
64
+ f"Turn: {obs.turn_number}/20\n"
65
+ f"Violations so far: {obs.constraints_violated}\n\n"
66
+ "What is your next action?"
67
+ ),
68
+ },
69
+ ]
70
+ return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
71
+
72
+
73
+ async def run_episode(
74
+ model,
75
+ tokenizer,
76
+ env_url: str,
77
+ difficulty: int = 1,
78
+ message_timeout_s: float = 300.0,
79
+ ) -> dict:
80
+ """
81
+ Run one full episode using the stateful OpenEnv client.
82
+ Returns trajectory + rewards.
83
+ """
84
+ DIFFICULTY_WORKFLOW = {
85
+ 1: ["QUALIFY", "PRESENT", "CLOSE"],
86
+ 2: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "CLOSE"],
87
+ 3: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "HANDLE_OBJECTION", "NEGOTIATE", "CLOSE"],
88
+ 4: [],
89
+ }
90
+
91
+ workflow = DIFFICULTY_WORKFLOW[difficulty]
92
+
93
+ async with SalesPathEnv(base_url=env_url) as env:
94
+ obs = await env.reset(difficulty=difficulty)
95
+ trajectory = []
96
+ total_reward = 0.0
97
+
98
+ while not obs.done:
99
+ # --- Model inference (CPU/GPU — no network) ---
100
+ prompt = build_prompt(obs, workflow, tokenizer)
101
+
102
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
103
+
104
+ with torch.no_grad():
105
+ outputs = model.generate(
106
+ **inputs,
107
+ max_new_tokens=128,
108
+ temperature=0.7,
109
+ do_sample=True,
110
+ )
111
+
112
+ generated = tokenizer.decode(
113
+ outputs[0][inputs["input_ids"].shape[1]:],
114
+ skip_special_tokens=True,
115
+ )
116
+
117
+ action_type, content = parse_action(generated)
118
+
119
+ # --- Stateful step via OpenEnv client ---
120
+ obs = await env.step(
121
+ action_type=action_type,
122
+ content=content,
123
+ target="",
124
+ )
125
+
126
+ trajectory.append({
127
+ "prompt": prompt,
128
+ "generated": generated,
129
+ "action_type": action_type,
130
+ "reward": obs.reward,
131
+ "components": obs.reward_components,
132
+ "done": obs.done,
133
+ })
134
+
135
+ total_reward += obs.reward
136
+
137
+ return {
138
+ "trajectory": trajectory,
139
+ "total_reward": total_reward,
140
+ "steps_completed": obs.steps_completed,
141
+ "violations": obs.constraints_violated,
142
+ "difficulty": difficulty,
143
+ }