diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..25cbac1dcf5a9dce767b3d918b1d61d1bf60d67f --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/.spa \ No newline at end of file diff --git a/CODING_APPROACH.md b/CODING_APPROACH.md new file mode 100644 index 0000000000000000000000000000000000000000..68a629a239f4b42f06775bf094f958ebce2e2efd --- /dev/null +++ b/CODING_APPROACH.md @@ -0,0 +1,1028 @@ +# SalesPath — End-to-End Coding Approach +### For Agent Execution. Follow in order. No skipping. + +--- + +## Phase 0: Setup (Do First, ~15 min) + +```bash +# Install OpenEnv +pip install openenv + +# Scaffold the project +openenv init salespath_env +cd salespath_env + +# Install dependencies +pip install -e . + +# Verify scaffold works +uv run server --host 0.0.0.0 --port 8000 +# Should start FastAPI on 8000. Ctrl+C after confirming. +``` + +Edit `pyproject.toml` — add dependencies: +```toml +[project] +name = "salespath_env" +version = "0.1.0" +dependencies = [ + "openenv", + "fastapi", + "uvicorn", + "pydantic>=2.0", + "trl>=0.8.0", + "unsloth", + "torch", + "transformers", +] +``` + +--- + +## Phase 1: Models (Person A) — `models.py` + +Write this file first. Everything else depends on it. + +```python +# salespath_env/models.py +from __future__ import annotations +import uuid +from dataclasses import dataclass, field +from typing import Optional +from openenv.core import Action, Observation, State + +VALID_ACTIONS = { + "PROSPECT", "QUALIFY", "PRESENT", "HANDLE_OBJECTION", + "OFFER_DEMO", "NEGOTIATE", "CLOSE", "FOLLOW_UP", "DISQUALIFY" +} + +class SalesPathAction(Action): + action_type: str + content: str + target: str = "" + + def is_valid(self) -> bool: + return self.action_type in VALID_ACTIONS + + +class SalesPathObservation(Observation): + prospect_response: str = "" + workflow_stage: str = "START" + constraints_violated: list[str] = field(default_factory=list) + steps_completed: list[str] = field(default_factory=list) + turn_number: int = 0 + reward: float = 0.0 + reward_components: dict = field(default_factory=dict) + done: bool = False + info: dict = field(default_factory=dict) + + +class SalesPathState(State): + episode_id: str = field(default_factory=lambda: str(uuid.uuid4())) + prospect_profile: dict = field(default_factory=dict) + conversation_history: list[dict] = field(default_factory=list) + workflow_stage: str = "START" + required_workflow: list[str] = field(default_factory=list) + steps_completed: list[str] = field(default_factory=list) + constraints_violated: list[str] = field(default_factory=list) + objections_handled: int = 0 + turn_number: int = 0 + difficulty: int = 1 + done: bool = False + # Hidden — never expose in Observation + _hidden: dict = field(default_factory=dict) +``` + +--- + +## Phase 2: Task Bank (Person A) — `server/task_bank.py` + +This generates prospect profiles. Keep it simple — 10 profiles per difficulty level. + +```python +# server/task_bank.py +import random +from dataclasses import dataclass + +@dataclass +class ProspectProfile: + company_name: str + company_size: str # "small" / "medium" / "enterprise" + industry: str + budget_signal: str # "high" / "medium" / "low" / "unknown" + pain_points: list[str] + decision_maker: bool + # Hidden — simulator uses these, agent never sees raw values + true_budget: float # 0.0 to 1.0 scale + close_threshold: float # budget needed to close + stall_probability: float # for Level 3+ + + +PROFILES_L1 = [ + ProspectProfile( + company_name="Meridian Retail", + company_size="medium", + industry="retail", + budget_signal="high", + pain_points=["manual inventory tracking", "slow reporting"], + decision_maker=True, + true_budget=0.8, + close_threshold=0.5, + stall_probability=0.0, + ), + # Add 9 more L1 profiles following same pattern + # L1: budget_signal always known, decision_maker always True, close_threshold <= 0.6 +] + +PROFILES_L2 = [ + ProspectProfile( + company_name="Apex Logistics", + company_size="enterprise", + industry="logistics", + budget_signal="unknown", # revealed after QUALIFY + pain_points=["route optimization", "driver coordination", "fuel tracking"], + decision_maker=True, + true_budget=0.7, + close_threshold=0.5, + stall_probability=0.0, + ), + # 9 more L2 profiles: budget hidden, one objection expected +] + +PROFILES_L3 = [ + ProspectProfile( + company_name="Nova Financial", + company_size="enterprise", + industry="finance", + budget_signal="unknown", + pain_points=["compliance reporting", "audit trails", "data silos"], + decision_maker=False, # must navigate to decision maker + true_budget=0.6, + close_threshold=0.55, + stall_probability=0.3, # will stall at turn 10 + ), + # 9 more L3 profiles: budget hidden, two objections, mode shift +] + +PROFILES_L4 = [ + ProspectProfile( + company_name="Cipher Tech", + company_size="small", + industry="technology", + budget_signal="high", # MISLEADING — true_budget is actually low + pain_points=["security", "compliance"], + decision_maker=True, + true_budget=0.2, # can't actually afford it + close_threshold=0.5, + stall_probability=0.5, + ), + # 9 more L4: misleading signals, correct answer is DISQUALIFY +] + +ALL_PROFILES = {1: PROFILES_L1, 2: PROFILES_L2, 3: PROFILES_L3, 4: PROFILES_L4} + +def sample_profile(difficulty: int) -> ProspectProfile: + return random.choice(ALL_PROFILES[difficulty]) +``` + +--- + +## Phase 3: Business Rules (Person A) — `server/rules.py` + +```python +# server/rules.py +from dataclasses import dataclass +from typing import Callable +from ..models import SalesPathAction, SalesPathState + + +@dataclass +class BusinessRule: + rule_id: str + name: str + description: str + check: Callable[[SalesPathState, SalesPathAction], bool] + # Returns True if VIOLATED + + +def _qualify_before_present(state: SalesPathState, action: SalesPathAction) -> bool: + if action.action_type == "PRESENT": + return "QUALIFY" not in state.steps_completed + return False + + +def _demo_before_negotiate(state: SalesPathState, action: SalesPathAction) -> bool: + if action.action_type == "NEGOTIATE": + return "OFFER_DEMO" not in state.steps_completed + return False + + +def _budget_known_to_negotiate(state: SalesPathState, action: SalesPathAction) -> bool: + if action.action_type == "NEGOTIATE": + return state.prospect_profile.get("budget_signal") == "unknown" + return False + + +def _discount_after_objections(state: SalesPathState, action: SalesPathAction) -> bool: + if action.action_type == "NEGOTIATE": + if "discount" in action.content.lower(): + return state.objections_handled < 2 + return False + + +def _no_repeat_action(state: SalesPathState, action: SalesPathAction) -> bool: + if state.conversation_history: + last_action = state.conversation_history[-1].get("action_type", "") + return last_action == action.action_type + return False + + +def _prospect_first(state: SalesPathState, action: SalesPathAction) -> bool: + if state.turn_number == 1: + return action.action_type != "PROSPECT" + return False + + +def _followup_timing(state: SalesPathState, action: SalesPathAction) -> bool: + if action.action_type == "FOLLOW_UP": + if state.conversation_history: + last_speaker = state.conversation_history[-1].get("speaker", "agent") + return last_speaker == "prospect" # prospect just responded + return False + + +def _disqualify_logic(state: SalesPathState, action: SalesPathAction) -> bool: + if action.action_type == "DISQUALIFY": + profile = state.prospect_profile + true_budget = state._hidden.get("true_budget", 0.5) + close_threshold = state._hidden.get("close_threshold", 0.5) + dm = profile.get("decision_maker", True) + # Violation: disqualifying when prospect is actually closeable + return (true_budget >= close_threshold) and dm + return False + + +def _close_requires_demo(state: SalesPathState, action: SalesPathAction) -> bool: + if action.action_type == "CLOSE": + if state.difficulty >= 2: + return "OFFER_DEMO" not in state.steps_completed + return False + + +BUSINESS_RULES = [ + BusinessRule("R01", "qualify_before_present", + "Must QUALIFY before PRESENT", _qualify_before_present), + BusinessRule("R02", "demo_before_negotiate", + "Must OFFER_DEMO before NEGOTIATE", _demo_before_negotiate), + BusinessRule("R03", "budget_known_to_negotiate", + "Budget must be known before NEGOTIATE", _budget_known_to_negotiate), + BusinessRule("R04", "discount_after_objections", + "Discount only after 2 objections", _discount_after_objections), + BusinessRule("R05", "no_repeat_action", + "Cannot repeat same action consecutively", _no_repeat_action), + BusinessRule("R06", "prospect_first", + "First action must be PROSPECT", _prospect_first), + BusinessRule("R07", "followup_timing", + "FOLLOW_UP only after prospect silence", _followup_timing), + BusinessRule("R08", "disqualify_logic", + "DISQUALIFY only when prospect is genuinely unqualified", _disqualify_logic), + BusinessRule("R09", "close_requires_demo", + "Must OFFER_DEMO before CLOSE (Levels 2+)", _close_requires_demo), +] + + +def check_rules(state: SalesPathState, action: SalesPathAction) -> list[str]: + """Returns list of violated rule IDs.""" + return [ + rule.rule_id + for rule in BUSINESS_RULES + if rule.check(state, action) + ] +``` + +--- + +## Phase 4: Prospect Simulator (Person A) — `server/prospect_simulator.py` + +```python +# server/prospect_simulator.py +# PURE RULE-BASED. No LLM. No imports from transformers. + +from ..models import SalesPathState, SalesPathAction + +RESPONSE_TEXT = { + "open:positive_signal": "That sounds interesting. Tell me more about how this works.", + "open:neutral_signal": "I see. We're evaluating a few options at the moment.", + "objection:price": "The pricing seems higher than what we budgeted for.", + "objection:timing": "The timing isn't ideal — we're in the middle of a quarter close.", + "objection:premature_pitch": "I'm not sure we're ready to discuss solutions yet. What do you know about our situation?", + "deflect:budget_not_discussed": "We haven't really talked about what we're looking for yet.", + "deflect:stall": "Let me get back to you on this. A lot is happening on our end.", + "accept:demo_scheduled": "Yes, let's set up a demo. What time works next week?", + "accept:close_success": "Alright, I think we can move forward with this. Send over the paperwork.", + "reject:close_failed": "I don't think we're ready to commit at this point.", + "silence": "", + "exit:disqualified": "I think we're done here. This isn't the right fit.", +} + + +class ProspectSimulator: + + def respond(self, action: SalesPathAction, state: SalesPathState) -> tuple[str, str]: + """ + Returns (response_token, response_text). + Deterministic — same inputs always produce same output. + """ + token = self._get_token(action, state) + text = RESPONSE_TEXT[token] + return token, text + + def _get_token(self, action: SalesPathAction, state: SalesPathState) -> str: + atype = action.action_type + hidden = state._hidden + turn = state.turn_number + profile = state.prospect_profile + objections = state.objections_handled + difficulty = state.difficulty + + # Rule violation responses (priority — check first) + if "R01" in state.constraints_violated[-1:]: + return "objection:premature_pitch" + if "R03" in state.constraints_violated[-1:]: + return "deflect:budget_not_discussed" + + # Action-specific logic + if atype == "PROSPECT": + return "open:positive_signal" + + if atype == "QUALIFY": + # Reveal budget signal if it was hidden + if profile.get("budget_signal") == "unknown": + state.prospect_profile["budget_signal"] = hidden.get("revealed_budget", "medium") + return "open:neutral_signal" + + if atype == "PRESENT": + if difficulty >= 2: + return "objection:price" if objections == 0 else "open:positive_signal" + return "open:positive_signal" + + if atype == "HANDLE_OBJECTION": + state.objections_handled += 1 + if objections + 1 >= hidden.get("num_objections", 1): + return "open:positive_signal" + return "objection:timing" if objections == 0 else "open:positive_signal" + + if atype == "OFFER_DEMO": + return "accept:demo_scheduled" + + if atype == "NEGOTIATE": + return "open:neutral_signal" + + if atype == "CLOSE": + true_budget = hidden.get("true_budget", 0.7) + threshold = hidden.get("close_threshold", 0.5) + if true_budget >= threshold and profile.get("decision_maker", True): + return "accept:close_success" + return "reject:close_failed" + + if atype == "FOLLOW_UP": + return "open:neutral_signal" + + if atype == "DISQUALIFY": + return "exit:disqualified" + + # Mode shift at turn 10 for Level 3+ + if difficulty >= 3 and turn >= 10: + import random + if random.random() < hidden.get("stall_probability", 0.0): + return "deflect:stall" + + return "open:neutral_signal" +``` + +--- + +## Phase 5: Reward Function (Person B) — `server/reward.py` + +```python +# server/reward.py + +from ..models import SalesPathState, SalesPathAction + +DIFFICULTY_OPTIMAL_TURNS = {1: 5, 2: 8, 3: 12, 4: 14} + + +def compute_reward( + state: SalesPathState, + action: SalesPathAction, + response_token: str, + new_violations: list[str], + episode_done: bool, +) -> tuple[float, dict]: + """ + Returns (total_reward, component_dict). + Always returns components — never a single scalar. + """ + components = {} + + # --- Component 1: Outcome (only on terminal step) --- + r_outcome = 0.0 + if episode_done: + if response_token == "accept:close_success": + r_outcome = 1.0 + elif action.action_type == "DISQUALIFY": + # Check if disqualify was correct (no R08 violation) + if "R08" not in new_violations: + r_outcome = 0.5 + else: + r_outcome = -0.5 + elif state.turn_number >= 20: + r_outcome = -0.3 + elif len(state.constraints_violated) >= 3: + r_outcome = -0.5 + else: + r_outcome = -0.5 # failed close + components["r_outcome"] = r_outcome + + # --- Component 2: Compliance --- + total_violations = len(state.constraints_violated) + len(new_violations) + r_compliance = max(-1.0, -0.2 * len(new_violations)) # per-step signal + components["r_compliance"] = r_compliance + + # --- Component 3: Step Ordering --- + required = state.required_workflow + completed = state.steps_completed + if len(required) > 1 and len(completed) > 0: + # Count correct transitions + correct = sum( + 1 for i in range(min(len(completed), len(required))) + if completed[i] == required[i] + ) + r_ordering = correct / len(required) + else: + r_ordering = 1.0 if (not required or action.action_type == required[0]) else 0.0 + components["r_ordering"] = r_ordering + + # --- Component 4: Efficiency --- + if episode_done: + optimal = DIFFICULTY_OPTIMAL_TURNS.get(state.difficulty, 10) + overhead = max(0, state.turn_number - optimal) + r_efficiency = max(-0.3, -0.05 * overhead) + else: + r_efficiency = 0.0 # only computed at episode end + components["r_efficiency"] = r_efficiency + + # --- Component 5: Format --- + r_format = 1.0 if action.is_valid() else -0.1 + components["r_format"] = r_format + + # --- Weighted total --- + weights = { + "r_outcome": 0.40, + "r_compliance": 0.30, + "r_ordering": 0.15, + "r_efficiency": 0.10, + "r_format": 0.05, + } + total = sum(weights[k] * v for k, v in components.items()) + components["total"] = total + + return total, components +``` + +--- + +## Phase 6: Environment Core (Person A) — `server/salespath_environment.py` + +```python +# server/salespath_environment.py +import uuid +from openenv.core.env_server import Environment +from ..models import SalesPathAction, SalesPathObservation, SalesPathState +from .task_bank import sample_profile +from .rules import check_rules, BUSINESS_RULES +from .reward import compute_reward +from .prospect_simulator import ProspectSimulator + +DIFFICULTY_WORKFLOW = { + 1: ["QUALIFY", "PRESENT", "CLOSE"], + 2: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "CLOSE"], + 3: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", + "HANDLE_OBJECTION", "NEGOTIATE", "CLOSE"], + 4: [], # agent must determine; DISQUALIFY may be correct +} + +MAX_VIOLATIONS_BEFORE_TERMINATE = 3 +MAX_TURNS = 20 + + +class SalesPathEnvironment(Environment): + + def __init__(self): + super().__init__() + self._state = SalesPathState() + self._simulator = ProspectSimulator() + + def reset(self, difficulty: int = 1) -> SalesPathObservation: + profile = sample_profile(difficulty) + hidden = { + "true_budget": profile.true_budget, + "close_threshold": profile.close_threshold, + "stall_probability": profile.stall_probability, + "num_objections": {1: 0, 2: 1, 3: 2, 4: 2}[difficulty], + "revealed_budget": ( + "high" if profile.true_budget >= 0.7 + else "medium" if profile.true_budget >= 0.4 + else "low" + ), + } + public_profile = { + "company_name": profile.company_name, + "company_size": profile.company_size, + "industry": profile.industry, + "budget_signal": profile.budget_signal, + "pain_points": profile.pain_points, + "decision_maker": profile.decision_maker, + } + self._state = SalesPathState( + episode_id=str(uuid.uuid4()), + prospect_profile=public_profile, + required_workflow=DIFFICULTY_WORKFLOW[difficulty], + difficulty=difficulty, + ) + self._state._hidden = hidden + + return SalesPathObservation( + prospect_response=( + f"You are engaging {profile.company_name}, a {profile.company_size} " + f"{profile.industry} company. Pain points: {', '.join(profile.pain_points)}. " + f"Begin the sales conversation." + ), + workflow_stage="START", + steps_completed=[], + constraints_violated=[], + turn_number=0, + reward=0.0, + done=False, + info={"difficulty": difficulty, "episode_id": self._state.episode_id}, + ) + + def step(self, action: SalesPathAction) -> SalesPathObservation: + state = self._state + state.turn_number += 1 + + # Validate action format + if not action.is_valid(): + return SalesPathObservation( + prospect_response="Invalid action type.", + workflow_stage=state.workflow_stage, + steps_completed=list(state.steps_completed), + constraints_violated=list(state.constraints_violated), + turn_number=state.turn_number, + reward=-0.2, + done=False, + info={"error": f"Invalid action_type: {action.action_type}", + "r_format": -0.1}, + ) + + # Check business rules + new_violations = check_rules(state, action) + state.constraints_violated.extend(new_violations) + + # Update conversation history + state.conversation_history.append({ + "turn": state.turn_number, + "speaker": "agent", + "action_type": action.action_type, + "content": action.content, + }) + + # Update steps completed + if action.action_type not in state.steps_completed: + state.steps_completed.append(action.action_type) + state.workflow_stage = action.action_type + + # Get prospect response + response_token, response_text = self._simulator.respond(action, state) + state.conversation_history.append({ + "turn": state.turn_number, + "speaker": "prospect", + "response_token": response_token, + "text": response_text, + }) + + # Determine episode termination + terminal_actions = {"CLOSE", "DISQUALIFY"} + too_many_violations = len(state.constraints_violated) >= MAX_VIOLATIONS_BEFORE_TERMINATE + turn_limit = state.turn_number >= MAX_TURNS + done = ( + action.action_type in terminal_actions + or too_many_violations + or turn_limit + ) + state.done = done + + # Compute reward + total_reward, components = compute_reward( + state, action, response_token, new_violations, done + ) + + return SalesPathObservation( + prospect_response=response_text, + workflow_stage=state.workflow_stage, + steps_completed=list(state.steps_completed), + constraints_violated=list(state.constraints_violated), + turn_number=state.turn_number, + reward=total_reward, + reward_components=components, + done=done, + info={ + "response_token": response_token, + "new_violations": new_violations, + "episode_id": state.episode_id, + }, + ) + + @property + def state(self) -> SalesPathState: + return self._state +``` + +--- + +## Phase 7: FastAPI App (Person A) — `server/app.py` + +```python +# server/app.py — thin wrapper only +from openenv.core.env_server import create_fastapi_app +from ..models import SalesPathAction, SalesPathObservation +from .salespath_environment import SalesPathEnvironment + +app = create_fastapi_app( + SalesPathEnvironment, + SalesPathAction, + SalesPathObservation, +) +``` + +--- + +## Phase 8: Client (Person B) — `client.py` + +```python +# client.py +from openenv.core import EnvClient +from .models import SalesPathAction, SalesPathObservation, SalesPathState + + +class SalesPathEnv(EnvClient): + action_type = SalesPathAction + observation_type = SalesPathObservation + state_type = SalesPathState + + async def reset(self, difficulty: int = 1) -> SalesPathObservation: + return await super().reset(difficulty=difficulty) + + async def step(self, action_type: str, content: str, target: str = "") -> SalesPathObservation: + action = SalesPathAction( + action_type=action_type, + content=content, + target=target, + ) + return await super().step(action) +``` + +--- + +## Phase 9: Rollout Function (Person B) — `training/rollout.py` + +```python +# training/rollout.py +import re +from salespath_env.client import SalesPathEnv +from salespath_env.models import SalesPathObservation + +SYSTEM_PROMPT = """You are a B2B sales agent. Your goal is to close deals by following a strict workflow. + +Required workflow steps (in order): {workflow} + +Business rules — NEVER violate these: +- R01: Must QUALIFY before PRESENT +- R02: Must OFFER_DEMO before NEGOTIATE +- R03: Budget must be known before NEGOTIATE +- R04: Discount only after 2 objections handled +- R05: Cannot repeat same action twice in a row +- R06: First action must always be PROSPECT +- R07: FOLLOW_UP only after prospect goes silent +- R08: DISQUALIFY only if prospect is genuinely unqualified +- R09: Must OFFER_DEMO before CLOSE (difficulty 2+) + +Respond EXACTLY in this format: +ACTION: +CONTENT: """ + + +def parse_action(text: str) -> tuple[str, str]: + """Extract ACTION and CONTENT from model output.""" + action_match = re.search(r"ACTION:\s*(\w+)", text, re.IGNORECASE) + content_match = re.search(r"CONTENT:\s*(.+?)(?:\n|$)", text, re.IGNORECASE | re.DOTALL) + + action_type = action_match.group(1).upper() if action_match else "QUALIFY" + content = content_match.group(1).strip() if content_match else "Tell me more about your needs." + + return action_type, content + + +def build_prompt(obs: SalesPathObservation, workflow: list[str], tokenizer) -> str: + messages = [ + {"role": "system", "content": SYSTEM_PROMPT.format(workflow=" → ".join(workflow))}, + {"role": "user", "content": ( + f"Prospect response: {obs.prospect_response}\n" + f"Current stage: {obs.workflow_stage}\n" + f"Steps completed: {obs.steps_completed}\n" + f"Turn: {obs.turn_number}/20\n" + f"Violations so far: {obs.constraints_violated}\n\n" + "What is your next action?" + )}, + ] + return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + +async def run_episode(model, tokenizer, env_url: str, difficulty: int = 1) -> dict: + """Run one full episode. Returns trajectory with rewards.""" + DIFFICULTY_WORKFLOW = { + 1: ["QUALIFY", "PRESENT", "CLOSE"], + 2: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "CLOSE"], + 3: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", + "HANDLE_OBJECTION", "NEGOTIATE", "CLOSE"], + 4: [], + } + workflow = DIFFICULTY_WORKFLOW[difficulty] + + async with SalesPathEnv(base_url=env_url) as env: + obs = await env.reset(difficulty=difficulty) + trajectory = [] + total_reward = 0.0 + + while not obs.done: + prompt = build_prompt(obs, workflow, tokenizer) + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=256, + temperature=0.8, + do_sample=True, + ) + generated = tokenizer.decode( + outputs[0][inputs["input_ids"].shape[1]:], + skip_special_tokens=True + ) + + action_type, content = parse_action(generated) + obs = await env.step(action_type, content) + + trajectory.append({ + "prompt": prompt, + "generated": generated, + "action_type": action_type, + "reward": obs.reward, + "components": obs.reward_components, + "done": obs.done, + }) + total_reward += obs.reward + + return { + "trajectory": trajectory, + "total_reward": total_reward, + "steps_completed": obs.steps_completed, + "violations": obs.constraints_violated, + "difficulty": difficulty, + } +``` + +--- + +## Phase 10: Curriculum Scheduler (Person B) — `training/curriculum.py` + +```python +# training/curriculum.py +from dataclasses import dataclass + +@dataclass +class CurriculumConfig: + thresholds: dict # mean_reward -> difficulty_distribution + + def get_distribution(self, mean_reward: float) -> dict: + for threshold in sorted(self.thresholds.keys(), reverse=True): + if mean_reward >= threshold: + return self.thresholds[threshold] + return self.thresholds[min(self.thresholds.keys())] + + +DEFAULT_CURRICULUM = CurriculumConfig( + thresholds={ + 0.0: {1: 0.90, 2: 0.10, 3: 0.00, 4: 0.00}, + 0.30: {1: 0.50, 2: 0.40, 3: 0.10, 4: 0.00}, + 0.50: {1: 0.20, 2: 0.40, 3: 0.35, 4: 0.05}, + 0.65: {1: 0.10, 2: 0.30, 3: 0.40, 4: 0.20}, + } +) + + +def sample_difficulty(curriculum: CurriculumConfig, mean_reward: float) -> int: + import random + dist = curriculum.get_distribution(mean_reward) + return random.choices( + list(dist.keys()), + weights=list(dist.values()), + k=1 + )[0] +``` + +--- + +## Phase 11: Training Script (Person B) — `training/grpo_train.py` + +```python +# training/grpo_train.py +import torch +import asyncio +import numpy as np +from unsloth import FastLanguageModel +from trl import GRPOConfig, GRPOTrainer +from curriculum import DEFAULT_CURRICULUM, sample_difficulty +from rollout import run_episode + +# --- Model Load --- +model, tokenizer = FastLanguageModel.from_pretrained( + model_name="unsloth/Qwen2.5-7B-Instruct", + max_seq_length=2048, + load_in_4bit=True, +) +model = FastLanguageModel.get_peft_model( + model, + r=16, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj"], + lora_alpha=16, + lora_dropout=0, + bias="none", + use_gradient_checkpointing="unsloth", +) + +ENV_URL = "http://localhost:8000" # or HuggingFace Space URL + +# --- Reward function for GRPO (wraps environment) --- +def salespath_reward_fn(completions, prompts, **kwargs) -> list[float]: + """ + GRPO calls this with a batch of completions. + We run each through the environment and return rewards. + """ + rewards = [] + for completion in completions: + # Parse action from completion + from rollout import parse_action + action_type, content = parse_action(completion) + # For GRPO, we use a simplified single-step reward + # Full episode reward is tracked separately in curriculum loop + reward = kwargs.get("step_rewards", {}).get(completion, 0.0) + rewards.append(reward) + return rewards + + +# --- Training config --- +training_config = GRPOConfig( + output_dir="salespath_grpo_output", + num_train_epochs=3, + per_device_train_batch_size=2, + gradient_accumulation_steps=4, + num_generations=8, + max_new_tokens=256, + temperature=0.8, + learning_rate=1e-5, + logging_steps=10, + save_steps=100, + report_to="none", +) + +# --- Curriculum training loop --- +async def curriculum_train(): + mean_reward = 0.0 + reward_history = [] + + for step in range(500): + difficulty = sample_difficulty(DEFAULT_CURRICULUM, mean_reward) + result = await run_episode(model, tokenizer, ENV_URL, difficulty) + + reward_history.append(result["total_reward"]) + if len(reward_history) > 20: + mean_reward = np.mean(reward_history[-20:]) + + # Log metrics + if step % 10 == 0: + print(f"Step {step:4d} | Difficulty {difficulty} | " + f"Reward {result['total_reward']:.3f} | " + f"Mean(20) {mean_reward:.3f} | " + f"Violations {len(result['violations'])} | " + f"Steps {result['steps_completed']}") + + # Manual inspection every 50 steps + if step % 50 == 0: + print("\n=== RAW GENERATION SAMPLE ===") + if result["trajectory"]: + print(result["trajectory"][0]["generated"]) + print("==============================\n") + + +if __name__ == "__main__": + asyncio.run(curriculum_train()) +``` + +--- + +## Phase 12: Dockerfile (Person A) — `server/Dockerfile` + +```dockerfile +ARG BASE_IMAGE=openenv-base:latest +FROM ${BASE_IMAGE} + +COPY server/requirements.txt /tmp/requirements.txt +RUN pip install --no-cache-dir -r /tmp/requirements.txt + +COPY src/openenv/core/ /app/src/openenv/core/ +COPY salespath_env/ /app/salespath_env/ + +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8000/health || exit 1 + +CMD ["uvicorn", "salespath_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"] +``` + +`server/requirements.txt`: +``` +fastapi +uvicorn +pydantic>=2.0 +``` + +--- + +## Phase 13: Deploy to HuggingFace + +```bash +# From salespath_env/ directory +openenv push --repo-id Imsachin010/salespath-env + +# Verify it's running +curl -X POST https://imsachin010-salespath-env.hf.space/reset \ + -H "Content-Type: application/json" \ + -d '{"difficulty": 1}' +``` + +--- + +## Phase 14: Model Save (After Training) + +```python +# CORRECT save — do not change this +model.save_pretrained_merged( + "salespath_trained_merged", + tokenizer, + save_method="merged_16bit", +) + +# Push to HuggingFace Hub +model.push_to_hub_merged( + "Imsachin010/salespath-qwen25-7b", + tokenizer, + save_method="merged_16bit", +) +``` + +--- + +## Build Order Summary + +``` +Person A (Environment): Person B (Training): +1. models.py (wait for models.py) +2. server/task_bank.py 1. server/reward.py +3. server/rules.py 2. training/rollout.py +4. server/prospect_simulator.py 3. training/curriculum.py +5. server/salespath_environment 4. training/grpo_train.py +6. server/app.py 5. training/colab_train.ipynb +7. Dockerfile +8. openenv push → verify health + 6. Connect rollout to live env URL + 7. Run first training loop (difficulty=1 only) + 8. Verify reward > 0 on step 1 + 9. Enable curriculum +``` + +**Critical gate:** Person B does not run training until Person A has confirmed: +- `POST /reset` returns a valid observation +- `POST /step` with a valid action returns a valid observation +- `POST /step` with an invalid action returns error in `info`, not a 500 diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..1750eecb3bee3a18ce14c534e50d16a53b13b3cc --- /dev/null +++ b/Dockerfile @@ -0,0 +1,27 @@ +FROM python:3.11-slim + +# HuggingFace Spaces runs on port 7860 by default +ENV PORT=7860 +ENV PYTHONUNBUFFERED=1 +ENV PYTHONDONTWRITEBYTECODE=1 + +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Install Python dependencies +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy the salespath_env package +COPY salespath_env/ ./salespath_env/ + +# Health check +HEALTHCHECK --interval=30s --timeout=5s --start-period=15s --retries=3 \ + CMD curl -f http://localhost:${PORT}/health || exit 1 + +# Start the FastAPI server on HF Spaces port +CMD ["sh", "-c", "uvicorn salespath_env.server.app:app --host 0.0.0.0 --port ${PORT}"] diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a79ccc6906f3128ce46dec35ddea56fbe2ef0ae5 --- /dev/null +++ b/README.md @@ -0,0 +1,51 @@ +--- +title: SalesPath Environment +emoji: 🤝 +colorFrom: blue +colorTo: indigo +sdk: docker +app_port: 7860 +pinned: false +license: mit +short_description: RL gym environment for sales agent training +--- + +# SalesPath Environment + +A [OpenEnv](https://github.com/openenv)-compatible Reinforcement Learning gym environment for training sales agents via LLM fine-tuning. + +## API Endpoints + +| Method | Endpoint | Description | +|--------|----------|-------------| +| `POST` | `/reset` | Reset the environment, returns initial observation | +| `POST` | `/step` | Take an action, returns next observation + reward | +| `GET` | `/health` | Health check | + +## Quick Start + +### Reset +```bash +curl -X POST https://imsachin010-salespath-env.hf.space/reset \ + -H "Content-Type: application/json" \ + -d '{"difficulty": 1}' +``` + +### Step +```bash +curl -X POST https://imsachin010-salespath-env.hf.space/step \ + -H "Content-Type: application/json" \ + -d '{"action": {"action_type": "PROSPECT", "content": "Hello, tell me about your workflow challenges."}}' +``` + +## Action Types + +- `PROSPECT` — Initial outreach and discovery +- `QUALIFY` — Qualify the lead +- `PRESENT` — Deliver the sales pitch +- `HANDLE_OBJECTION` — Handle prospect objections +- `OFFER_DEMO` — Offer product demonstration +- `NEGOTIATE` — Discuss pricing and terms +- `FOLLOW_UP` — Follow-up message +- `DISQUALIFY` — Exit if prospect is not a fit +- `CLOSE` — Attempt to close the deal diff --git a/RULES.md b/RULES.md new file mode 100644 index 0000000000000000000000000000000000000000..b3bc529f685ff3e003ce4dc193e9464f9cbf6d2c --- /dev/null +++ b/RULES.md @@ -0,0 +1,354 @@ +# SalesPath — Agent Rules & Constraints +### Read this before touching any file. These are non-negotiable. + +--- + +## 0. Project Identity + +- **Project name:** `salespath_env` +- **HuggingFace repo:** `Imsachin010/salespath-env` +- **Theme:** Theme #2 — Long-Horizon Planning (Scale AI bonus prize) +- **Stack:** OpenEnv + GRPO (HF TRL) + Unsloth + Qwen 2.5 7B Instruct + +--- + +## 1. Directory Structure — Do Not Deviate + +``` +salespath_env/ +├── __init__.py +├── models.py ← ALL Pydantic dataclasses live here only +├── client.py ← SalesPathEnv(EnvClient) lives here only +├── README.md +├── openenv.yaml +├── pyproject.toml +├── server/ +│ ├── __init__.py +│ ├── salespath_environment.py ← SalesPathEnvironment(Environment) +│ ├── prospect_simulator.py ← ProspectSimulator (rule-based only) +│ ├── reward.py ← ALL reward logic lives here only +│ ├── task_bank.py ← ALL prospect profiles and tasks +│ ├── rules.py ← ALL business rule definitions +│ ├── app.py ← FastAPI app only, no logic +│ ├── requirements.txt +│ └── Dockerfile +training/ +├── grpo_train.py ← training script +├── rollout.py ← rollout function +├── curriculum.py ← difficulty scheduler +└── colab_train.ipynb ← Colab notebook for judges +``` + +--- + +## 2. OpenEnv API — Exact Signatures to Follow + +```python +# models.py — extend these base classes +from openenv.core import Action, Observation, State # actual imports + +class SalesPathAction(Action): + action_type: str # one of the 9 valid action types + content: str # natural language content of the action + target: str = "" # optional target (e.g., which objection) + +class SalesPathObservation(Observation): + prospect_response: str + workflow_stage: str + constraints_violated: list[str] + steps_completed: list[str] + turn_number: int + reward: float + done: bool + info: dict + +class SalesPathState(State): + episode_id: str + prospect_profile: dict + conversation_history: list[dict] + workflow_stage: str + steps_completed: list[str] + constraints_violated: list[str] + turn_number: int + difficulty: int # 1, 2, 3, or 4 + hidden_state: dict # NOT exposed to agent +``` + +```python +# server/salespath_environment.py +from openenv.core.env_server import Environment + +class SalesPathEnvironment(Environment): + def reset(self, difficulty: int = 1) -> SalesPathObservation: ... + def step(self, action: SalesPathAction) -> SalesPathObservation: ... + @property + def state(self) -> SalesPathState: ... +``` + +```python +# server/app.py — nothing else in this file +from openenv.core.env_server import create_fastapi_app +from ..models import SalesPathAction, SalesPathObservation +from .salespath_environment import SalesPathEnvironment + +app = create_fastapi_app(SalesPathEnvironment, SalesPathAction, SalesPathObservation) +``` + +--- + +## 3. Hard Rules — Code Will Be Rejected If Violated + +### 3.1 No LLM in the Environment +- `ProspectSimulator` is a **pure rule-based state machine** +- No API calls, no model inference, no `transformers` imports inside `server/` +- If you find yourself writing `model.generate()` inside `server/`, stop. Wrong file. + +### 3.2 Immutable Prospect State +- Once `reset()` sets the prospect profile, agent actions **cannot modify `hidden_state`** +- `hidden_state` is read-only after `reset()` +- Never expose `hidden_state` fields in `SalesPathObservation` + +### 3.3 Reward Lives in One Place +- All reward computation goes in `server/reward.py` +- `salespath_environment.py` calls `compute_reward()` — it does not compute reward itself +- Never compute reward inside `step()` directly + +### 3.4 Business Rules Live in One Place +- All rule definitions go in `server/rules.py` as a list of `BusinessRule` dataclasses +- `step()` calls `check_rules(state, action)` from `rules.py` — it does not check rules inline + +### 3.5 Turn Limit is Absolute +- Max turns = 20. Hard terminate. No exceptions. +- Episode must set `done=True` and assign `r_outcome = -0.3` at turn 20 regardless of state + +### 3.6 Action Validation is Strict +- If `action_type` is not one of the 9 valid types, return `done=False`, `reward=-0.2`, observation with error message +- Do not raise exceptions to the agent — return a valid `SalesPathObservation` with error in `info` + +### 3.7 Reward Must Be Multi-Component +- Reward function must log all 5 components separately in `info` dict +- Never return a single scalar reward without component breakdown +- Component keys: `r_outcome`, `r_compliance`, `r_ordering`, `r_efficiency`, `r_format` + +### 3.8 No Global Mutable State in Environment +- Each WebSocket session gets its own `SalesPathEnvironment` instance +- No class-level variables that change during episodes +- No module-level state + +--- + +## 4. Valid Action Types — Exact Strings + +```python +VALID_ACTIONS = { + "PROSPECT", # initial outreach — only valid on turn 1 + "QUALIFY", # ask qualification questions + "PRESENT", # deliver pitch + "HANDLE_OBJECTION", # respond to raised objection + "OFFER_DEMO", # propose product demonstration + "NEGOTIATE", # discuss pricing/terms + "CLOSE", # submit closing offer → terminates episode + "FOLLOW_UP", # follow up after no response + "DISQUALIFY", # exit if prospect is not a fit → terminates episode +} +``` + +--- + +## 5. Business Rules — Exact Definitions + +These are checked after every `step()`. Each violation increments `constraints_violated`. + +```python +RULES = [ + # ID Name Condition for VIOLATION + R01 "qualify_before_present" PRESENT called before any QUALIFY + R02 "demo_before_negotiate" NEGOTIATE called before OFFER_DEMO + R03 "budget_known_to_negotiate" NEGOTIATE called while budget_signal == "unknown" + R04 "discount_after_objections" Discount mentioned in NEGOTIATE before 2 objections handled + R05 "no_repeat_action" Same action_type on consecutive turns + R06 "prospect_first" Any action other than PROSPECT on turn 1 + R07 "followup_timing" FOLLOW_UP called when prospect responded last turn + R08 "disqualify_logic" DISQUALIFY called when budget >= threshold AND decision_maker==True + R09 "close_requires_demo" CLOSE called before OFFER_DEMO +] +``` + +Three violations → `done=True`, `r_outcome = -0.5` + +--- + +## 6. Prospect Simulator — Exact Response Rules + +`ProspectSimulator.respond(action, state)` returns one of these string tokens. The environment converts tokens to natural language text for the observation. + +```python +RESPONSE_TOKENS = { + "open:positive_signal", # prospect is engaged and open + "open:neutral_signal", # prospect acknowledges but non-committal + "objection:price", # raises price objection + "objection:timing", # raises timing objection + "objection:premature_pitch", # triggered by R01 violation + "deflect:budget_not_discussed", # triggered by R03 violation + "deflect:stall", # prospect stalls (Level 3+) + "accept:demo_scheduled", # agrees to demo + "accept:close_success", # agrees to close → episode success + "reject:close_failed", # rejects close + "silence", # no response (enables FOLLOW_UP) + "exit:disqualified", # prospect exits conversation +} +``` + +--- + +## 7. Difficulty Configuration + +```python +DIFFICULTY_CONFIG = { + 1: { + "max_turns": 20, + "workflow_steps": ["QUALIFY", "PRESENT", "CLOSE"], + "num_objections": 0, + "budget_hidden": False, + "mode_shift": False, + "optimal_turns": 5, + }, + 2: { + "max_turns": 20, + "workflow_steps": ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "CLOSE"], + "num_objections": 1, + "budget_hidden": True, # revealed after QUALIFY + "mode_shift": False, + "optimal_turns": 8, + }, + 3: { + "max_turns": 20, + "workflow_steps": ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", + "HANDLE_OBJECTION", "NEGOTIATE", "CLOSE"], + "num_objections": 2, + "budget_hidden": True, + "mode_shift": True, # prospect signals shift at turn 10 + "optimal_turns": 12, + }, + 4: { + "max_turns": 20, + "workflow_steps": "full", # agent must determine correct path + "num_objections": 2, + "budget_hidden": True, + "mode_shift": True, + "misleading_signals": True, # budget signals are deceptive + "optimal_turns": 14, + }, +} +``` + +--- + +## 8. Reward — Exact Weights + +```python +REWARD_WEIGHTS = { + "r_outcome": 0.40, + "r_compliance": 0.30, + "r_ordering": 0.15, + "r_efficiency": 0.10, + "r_format": 0.05, +} + +OUTCOME_VALUES = { + "close_success": 1.0, + "disqualify_correct": 0.5, + "turn_limit_reached": -0.3, + "close_failed": -0.5, + "three_violations": -0.5, +} + +COMPLIANCE_PER_VIOLATION = -0.2 # capped at -1.0 +EFFICIENCY_PER_EXTRA_TURN = -0.05 # capped at -0.3 +FORMAT_PASS = 1.0 +FORMAT_FAIL = -0.1 +``` + +--- + +## 9. Training Rules + +### Prompt Format (what gets sent to the LLM) +``` +System: You are a B2B sales agent. Follow this workflow strictly: +{workflow_steps_for_difficulty} + +Business rules you must never violate: +{rules_list} + +Current state: +- Prospect: {prospect_summary} +- Stage: {workflow_stage} +- Steps done: {steps_completed} +- Turn: {turn_number}/20 + +Prospect said: {prospect_response} + +Respond with: +ACTION: +CONTENT: +``` + +### Response parsing +- Extract `ACTION:` line → `action_type` +- Extract `CONTENT:` line → `content` +- If parsing fails → `r_format = -0.1`, use fallback QUALIFY + +### GRPO config +```python +GRPOConfig( + num_generations=8, # rollouts per prompt + max_new_tokens=256, + temperature=0.8, + learning_rate=1e-5, + per_device_train_batch_size=2, + gradient_accumulation_steps=4, +) +``` + +--- + +## 10. What to Monitor During Training + +Log these every 10 steps. If any of these goes wrong, stop and inspect raw generations: + +| Metric | Healthy Range | Alarm | +|--------|--------------|-------| +| `mean_reward` | Rising | Flat for >50 steps | +| `mean_r_compliance` | Rising | < -0.5 after step 100 | +| `violations_per_episode` | Falling | > 3.0 after step 100 | +| `ordering_rate` | Rising toward 0.85 | < 0.3 after step 150 | +| `close_success_rate` | Rising | 0 after step 200 | + +Inspect raw generations every 50 steps. Look for: repeated actions, empty CONTENT, invalid ACTION types, CLOSE before QUALIFY. + +--- + +## 11. Save Model Correctly + +```python +# CORRECT — do not deviate +model.save_pretrained_merged( + "salespath_trained", + tokenizer, + save_method="merged_16bit", # NOT naive upcast of 4bit +) +``` + +Never do: `model.save_pretrained()` on a 4-bit model without merging first. + +--- + +## 12. File Ownership (2-Person Team) + +| Person | Files | +|--------|-------| +| **A** | `models.py`, `server/salespath_environment.py`, `server/prospect_simulator.py`, `server/rules.py`, `server/task_bank.py`, `server/app.py`, `Dockerfile` | +| **B** | `server/reward.py`, `training/grpo_train.py`, `training/rollout.py`, `training/curriculum.py`, `training/colab_train.ipynb`, `client.py` | + +Both: `README.md`, `openenv.yaml`, `pyproject.toml` diff --git a/push_to_hub.py b/push_to_hub.py new file mode 100644 index 0000000000000000000000000000000000000000..06b37097279d1e43d22e38ad5242a7931a8ddda0 --- /dev/null +++ b/push_to_hub.py @@ -0,0 +1,44 @@ +from huggingface_hub import HfApi +import os + +REPO_ID = "Imsachin010/salespath-env" +FOLDER_PATH = "." + +IGNORE_PATTERNS = [ + "*.pyc", + "**/__pycache__/**", + ".git/**", + ".spa/**", + ".SPA/**", + "*.egg-info/**", + "push_to_hub.py", + "salespath_env/server/Dockerfile", # root Dockerfile is used instead + "training/**", # exclude training scripts from Space +] + +def main(): + api = HfApi() + + api.create_repo( + repo_id=REPO_ID, + repo_type="space", + space_sdk="docker", + exist_ok=True, + private=False, + ) + + api.upload_folder( + folder_path=FOLDER_PATH, + repo_id=REPO_ID, + repo_type="space", + ignore_patterns=IGNORE_PATTERNS, + commit_message="Deploy SalesPath Environment", + ) + + print( + f"Live Space URL:\n" + f"https://{REPO_ID.replace('/', '-')}.hf.space" + ) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..4b40d28060f04c74e9fbb7fc504a3526567185b4 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,22 @@ +[build-system] +requires = ["setuptools>=42"] +build-backend = "setuptools.build_meta" + +[project] +name = "salespath_env" +version = "0.1.0" +requires-python = ">=3.10" +dependencies = [ + "openenv", + "fastapi", + "uvicorn", + "pydantic>=2.0", + "trl>=0.8.0", + "unsloth", + "torch", + "transformers", +] + +[tool.setuptools.packages.find] +where = ["."] +include = ["salespath_env*"] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a4f942b97aa7be5a5bd8e26f6144938ed489dec3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +fastapi>=0.110.0 +uvicorn[standard]>=0.29.0 +pydantic>=2.0 +openenv-core>=0.2.3 diff --git a/salespath_env.egg-info/PKG-INFO b/salespath_env.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..21fa0a72066bd5ee8cd48ce9b5f114d1cfc165a1 --- /dev/null +++ b/salespath_env.egg-info/PKG-INFO @@ -0,0 +1,8 @@ +Metadata-Version: 2.4 +Name: salespath-env +Version: 0.1.0 +Requires-Python: >=3.10 +Requires-Dist: openenv-core>=0.2.3 +Requires-Dist: fastapi>=0.110.0 +Requires-Dist: uvicorn[standard]>=0.29.0 +Requires-Dist: pydantic>=2.0 diff --git a/salespath_env.egg-info/SOURCES.txt b/salespath_env.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..52403ce86fa6aafca18247291d6aa493bae694b2 --- /dev/null +++ b/salespath_env.egg-info/SOURCES.txt @@ -0,0 +1,17 @@ +README.md +pyproject.toml +salespath_env/__init__.py +salespath_env/client.py +salespath_env/models.py +salespath_env.egg-info/PKG-INFO +salespath_env.egg-info/SOURCES.txt +salespath_env.egg-info/dependency_links.txt +salespath_env.egg-info/requires.txt +salespath_env.egg-info/top_level.txt +salespath_env/server/__init__.py +salespath_env/server/app.py +salespath_env/server/prospect_simulator.py +salespath_env/server/reward.py +salespath_env/server/rules.py +salespath_env/server/salespath_environment.py +salespath_env/server/task_bank.py \ No newline at end of file diff --git a/salespath_env.egg-info/dependency_links.txt b/salespath_env.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/salespath_env.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/salespath_env.egg-info/requires.txt b/salespath_env.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..d806fb9b629f3e2bcefd0a43d5911684c277f109 --- /dev/null +++ b/salespath_env.egg-info/requires.txt @@ -0,0 +1,4 @@ +openenv-core>=0.2.3 +fastapi>=0.110.0 +uvicorn[standard]>=0.29.0 +pydantic>=2.0 diff --git a/salespath_env.egg-info/top_level.txt b/salespath_env.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..753add4d8e11513385d805d30bdd7d97a042e882 --- /dev/null +++ b/salespath_env.egg-info/top_level.txt @@ -0,0 +1 @@ +salespath_env diff --git a/salespath_env/README.md b/salespath_env/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/salespath_env/__init__.py b/salespath_env/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4696797dc43ad934832a788794495e414f3b5769 --- /dev/null +++ b/salespath_env/__init__.py @@ -0,0 +1,2 @@ +"""SalesPath OpenEnv package.""" + diff --git a/salespath_env/__pycache__/__init__.cpython-313.pyc b/salespath_env/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..940922f64a29a9619e655e86d83db8a411924d12 Binary files /dev/null and b/salespath_env/__pycache__/__init__.cpython-313.pyc differ diff --git a/salespath_env/__pycache__/client.cpython-313.pyc b/salespath_env/__pycache__/client.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4ac7d84a63dc74a59935e573d5df4fe13e5ed59 Binary files /dev/null and b/salespath_env/__pycache__/client.cpython-313.pyc differ diff --git a/salespath_env/__pycache__/models.cpython-313.pyc b/salespath_env/__pycache__/models.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ddc8b9462684a20a5e856c06c4da25af0e4b6f2a Binary files /dev/null and b/salespath_env/__pycache__/models.cpython-313.pyc differ diff --git a/salespath_env/client.py b/salespath_env/client.py new file mode 100644 index 0000000000000000000000000000000000000000..5512b3101d436db9d349dd981cf3c3acd1aa6629 --- /dev/null +++ b/salespath_env/client.py @@ -0,0 +1,81 @@ +# salespath_env/client.py + +from typing import Any, Dict + +from openenv.core import EnvClient +from openenv.core.client_types import StepResult + +from .models import ( + SalesPathAction, + SalesPathObservation, + SalesPathState, +) + + +class SalesPathEnv(EnvClient[SalesPathAction, SalesPathObservation, SalesPathState]): + + # ------------------------------------------------------------------ # + # Abstract method implementations required by EnvClient # + # ------------------------------------------------------------------ # + + def _step_payload(self, action: SalesPathAction) -> Dict[str, Any]: + """Serialise action → JSON dict for the WebSocket server. + WSStepMessage.data IS the action dict directly (no wrapper key). + """ + return action.model_dump(exclude={"metadata"}) + + def _parse_result(self, payload: Dict[str, Any]) -> StepResult[SalesPathObservation]: + """Deserialise server JSON → StepResult[SalesPathObservation].""" + # Server may nest obs under an 'observation' key + obs_data = payload.get("observation", payload) + obs = SalesPathObservation(**obs_data) + return StepResult( + observation=obs, + reward=payload.get("reward", obs.reward), + done=payload.get("done", obs.done), + ) + + def _parse_state(self, payload: Dict[str, Any]) -> SalesPathState: + """Deserialise server JSON → SalesPathState.""" + state_data = payload.get("state", payload) + return SalesPathState(**state_data) + + # ------------------------------------------------------------------ # + # Convenience wrappers that return the unwrapped observation directly # + # ------------------------------------------------------------------ # + + @staticmethod + def _with_step_fields( + result: StepResult[SalesPathObservation], + ) -> SalesPathObservation: + """ + Keep observation fields in sync with StepResult wrapper fields. + Some server payloads provide reward/done only at top-level. + """ + return result.observation.model_copy( + update={ + "reward": result.reward, + "done": result.done, + } + ) + + async def reset( + self, + difficulty: int = 1, + ) -> SalesPathObservation: + result = await super().reset(difficulty=difficulty) + return self._with_step_fields(result) + + async def step( + self, + action_type: str, + content: str, + target: str = "", + ) -> SalesPathObservation: + action = SalesPathAction( + action_type=action_type, + content=content, + target=target, + ) + result = await super().step(action) + return self._with_step_fields(result) \ No newline at end of file diff --git a/salespath_env/models.py b/salespath_env/models.py new file mode 100644 index 0000000000000000000000000000000000000000..86d5d934296632d6af5128bfe242782430e07a94 --- /dev/null +++ b/salespath_env/models.py @@ -0,0 +1,93 @@ +# salespath_env/models.py + +from __future__ import annotations + +import uuid +from typing import Dict, List +from pydantic import BaseModel, Field + +# Safe OpenEnv Imports: Use OpenEnv base classes if available, +# otherwise fall back to Pydantic to bypass security blocks. +try: + from openenv.core import Action, Observation, State +except (ImportError, Exception): + Action = BaseModel + Observation = BaseModel + State = BaseModel + + +VALID_ACTIONS = { + "PROSPECT", + "QUALIFY", + "PRESENT", + "HANDLE_OBJECTION", + "OFFER_DEMO", + "NEGOTIATE", + "CLOSE", + "FOLLOW_UP", + "DISQUALIFY", +} + + +class SalesPathAction(Action): + """ + Action sent by the agent to the environment. + """ + + action_type: str + content: str + target: str = "" + + def is_valid(self) -> bool: + """ + Strict validation of allowed action types. + """ + return self.action_type in VALID_ACTIONS + + +class SalesPathObservation(Observation): + """ + What the agent is allowed to observe. + Hidden state must NEVER be exposed here. + """ + + prospect_response: str = "" + workflow_stage: str = "START" + + constraints_violated: List[str] = Field(default_factory=list) + steps_completed: List[str] = Field(default_factory=list) + + turn_number: int = 0 + + reward: float = 0.0 + reward_components: Dict = Field(default_factory=dict) + + done: bool = False + info: Dict = Field(default_factory=dict) + + +class SalesPathState(State): + """ + Internal environment state. + Includes hidden state not exposed to the agent. + """ + + episode_id: str = Field(default_factory=lambda: str(uuid.uuid4())) + + prospect_profile: Dict = Field(default_factory=dict) + conversation_history: List[Dict] = Field(default_factory=list) + + workflow_stage: str = "START" + required_workflow: List[str] = Field(default_factory=list) + + steps_completed: List[str] = Field(default_factory=list) + constraints_violated: List[str] = Field(default_factory=list) + + objections_handled: int = 0 + turn_number: int = 0 + difficulty: int = 1 + + done: bool = False + + # Hidden state — NEVER exposed in Observation + hidden_state: Dict = Field(default_factory=dict) \ No newline at end of file diff --git a/salespath_env/openenv.yaml b/salespath_env/openenv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9cf65cd71742b59aadb42640e8290805ed90ee58 --- /dev/null +++ b/salespath_env/openenv.yaml @@ -0,0 +1,13 @@ +[project] +name = "salespath_env" +version = "0.1.0" +dependencies = [ + "openenv", + "fastapi", + "uvicorn", + "pydantic>=2.0", + "trl>=0.8.0", + "unsloth", + "torch", + "transformers", +] \ No newline at end of file diff --git a/salespath_env/pyproject.toml b/salespath_env/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/salespath_env/server/Dockerfile b/salespath_env/server/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..0d752ad01adc8d1b670eb208b93eb5595da36b79 --- /dev/null +++ b/salespath_env/server/Dockerfile @@ -0,0 +1,12 @@ +ARG BASE_IMAGE=openenv-base:latest +FROM ${BASE_IMAGE} + +COPY server/requirements.txt /tmp/requirements.txt +RUN pip install --no-cache-dir -r /tmp/requirements.txt + +COPY salespath_env/ /app/salespath_env/ + +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8000/health || exit 1 + +CMD ["uvicorn", "salespath_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file diff --git a/salespath_env/server/__init__.py b/salespath_env/server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7ceba3a66eb98e188a300114132c13e88834d717 --- /dev/null +++ b/salespath_env/server/__init__.py @@ -0,0 +1,2 @@ +"""SalesPath environment server package.""" + diff --git a/salespath_env/server/__pycache__/__init__.cpython-313.pyc b/salespath_env/server/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..259bb92b0204cd8717e4e99f2592c337c0c5ef38 Binary files /dev/null and b/salespath_env/server/__pycache__/__init__.cpython-313.pyc differ diff --git a/salespath_env/server/__pycache__/app.cpython-313.pyc b/salespath_env/server/__pycache__/app.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..897ca7140db083e23f77d09c7d37bb1b64de83bd Binary files /dev/null and b/salespath_env/server/__pycache__/app.cpython-313.pyc differ diff --git a/salespath_env/server/__pycache__/prospect_simulator.cpython-313.pyc b/salespath_env/server/__pycache__/prospect_simulator.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25fb34af20d3a289cf06bab1a49cf22e5fa4599e Binary files /dev/null and b/salespath_env/server/__pycache__/prospect_simulator.cpython-313.pyc differ diff --git a/salespath_env/server/__pycache__/reward.cpython-313.pyc b/salespath_env/server/__pycache__/reward.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5224df2aafd547ffd3947ca3dd88271255fb6fc Binary files /dev/null and b/salespath_env/server/__pycache__/reward.cpython-313.pyc differ diff --git a/salespath_env/server/__pycache__/rules.cpython-313.pyc b/salespath_env/server/__pycache__/rules.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7770489c01b25bd118f59d12876e1e9ee8859c0 Binary files /dev/null and b/salespath_env/server/__pycache__/rules.cpython-313.pyc differ diff --git a/salespath_env/server/__pycache__/salespath_environment.cpython-313.pyc b/salespath_env/server/__pycache__/salespath_environment.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bdb03546d79e3f5735befac08bd666f956a5fef Binary files /dev/null and b/salespath_env/server/__pycache__/salespath_environment.cpython-313.pyc differ diff --git a/salespath_env/server/__pycache__/task_bank.cpython-313.pyc b/salespath_env/server/__pycache__/task_bank.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26b44c5703659985aea561309175cc60cad2e66c Binary files /dev/null and b/salespath_env/server/__pycache__/task_bank.cpython-313.pyc differ diff --git a/salespath_env/server/app.py b/salespath_env/server/app.py new file mode 100644 index 0000000000000000000000000000000000000000..9acd4c0f4c1027f5ef9adf6da529c136d8006087 --- /dev/null +++ b/salespath_env/server/app.py @@ -0,0 +1,18 @@ +# salespath_env/server/app.py + +from openenv.core.env_server import create_fastapi_app + +from ..models import ( + SalesPathAction, + SalesPathObservation, +) +from .salespath_environment import ( + SalesPathEnvironment, +) + + +app = create_fastapi_app( + SalesPathEnvironment, + SalesPathAction, + SalesPathObservation, +) \ No newline at end of file diff --git a/salespath_env/server/prospect_simulator.py b/salespath_env/server/prospect_simulator.py new file mode 100644 index 0000000000000000000000000000000000000000..46096f3490b848d69b917757dfe2521594ea5957 --- /dev/null +++ b/salespath_env/server/prospect_simulator.py @@ -0,0 +1,162 @@ +# salespath_env/server/prospect_simulator.py + +from ..models import SalesPathAction, SalesPathState + + +RESPONSE_TEXT = { + "open:positive_signal": "That sounds interesting. Tell me more about how this works.", + "open:neutral_signal": "I see. We're evaluating a few options at the moment.", + + "objection:price": "The pricing seems higher than what we budgeted for.", + "objection:timing": "The timing isn't ideal — we're in the middle of a quarter close.", + "objection:premature_pitch": ( + "I'm not sure we're ready to discuss solutions yet. " + "What do you know about our current situation?" + ), + + "deflect:budget_not_discussed": ( + "We haven't really talked about what we're looking for yet." + ), + "deflect:stall": ( + "Let me get back to you on this. A lot is happening on our end." + ), + + "accept:demo_scheduled": ( + "Yes, let's set up a demo. What time works next week?" + ), + "accept:close_success": ( + "Alright, I think we can move forward with this. " + "Send over the paperwork." + ), + + "reject:close_failed": ( + "I don't think we're ready to commit at this point." + ), + + "silence": "", + + "exit:disqualified": ( + "I think we're done here. This isn't the right fit." + ), +} + + +class ProspectSimulator: + """ + Pure rule-based simulator. + No LLM. No transformers. Deterministic behavior. + """ + + def respond( + self, + action: SalesPathAction, + state: SalesPathState, + ) -> tuple[str, str]: + """ + Returns: + (response_token, response_text) + """ + + token = self._get_token(action, state) + text = RESPONSE_TEXT[token] + + return token, text + + def _get_token( + self, + action: SalesPathAction, + state: SalesPathState, + ) -> str: + atype = action.action_type + difficulty = state.difficulty + turn = state.turn_number + profile = state.prospect_profile + hidden = state.hidden_state + objections = state.objections_handled + + # ----------------------------- + # Rule-triggered responses first + # ----------------------------- + + if state.constraints_violated: + latest = state.constraints_violated[-1] + + if latest == "R01": + return "objection:premature_pitch" + + if latest == "R03": + return "deflect:budget_not_discussed" + + # ----------------------------- + # Action-based responses + # ----------------------------- + + if atype == "PROSPECT": + return "open:positive_signal" + + if atype == "QUALIFY": + # Reveal budget if hidden + if profile.get("budget_signal") == "unknown": + state.prospect_profile["budget_signal"] = hidden.get( + "revealed_budget", + "medium", + ) + + return "open:neutral_signal" + + if atype == "PRESENT": + if difficulty >= 2: + if objections == 0: + return "objection:price" + + return "open:positive_signal" + + if atype == "HANDLE_OBJECTION": + state.objections_handled += 1 + + required_objections = hidden.get("num_objections", 1) + + if state.objections_handled >= required_objections: + return "open:positive_signal" + + if objections == 0: + return "objection:timing" + + return "open:positive_signal" + + if atype == "OFFER_DEMO": + return "accept:demo_scheduled" + + if atype == "NEGOTIATE": + return "open:neutral_signal" + + if atype == "CLOSE": + true_budget = hidden.get("true_budget", 0.7) + close_threshold = hidden.get("close_threshold", 0.5) + decision_maker = profile.get("decision_maker", True) + + if ( + true_budget >= close_threshold + and decision_maker + ): + return "accept:close_success" + + return "reject:close_failed" + + if atype == "FOLLOW_UP": + return "open:neutral_signal" + + if atype == "DISQUALIFY": + return "exit:disqualified" + + # ----------------------------- + # Difficulty 3+ mode shift + # ----------------------------- + + if difficulty >= 3 and turn >= 10: + import random + + if random.random() < hidden.get("stall_probability", 0.0): + return "deflect:stall" + + return "open:neutral_signal" \ No newline at end of file diff --git a/salespath_env/server/requirements.txt b/salespath_env/server/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..f7e5cb8b7c8e9168da056bcbe9b3c832b4020ac2 --- /dev/null +++ b/salespath_env/server/requirements.txt @@ -0,0 +1,3 @@ +fastapi +uvicorn +pydantic>=2.0 diff --git a/salespath_env/server/reward.py b/salespath_env/server/reward.py new file mode 100644 index 0000000000000000000000000000000000000000..a3eb0db33779b3ede1c75299858ce3490587706e --- /dev/null +++ b/salespath_env/server/reward.py @@ -0,0 +1,138 @@ +# salespath_env/server/reward.py + +from ..models import SalesPathAction, SalesPathState + + +DIFFICULTY_OPTIMAL_TURNS = { + 1: 5, + 2: 8, + 3: 12, + 4: 14, +} + + +def compute_reward( + state: SalesPathState, + action: SalesPathAction, + response_token: str, + new_violations: list[str], + episode_done: bool, +) -> tuple[float, dict]: + """ + Returns: + (total_reward, reward_components) + """ + + components = {} + + # -------------------------------------------------- + # 1. Outcome Reward (terminal only) + # -------------------------------------------------- + + r_outcome = 0.0 + + if episode_done: + if response_token == "accept:close_success": + r_outcome = 1.0 + + elif action.action_type == "DISQUALIFY": + if "R08" not in new_violations: + r_outcome = 0.5 + else: + r_outcome = -0.5 + + elif state.turn_number >= 20: + r_outcome = -0.3 + + elif len(state.constraints_violated) >= 3: + r_outcome = -0.5 + + else: + r_outcome = -0.5 + + components["r_outcome"] = r_outcome + + # -------------------------------------------------- + # 2. Compliance Reward + # -------------------------------------------------- + + r_compliance = max( + -1.0, + -0.2 * len(new_violations), + ) + + components["r_compliance"] = r_compliance + + # -------------------------------------------------- + # 3. Ordering Reward + # -------------------------------------------------- + + required = state.required_workflow + completed = state.steps_completed + + if len(required) > 0 and len(completed) > 0: + correct = sum( + 1 + for i in range(min(len(required), len(completed))) + if required[i] == completed[i] + ) + + r_ordering = correct / len(required) + + else: + r_ordering = 1.0 + + components["r_ordering"] = r_ordering + + # -------------------------------------------------- + # 4. Efficiency Reward + # -------------------------------------------------- + + if episode_done: + optimal = DIFFICULTY_OPTIMAL_TURNS.get( + state.difficulty, + 10, + ) + + extra_turns = max( + 0, + state.turn_number - optimal, + ) + + r_efficiency = max( + -0.3, + -0.05 * extra_turns, + ) + + else: + r_efficiency = 0.0 + + components["r_efficiency"] = r_efficiency + + # -------------------------------------------------- + # 5. Format Reward + # -------------------------------------------------- + + r_format = 1.0 if action.is_valid() else -0.1 + components["r_format"] = r_format + + # -------------------------------------------------- + # Final Weighted Reward + # -------------------------------------------------- + + weights = { + "r_outcome": 0.40, + "r_compliance": 0.30, + "r_ordering": 0.15, + "r_efficiency": 0.10, + "r_format": 0.05, + } + + total_reward = sum( + weights[key] * components[key] + for key in weights + ) + + components["total"] = total_reward + + return total_reward, components \ No newline at end of file diff --git a/salespath_env/server/rules.py b/salespath_env/server/rules.py new file mode 100644 index 0000000000000000000000000000000000000000..c8838e152ece6f3f66a9aeebdb04b6fcb5bfea31 --- /dev/null +++ b/salespath_env/server/rules.py @@ -0,0 +1,222 @@ +# salespath_env/server/rules.py + +from dataclasses import dataclass +from typing import Callable + +from ..models import SalesPathAction, SalesPathState + + +@dataclass +class BusinessRule: + """ + Returns True when the rule is VIOLATED. + """ + + rule_id: str + name: str + description: str + check: Callable[[SalesPathState, SalesPathAction], bool] + + +def _qualify_before_present( + state: SalesPathState, + action: SalesPathAction, +) -> bool: + """ + R01: + PRESENT before QUALIFY is invalid. + """ + if action.action_type == "PRESENT": + return "QUALIFY" not in state.steps_completed + return False + + +def _demo_before_negotiate( + state: SalesPathState, + action: SalesPathAction, +) -> bool: + """ + R02: + NEGOTIATE before OFFER_DEMO is invalid. + """ + if action.action_type == "NEGOTIATE": + return "OFFER_DEMO" not in state.steps_completed + return False + + +def _budget_known_to_negotiate( + state: SalesPathState, + action: SalesPathAction, +) -> bool: + """ + R03: + Cannot NEGOTIATE while budget is unknown. + """ + if action.action_type == "NEGOTIATE": + return state.prospect_profile.get("budget_signal") == "unknown" + return False + + +def _discount_after_objections( + state: SalesPathState, + action: SalesPathAction, +) -> bool: + """ + R04: + Discount only after 2 objections handled. + """ + if action.action_type == "NEGOTIATE": + if "discount" in action.content.lower(): + return state.objections_handled < 2 + return False + + +def _no_repeat_action( + state: SalesPathState, + action: SalesPathAction, +) -> bool: + """ + R05: + Same action twice in a row is invalid. + """ + if state.conversation_history: + last_action = state.conversation_history[-1].get("action_type", "") + return last_action == action.action_type + return False + + +def _prospect_first( + state: SalesPathState, + action: SalesPathAction, +) -> bool: + """ + R06: + First action must be PROSPECT. + """ + if state.turn_number == 1: + return action.action_type != "PROSPECT" + return False + + +def _followup_timing( + state: SalesPathState, + action: SalesPathAction, +) -> bool: + """ + R07: + FOLLOW_UP only valid after silence. + If prospect just responded last turn, violation. + """ + if action.action_type == "FOLLOW_UP": + if state.conversation_history: + last_speaker = state.conversation_history[-1].get("speaker", "agent") + return last_speaker == "prospect" + return False + + +def _disqualify_logic( + state: SalesPathState, + action: SalesPathAction, +) -> bool: + """ + R08: + DISQUALIFY only when prospect is genuinely not closeable. + Violation if prospect is actually closeable. + """ + if action.action_type == "DISQUALIFY": + true_budget = state.hidden_state.get("true_budget", 0.5) + close_threshold = state.hidden_state.get("close_threshold", 0.5) + decision_maker = state.prospect_profile.get("decision_maker", True) + + return (true_budget >= close_threshold) and decision_maker + + return False + + +def _close_requires_demo( + state: SalesPathState, + action: SalesPathAction, +) -> bool: + """ + R09: + Difficulty 2+ requires OFFER_DEMO before CLOSE. + """ + if action.action_type == "CLOSE": + if state.difficulty >= 2: + return "OFFER_DEMO" not in state.steps_completed + return False + + +BUSINESS_RULES = [ + BusinessRule( + "R01", + "qualify_before_present", + "Must QUALIFY before PRESENT", + _qualify_before_present, + ), + BusinessRule( + "R02", + "demo_before_negotiate", + "Must OFFER_DEMO before NEGOTIATE", + _demo_before_negotiate, + ), + BusinessRule( + "R03", + "budget_known_to_negotiate", + "Budget must be known before NEGOTIATE", + _budget_known_to_negotiate, + ), + BusinessRule( + "R04", + "discount_after_objections", + "Discount only after 2 objections handled", + _discount_after_objections, + ), + BusinessRule( + "R05", + "no_repeat_action", + "Cannot repeat same action consecutively", + _no_repeat_action, + ), + BusinessRule( + "R06", + "prospect_first", + "First action must be PROSPECT", + _prospect_first, + ), + BusinessRule( + "R07", + "followup_timing", + "FOLLOW_UP only after prospect silence", + _followup_timing, + ), + BusinessRule( + "R08", + "disqualify_logic", + "DISQUALIFY only when prospect is genuinely unqualified", + _disqualify_logic, + ), + BusinessRule( + "R09", + "close_requires_demo", + "Must OFFER_DEMO before CLOSE (difficulty 2+)", + _close_requires_demo, + ), +] + + +def check_rules( + state: SalesPathState, + action: SalesPathAction, +) -> list[str]: + """ + Returns list of violated rule IDs. + """ + + violated = [] + + for rule in BUSINESS_RULES: + if rule.check(state, action): + violated.append(rule.rule_id) + + return violated \ No newline at end of file diff --git a/salespath_env/server/salespath_environment.py b/salespath_env/server/salespath_environment.py new file mode 100644 index 0000000000000000000000000000000000000000..cebc27ec2bcfd30c89e284e24a0dba3a2a006ae1 --- /dev/null +++ b/salespath_env/server/salespath_environment.py @@ -0,0 +1,294 @@ +# salespath_env/server/salespath_environment.py + +import uuid + +from openenv.core.env_server import Environment + +from ..models import ( + SalesPathAction, + SalesPathObservation, + SalesPathState, +) +from .task_bank import sample_profile +from .rules import check_rules +from .reward import compute_reward +from .prospect_simulator import ProspectSimulator + + +DIFFICULTY_WORKFLOW = { + 1: [ + "QUALIFY", + "PRESENT", + "CLOSE", + ], + 2: [ + "QUALIFY", + "PRESENT", + "HANDLE_OBJECTION", + "OFFER_DEMO", + "CLOSE", + ], + 3: [ + "QUALIFY", + "PRESENT", + "HANDLE_OBJECTION", + "OFFER_DEMO", + "HANDLE_OBJECTION", + "NEGOTIATE", + "CLOSE", + ], + 4: [], # Agent must determine; DISQUALIFY may be correct +} + + +MAX_VIOLATIONS_BEFORE_TERMINATE = 3 +MAX_TURNS = 20 + + +class SalesPathEnvironment(Environment): + """ + Core OpenEnv environment. + All business logic routes through: + - rules.py + - reward.py + - prospect_simulator.py + """ + + def __init__(self): + super().__init__() + self._state = SalesPathState() + self._simulator = ProspectSimulator() + + def reset(self, difficulty: int = 1) -> SalesPathObservation: + """ + Start a new episode. + """ + + profile = sample_profile(difficulty) + + hidden_state = { + "true_budget": profile.true_budget, + "close_threshold": profile.close_threshold, + "stall_probability": profile.stall_probability, + "num_objections": { + 1: 0, + 2: 1, + 3: 2, + 4: 2, + }[difficulty], + "revealed_budget": ( + "high" + if profile.true_budget >= 0.7 + else "medium" + if profile.true_budget >= 0.4 + else "low" + ), + } + + public_profile = { + "company_name": profile.company_name, + "company_size": profile.company_size, + "industry": profile.industry, + "budget_signal": profile.budget_signal, + "pain_points": profile.pain_points, + "decision_maker": profile.decision_maker, + } + + self._state = SalesPathState( + episode_id=str(uuid.uuid4()), + prospect_profile=public_profile, + conversation_history=[], + workflow_stage="START", + required_workflow=DIFFICULTY_WORKFLOW[difficulty], + steps_completed=[], + constraints_violated=[], + objections_handled=0, + turn_number=0, + difficulty=difficulty, + done=False, + hidden_state=hidden_state, + ) + + intro_message = ( + f"You are engaging {profile.company_name}, " + f"a {profile.company_size} {profile.industry} company. " + f"Pain points: {', '.join(profile.pain_points)}. " + f"Begin the sales conversation." + ) + + return SalesPathObservation( + prospect_response=intro_message, + workflow_stage="START", + constraints_violated=[], + steps_completed=[], + turn_number=0, + reward=0.0, + reward_components={}, + done=False, + info={ + "difficulty": difficulty, + "episode_id": self._state.episode_id, + }, + ) + + def step( + self, + action: SalesPathAction, + ) -> SalesPathObservation: + """ + One environment transition. + """ + + state = self._state + + # ----------------------------------- + # Advance turn + # ----------------------------------- + + state.turn_number += 1 + + # ----------------------------------- + # Strict action validation + # Must return observation, never crash + # ----------------------------------- + + if not action.is_valid(): + return SalesPathObservation( + prospect_response="Invalid action type.", + workflow_stage=state.workflow_stage, + constraints_violated=list(state.constraints_violated), + steps_completed=list(state.steps_completed), + turn_number=state.turn_number, + reward=-0.2, + reward_components={ + "r_format": -0.1, + }, + done=False, + info={ + "error": ( + f"Invalid action_type: " + f"{action.action_type}" + ) + }, + ) + + # ----------------------------------- + # Rule checks + # ----------------------------------- + + new_violations = check_rules( + state, + action, + ) + + state.constraints_violated.extend( + new_violations + ) + + # ----------------------------------- + # Record agent action + # ----------------------------------- + + state.conversation_history.append( + { + "turn": state.turn_number, + "speaker": "agent", + "action_type": action.action_type, + "content": action.content, + } + ) + + # ----------------------------------- + # Update workflow state + # ----------------------------------- + + if action.action_type not in state.steps_completed: + state.steps_completed.append( + action.action_type + ) + + state.workflow_stage = action.action_type + + # ----------------------------------- + # Prospect response + # ----------------------------------- + + response_token, response_text = ( + self._simulator.respond( + action, + state, + ) + ) + + state.conversation_history.append( + { + "turn": state.turn_number, + "speaker": "prospect", + "response_token": response_token, + "text": response_text, + } + ) + + # ----------------------------------- + # Episode termination + # ----------------------------------- + + terminal_actions = { + "CLOSE", + "DISQUALIFY", + } + + too_many_violations = ( + len(state.constraints_violated) + >= MAX_VIOLATIONS_BEFORE_TERMINATE + ) + + turn_limit_reached = ( + state.turn_number >= MAX_TURNS + ) + + done = ( + action.action_type in terminal_actions + or too_many_violations + or turn_limit_reached + ) + + state.done = done + + # ----------------------------------- + # Reward + # ----------------------------------- + + total_reward, components = ( + compute_reward( + state=state, + action=action, + response_token=response_token, + new_violations=new_violations, + episode_done=done, + ) + ) + + return SalesPathObservation( + prospect_response=response_text, + workflow_stage=state.workflow_stage, + constraints_violated=list( + state.constraints_violated + ), + steps_completed=list( + state.steps_completed + ), + turn_number=state.turn_number, + reward=total_reward, + reward_components=components, + done=done, + info={ + "response_token": response_token, + "new_violations": new_violations, + "episode_id": state.episode_id, + }, + ) + + @property + def state(self) -> SalesPathState: + return self._state \ No newline at end of file diff --git a/salespath_env/server/task_bank.py b/salespath_env/server/task_bank.py new file mode 100644 index 0000000000000000000000000000000000000000..a64aa15ec2c095057adb380e5daea3ee65f5d645 --- /dev/null +++ b/salespath_env/server/task_bank.py @@ -0,0 +1,199 @@ +# salespath_env/server/task_bank.py + +import random +from dataclasses import dataclass + + +@dataclass +class ProspectProfile: + company_name: str + company_size: str # small / medium / enterprise + industry: str + budget_signal: str # high / medium / low / unknown + pain_points: list[str] + decision_maker: bool + + # Hidden values — never exposed directly to agent + true_budget: float # 0.0 → 1.0 + close_threshold: float + stall_probability: float + + +# ------------------------- +# LEVEL 1 — Easy +# budget known +# decision maker present +# close is usually possible +# ------------------------- + +PROFILES_L1 = [ + ProspectProfile( + company_name="Meridian Retail", + company_size="medium", + industry="retail", + budget_signal="high", + pain_points=[ + "manual inventory tracking", + "slow reporting", + ], + decision_maker=True, + true_budget=0.8, + close_threshold=0.5, + stall_probability=0.0, + ), + + ProspectProfile( + company_name="Northline Foods", + company_size="small", + industry="food distribution", + budget_signal="medium", + pain_points=[ + "supplier delays", + "inventory mismatch", + ], + decision_maker=True, + true_budget=0.6, + close_threshold=0.5, + stall_probability=0.0, + ), +] + + +# ------------------------- +# LEVEL 2 — Medium +# budget hidden initially +# one objection expected +# ------------------------- + +PROFILES_L2 = [ + ProspectProfile( + company_name="Apex Logistics", + company_size="enterprise", + industry="logistics", + budget_signal="unknown", + pain_points=[ + "route optimization", + "driver coordination", + "fuel tracking", + ], + decision_maker=True, + true_budget=0.7, + close_threshold=0.5, + stall_probability=0.0, + ), + + ProspectProfile( + company_name="Vertex Supply", + company_size="medium", + industry="manufacturing", + budget_signal="unknown", + pain_points=[ + "vendor visibility", + "purchase delays", + ], + decision_maker=True, + true_budget=0.55, + close_threshold=0.5, + stall_probability=0.0, + ), +] + + +# ------------------------- +# LEVEL 3 — Hard +# budget hidden +# 2 objections +# possible stalling +# decision maker may be absent +# ------------------------- + +PROFILES_L3 = [ + ProspectProfile( + company_name="Nova Financial", + company_size="enterprise", + industry="finance", + budget_signal="unknown", + pain_points=[ + "compliance reporting", + "audit trails", + "data silos", + ], + decision_maker=False, + true_budget=0.6, + close_threshold=0.55, + stall_probability=0.3, + ), + + ProspectProfile( + company_name="Atlas Health", + company_size="enterprise", + industry="healthcare", + budget_signal="unknown", + pain_points=[ + "patient workflow delays", + "reporting compliance", + ], + decision_maker=False, + true_budget=0.65, + close_threshold=0.55, + stall_probability=0.25, + ), +] + + +# ------------------------- +# LEVEL 4 — Trap cases +# misleading signals +# correct action may be DISQUALIFY +# ------------------------- + +PROFILES_L4 = [ + ProspectProfile( + company_name="Cipher Tech", + company_size="small", + industry="technology", + budget_signal="high", # misleading + pain_points=[ + "security", + "compliance", + ], + decision_maker=True, + true_budget=0.2, + close_threshold=0.5, + stall_probability=0.5, + ), + + ProspectProfile( + company_name="BluePeak Studio", + company_size="small", + industry="creative agency", + budget_signal="high", # misleading + pain_points=[ + "project visibility", + "client reporting", + ], + decision_maker=True, + true_budget=0.25, + close_threshold=0.5, + stall_probability=0.4, + ), +] + + +ALL_PROFILES = { + 1: PROFILES_L1, + 2: PROFILES_L2, + 3: PROFILES_L3, + 4: PROFILES_L4, +} + + +def sample_profile(difficulty: int) -> ProspectProfile: + """ + Returns one sampled profile for the selected difficulty. + """ + + if difficulty not in ALL_PROFILES: + difficulty = 1 + + return random.choice(ALL_PROFILES[difficulty]) \ No newline at end of file diff --git a/training/__init__.py b/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/training/__pycache__/__init__.cpython-313.pyc b/training/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26232ca93aa83860157875ef662f6958e11dc8c3 Binary files /dev/null and b/training/__pycache__/__init__.cpython-313.pyc differ diff --git a/training/__pycache__/curriculum.cpython-313.pyc b/training/__pycache__/curriculum.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbc99602448047654e209b2d15200f363a575d3c Binary files /dev/null and b/training/__pycache__/curriculum.cpython-313.pyc differ diff --git a/training/__pycache__/debug_episode.cpython-313.pyc b/training/__pycache__/debug_episode.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2663b3bc44e47991d70c9dd9fe7b230b6ff5a97 Binary files /dev/null and b/training/__pycache__/debug_episode.cpython-313.pyc differ diff --git a/training/__pycache__/grpo_train.cpython-313.pyc b/training/__pycache__/grpo_train.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdfead38254c5d23b4b3c21f8e52c83dd7f3cc81 Binary files /dev/null and b/training/__pycache__/grpo_train.cpython-313.pyc differ diff --git a/training/__pycache__/rollout.cpython-313.pyc b/training/__pycache__/rollout.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f05f3d2f5fbf86a6252c70df65a051fd5cc0123 Binary files /dev/null and b/training/__pycache__/rollout.cpython-313.pyc differ diff --git a/training/__pycache__/test_rollout.cpython-313.pyc b/training/__pycache__/test_rollout.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62f9b32f6fcf28cbe00392e2b40f4e70a663ca92 Binary files /dev/null and b/training/__pycache__/test_rollout.cpython-313.pyc differ diff --git a/training/colab_train.ipynb b/training/colab_train.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..b1c65aa8b4b86cb89906a0f423fdd2a4c03d019b --- /dev/null +++ b/training/colab_train.ipynb @@ -0,0 +1,100 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SalesPath Colab Training\n", + "\n", + "This notebook installs dependencies, runs a local environment server, validates rollout, and launches curriculum training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -U pip\n", + "!pip install fastapi uvicorn pydantic httpx torch transformers trl unsloth openenv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# If the repo is not already present, clone it.\n", + "# !git clone https://github.com//salespath_env.git\n", + "# %cd salespath_env\n", + "\n", + "%cd /content/salespath_env" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Start the OpenEnv-compatible server in background.\n", + "!nohup python -m uvicorn salespath_env.server.app:app --host 0.0.0.0 --port 8000 > /content/server.log 2>&1 &\n", + "!sleep 3\n", + "!python -c \"import httpx; r=httpx.get('http://127.0.0.1:8000/health', timeout=30); print(r.status_code, r.text)\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Rollout smoke test (single episode)\n", + "!python -m training.test_rollout" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Curriculum run (example)\n", + "!python -m training.grpo_train --steps 30 --env-url http://127.0.0.1:8000 --model-name Qwen/Qwen2.5-0.5B-Instruct" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Optional: Push merged model to Hugging Face\n", + "\n", + "Set your token first:\n", + "\n", + "```python\n", + "import os\n", + "os.environ['HF_TOKEN'] = 'hf_xxx'\n", + "```\n", + "\n", + "Then run:\n", + "\n", + "```bash\n", + "python -m training.grpo_train --steps 100 --push-merged --hub-repo Imsachin010/salespath-qwen25-7b\n", + "```" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/training/curriculum.py b/training/curriculum.py new file mode 100644 index 0000000000000000000000000000000000000000..130034eec4b5f7467d77cda0bc59b194349dd3cc --- /dev/null +++ b/training/curriculum.py @@ -0,0 +1,80 @@ +# training/curriculum.py + +from dataclasses import dataclass +import random + + +@dataclass +class CurriculumConfig: + """ + Maps mean reward → difficulty distribution + """ + + thresholds: dict + + def get_distribution( + self, + mean_reward: float, + ) -> dict: + for threshold in sorted( + self.thresholds.keys(), + reverse=True, + ): + if mean_reward >= threshold: + return self.thresholds[threshold] + + return self.thresholds[ + min(self.thresholds.keys()) + ] + + +DEFAULT_CURRICULUM = CurriculumConfig( + thresholds={ + 0.0: { + 1: 0.90, + 2: 0.10, + 3: 0.00, + 4: 0.00, + }, + + 0.30: { + 1: 0.50, + 2: 0.40, + 3: 0.10, + 4: 0.00, + }, + + 0.50: { + 1: 0.20, + 2: 0.40, + 3: 0.35, + 4: 0.05, + }, + + 0.65: { + 1: 0.10, + 2: 0.30, + 3: 0.40, + 4: 0.20, + }, + } +) + + +def sample_difficulty( + curriculum: CurriculumConfig, + mean_reward: float, +) -> int: + """ + Sample difficulty from curriculum schedule. + """ + + dist = curriculum.get_distribution( + mean_reward + ) + + return random.choices( + list(dist.keys()), + weights=list(dist.values()), + k=1, + )[0] \ No newline at end of file diff --git a/training/debug_episode.py b/training/debug_episode.py new file mode 100644 index 0000000000000000000000000000000000000000..4f271f01f61ef0a22a80590471efbcfbe2bf7ec1 --- /dev/null +++ b/training/debug_episode.py @@ -0,0 +1,40 @@ +import argparse +import asyncio + +from salespath_env.client import SalesPathEnv + + +async def run_debug(env_url: str, difficulty: int): + actions = [ + ("PRESENT", "pitch too early"), + ("PRESENT", "repeat pitch"), + ("PRESENT", "repeat pitch again"), + ] + + async with SalesPathEnv(base_url=env_url) as env: + obs = await env.reset(difficulty=difficulty) + print("RESET") + print(f" turn={obs.turn_number} done={obs.done} reward={obs.reward}") + print(f" response={obs.prospect_response}") + + for idx, (action_type, content) in enumerate(actions, start=1): + obs = await env.step(action_type=action_type, content=content, target="") + print(f"\nSTEP {idx} action={action_type}") + print(f" turn={obs.turn_number} done={obs.done} reward={obs.reward}") + print(f" violations={obs.constraints_violated}") + print(f" new_violations={obs.info.get('new_violations')}") + print(f" components={obs.reward_components}") + if obs.done: + break + + +def parse_args(): + parser = argparse.ArgumentParser(description="Debug stateful episode transitions.") + parser.add_argument("--env-url", default="http://127.0.0.1:8000") + parser.add_argument("--difficulty", type=int, default=2) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + asyncio.run(run_debug(args.env_url, args.difficulty)) diff --git a/training/grpo_train.py b/training/grpo_train.py new file mode 100644 index 0000000000000000000000000000000000000000..152fb1141e020018bd65db590d7089796f820615 --- /dev/null +++ b/training/grpo_train.py @@ -0,0 +1,315 @@ +import argparse +import asyncio +import ast +import os +import re +from pathlib import Path + +import numpy as np +from transformers import AutoModelForCausalLM, AutoTokenizer + +from training.curriculum import DEFAULT_CURRICULUM, sample_difficulty +from training.rollout import run_episode + + +DEFAULT_MODEL = "Qwen/Qwen2.5-0.5B-Instruct" +DEFAULT_ENV_URL = "http://127.0.0.1:8000" +VALID_ACTIONS = { + "PROSPECT", + "QUALIFY", + "PRESENT", + "HANDLE_OBJECTION", + "OFFER_DEMO", + "NEGOTIATE", + "CLOSE", + "FOLLOW_UP", + "DISQUALIFY", +} +WORKFLOW_MAP = { + 1: ["QUALIFY", "PRESENT", "CLOSE"], + 2: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "CLOSE"], + 3: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "HANDLE_OBJECTION", "NEGOTIATE", "CLOSE"], + 4: [], +} + + +def _load_model_and_tokenizer(model_name: str): + tokenizer = AutoTokenizer.from_pretrained(model_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + model = AutoModelForCausalLM.from_pretrained( + model_name, + dtype="auto", + device_map="auto", + ) + return model, tokenizer + + +async def curriculum_train( + model, + tokenizer, + env_url: str, + total_steps: int = 100, + print_every: int = 10, +): + """Curriculum rollout loop to benchmark env + policy behavior.""" + mean_reward = 0.0 + reward_history: list[float] = [] + run_log: list[dict] = [] + + for step in range(total_steps): + difficulty = sample_difficulty(DEFAULT_CURRICULUM, mean_reward) + result = await run_episode( + model=model, + tokenizer=tokenizer, + env_url=env_url, + difficulty=difficulty, + ) + + reward_history.append(float(result["total_reward"])) + mean_reward = float(np.mean(reward_history[-20:])) + + run_log.append( + { + "step": step, + "difficulty": difficulty, + "reward": float(result["total_reward"]), + "violations": len(result["violations"]), + "steps_completed": list(result["steps_completed"]), + } + ) + + if step % print_every == 0: + print( + f"Step {step:04d} | Difficulty {difficulty} | " + f"Reward {result['total_reward']:.3f} | Mean(20) {mean_reward:.3f} | " + f"Violations {len(result['violations'])} | Steps {result['steps_completed']}" + ) + + return { + "mean_reward": mean_reward, + "reward_history": reward_history, + "run_log": run_log, + } + + +def _save_metrics(output_dir: str, metrics: dict): + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + rewards_path = output_path / "reward_history.txt" + with rewards_path.open("w", encoding="utf-8") as f: + for idx, reward in enumerate(metrics["reward_history"]): + f.write(f"{idx}\t{reward:.6f}\n") + print(f"Saved reward history to {rewards_path}") + + +def _extract_action_content(text: str) -> tuple[str, str]: + action_match = re.search(r"ACTION:\s*(\w+)", text, re.IGNORECASE) + content_match = re.search(r"CONTENT:\s*(.+?)(?:\n|$)", text, re.IGNORECASE | re.DOTALL) + action_type = action_match.group(1).upper() if action_match else "" + content = content_match.group(1).strip() if content_match else "" + return action_type, content + + +def _extract_steps_completed(prompt_text: str) -> list[str]: + match = re.search(r"Steps completed:\s*(\[.*?\])", prompt_text, re.DOTALL) + if not match: + return [] + try: + parsed = ast.literal_eval(match.group(1)) + if isinstance(parsed, list): + return [str(v).upper() for v in parsed] + except Exception: + return [] + return [] + + +def salespath_reward_func(prompts, completions, **kwargs): + """ + Lightweight GRPO reward signal aligned with project rules. + Uses format validity + basic workflow order constraints. + """ + rewards: list[float] = [] + + for prompt, completion in zip(prompts, completions): + action_type, content = _extract_action_content(completion) + steps_completed = _extract_steps_completed(prompt) + + reward = 0.0 + + # Format + valid action + if action_type in VALID_ACTIONS and content: + reward += 0.1 + else: + rewards.append(-0.2) + continue + + # Rule hints + if not steps_completed and action_type != "PROSPECT": + reward -= 0.2 # R06 + if action_type == "PRESENT" and "QUALIFY" not in steps_completed: + reward -= 0.2 # R01 + if action_type == "NEGOTIATE" and "OFFER_DEMO" not in steps_completed: + reward -= 0.2 # R02 + if action_type == "CLOSE" and "OFFER_DEMO" not in steps_completed: + reward -= 0.2 # R09 + + rewards.append(float(reward)) + + return rewards + + +def _build_grpo_dataset_rows(num_rows: int = 128): + rows = [] + prospect_snippets = [ + "We are evaluating options right now.", + "Budget is tight this quarter.", + "Can you explain implementation effort?", + "Pricing seems high compared to alternatives.", + ] + + for i in range(num_rows): + difficulty = (i % 4) + 1 + workflow = WORKFLOW_MAP[difficulty] + steps_completed = [] if i % 3 == 0 else workflow[: min(len(workflow), i % 2 + 1)] + prompt = ( + "You are a B2B sales agent.\n\n" + f"Required workflow steps (in order): {' -> '.join(workflow) if workflow else 'Dynamic'}\n" + f"Current stage: {'START' if not steps_completed else steps_completed[-1]}\n" + f"Steps completed: {steps_completed}\n" + f"Turn: {(i % 8) + 1}/20\n" + "Business rules: R01..R09 must be respected.\n" + f"Prospect response: {prospect_snippets[i % len(prospect_snippets)]}\n\n" + "Respond exactly with:\nACTION: \nCONTENT: " + ) + rows.append({"prompt": prompt}) + return rows + + +def run_grpo(args): + try: + from datasets import Dataset + from trl import GRPOConfig, GRPOTrainer + except Exception as exc: + raise RuntimeError( + "Failed to initialize TRL GRPO stack. On this machine, this is usually due to " + "Windows blocking pyarrow dataset binaries in the local virtualenv. " + "Use the provided Colab notebook (`training/colab_train.ipynb`) for GRPO runs, " + "or fix local pyarrow/datasets installation first." + ) from exc + + _, tokenizer = _load_model_and_tokenizer(args.model_name) + rows = _build_grpo_dataset_rows(args.grpo_dataset_size) + train_dataset = Dataset.from_list(rows) + + config = GRPOConfig( + output_dir=args.output_dir, + learning_rate=args.learning_rate, + per_device_train_batch_size=args.per_device_train_batch_size, + gradient_accumulation_steps=args.gradient_accumulation_steps, + num_generations=args.num_generations, + max_completion_length=args.max_completion_length, + temperature=args.temperature, + logging_steps=args.logging_steps, + save_steps=args.save_steps, + max_steps=args.grpo_steps, + report_to="none", + ) + + trainer = GRPOTrainer( + model=args.model_name, + reward_funcs=salespath_reward_func, + args=config, + train_dataset=train_dataset, + processing_class=tokenizer, + ) + + trainer.train() + trainer.save_model(str(Path(args.output_dir) / "grpo_final")) + print(f"Saved GRPO model to {Path(args.output_dir) / 'grpo_final'}") + + if args.push_to_hub: + trainer.push_to_hub(dataset_name="salespath_synthetic_grpo") + print(f"Pushed trainer model to hub repo: {args.hub_repo}") + + +def parse_args(): + parser = argparse.ArgumentParser(description="SalesPath training entrypoint.") + parser.add_argument("--mode", choices=["curriculum", "grpo"], default="curriculum") + parser.add_argument("--model-name", default=DEFAULT_MODEL) + parser.add_argument("--env-url", default=DEFAULT_ENV_URL) + parser.add_argument("--steps", type=int, default=100, help="Curriculum rollout steps.") + parser.add_argument("--print-every", type=int, default=10) + parser.add_argument("--output-dir", default="salespath_training_outputs") + parser.add_argument("--hub-repo", default="Imsachin010/salespath-qwen25-7b") + parser.add_argument("--push-to-hub", action="store_true") + parser.add_argument("--push-merged", action="store_true") + + # GRPO-specific knobs + parser.add_argument("--grpo-steps", type=int, default=30) + parser.add_argument("--grpo-dataset-size", type=int, default=128) + parser.add_argument("--learning-rate", type=float, default=1e-5) + parser.add_argument("--per-device-train-batch-size", type=int, default=2) + parser.add_argument("--gradient-accumulation-steps", type=int, default=4) + parser.add_argument("--num-generations", type=int, default=8) + parser.add_argument("--max-completion-length", type=int, default=128) + parser.add_argument("--temperature", type=float, default=0.8) + parser.add_argument("--logging-steps", type=int, default=10) + parser.add_argument("--save-steps", type=int, default=100) + + return parser.parse_args() + + +async def _run_curriculum_mode(args): + print(f"Loading model: {args.model_name}") + model, tokenizer = _load_model_and_tokenizer(args.model_name) + print(f"Starting curriculum loop against {args.env_url}") + + metrics = await curriculum_train( + model=model, + tokenizer=tokenizer, + env_url=args.env_url, + total_steps=args.steps, + print_every=args.print_every, + ) + print(f"Final mean reward (last 20): {metrics['mean_reward']:.4f}") + _save_metrics(args.output_dir, metrics) + + if args.push_merged: + hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") + if hasattr(model, "save_pretrained_merged"): + merged_dir = Path(args.output_dir) / "salespath_trained_merged" + model.save_pretrained_merged( + str(merged_dir), + tokenizer, + save_method="merged_16bit", + ) + print(f"Saved merged model to {merged_dir}") + if hf_token and hasattr(model, "push_to_hub_merged"): + model.push_to_hub_merged( + args.hub_repo, + tokenizer, + save_method="merged_16bit", + token=hf_token, + ) + print(f"Pushed merged model to {args.hub_repo}") + else: + print( + "Model does not support merged save APIs. " + "Use an Unsloth merged-capable model to enable --push-merged." + ) + + +async def _main(): + args = parse_args() + if args.mode == "curriculum": + await _run_curriculum_mode(args) + return + + print("Launching TRL GRPO mode...") + run_grpo(args) + + +if __name__ == "__main__": + asyncio.run(_main()) + diff --git a/training/rollout.py b/training/rollout.py new file mode 100644 index 0000000000000000000000000000000000000000..f5ea264e385319b7fd06057bb5bd9b49e45dfbdc --- /dev/null +++ b/training/rollout.py @@ -0,0 +1,143 @@ +# training/rollout.py + +import re +import torch + +from salespath_env.client import SalesPathEnv +from salespath_env.models import SalesPathObservation + + +SYSTEM_PROMPT = """ +You are a B2B sales agent. + +Your goal is to close deals by following a strict workflow. + +Required workflow steps (in order): +{workflow} + +Business rules — NEVER violate these: + +- R01: Must QUALIFY before PRESENT +- R02: Must OFFER_DEMO before NEGOTIATE +- R03: Budget must be known before NEGOTIATE +- R04: Discount only after 2 objections handled +- R05: Cannot repeat same action twice in a row +- R06: First action must always be PROSPECT +- R07: FOLLOW_UP only after prospect goes silent +- R08: DISQUALIFY only if prospect is genuinely unqualified +- R09: Must OFFER_DEMO before CLOSE (difficulty 2+) + +You must respond EXACTLY in this format: + +ACTION: +CONTENT: +""" + + +def parse_action(text: str) -> tuple[str, str]: + """ + Extract ACTION and CONTENT from model output. + Fallback = QUALIFY if parsing fails. + """ + action_match = re.search(r"ACTION:\s*(\w+)", text, re.IGNORECASE) + content_match = re.search(r"CONTENT:\s*(.+?)(?:\n|$)", text, re.IGNORECASE | re.DOTALL) + + action_type = action_match.group(1).upper() if action_match else "QUALIFY" + content = content_match.group(1).strip() if content_match else "Tell me more about your current process." + + return action_type, content + + +def build_prompt(obs: SalesPathObservation, workflow: list[str], tokenizer) -> str: + """Build model prompt from environment observation.""" + messages = [ + { + "role": "system", + "content": SYSTEM_PROMPT.format(workflow=" -> ".join(workflow)), + }, + { + "role": "user", + "content": ( + f"Prospect response: {obs.prospect_response}\n" + f"Current stage: {obs.workflow_stage}\n" + f"Steps completed: {obs.steps_completed}\n" + f"Turn: {obs.turn_number}/20\n" + f"Violations so far: {obs.constraints_violated}\n\n" + "What is your next action?" + ), + }, + ] + return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + +async def run_episode( + model, + tokenizer, + env_url: str, + difficulty: int = 1, + message_timeout_s: float = 300.0, +) -> dict: + """ + Run one full episode using the stateful OpenEnv client. + Returns trajectory + rewards. + """ + DIFFICULTY_WORKFLOW = { + 1: ["QUALIFY", "PRESENT", "CLOSE"], + 2: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "CLOSE"], + 3: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "HANDLE_OBJECTION", "NEGOTIATE", "CLOSE"], + 4: [], + } + + workflow = DIFFICULTY_WORKFLOW[difficulty] + + async with SalesPathEnv(base_url=env_url) as env: + obs = await env.reset(difficulty=difficulty) + trajectory = [] + total_reward = 0.0 + + while not obs.done: + # --- Model inference (CPU/GPU — no network) --- + prompt = build_prompt(obs, workflow, tokenizer) + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=128, + temperature=0.7, + do_sample=True, + ) + + generated = tokenizer.decode( + outputs[0][inputs["input_ids"].shape[1]:], + skip_special_tokens=True, + ) + + action_type, content = parse_action(generated) + + # --- Stateful step via OpenEnv client --- + obs = await env.step( + action_type=action_type, + content=content, + target="", + ) + + trajectory.append({ + "prompt": prompt, + "generated": generated, + "action_type": action_type, + "reward": obs.reward, + "components": obs.reward_components, + "done": obs.done, + }) + + total_reward += obs.reward + + return { + "trajectory": trajectory, + "total_reward": total_reward, + "steps_completed": obs.steps_completed, + "violations": obs.constraints_violated, + "difficulty": difficulty, + } \ No newline at end of file diff --git a/training/test_rollout.py b/training/test_rollout.py new file mode 100644 index 0000000000000000000000000000000000000000..1de48865234c84ad0942285f8807f8bd6fa2848f --- /dev/null +++ b/training/test_rollout.py @@ -0,0 +1,60 @@ +# training/test_rollout.py + +import asyncio +from transformers import AutoModelForCausalLM, AutoTokenizer + +try: + from rollout import run_episode +except ImportError: + from training.rollout import run_episode + + +MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" + +# Local server (already running via uvicorn) — more reliable than HF Space WS +ENV_URL = "http://127.0.0.1:8000" + + +async def main(): + print("Loading small local model...") + + tokenizer = AutoTokenizer.from_pretrained( + MODEL_NAME + ) + + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + dtype="auto", + device_map="auto", + ) + + print("Running single episode...") + + result = await run_episode( + model=model, + tokenizer=tokenizer, + env_url=ENV_URL, + difficulty=1, + message_timeout_s=300.0, # allow up to 5 min per step (CPU inference is slow) + ) + + print("\n========== RESULT ==========") + print( + f"Total Reward: {result['total_reward']:.4f}" + ) + print( + f"Violations: {result['violations']}" + ) + print( + f"Steps Completed: {result['steps_completed']}" + ) + + if result["trajectory"]: + print("\n=== First Generation ===") + print( + result["trajectory"][0]["generated"] + ) + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/training/traingrpo.ipynb b/training/traingrpo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391