Spaces:
Runtime error
Runtime error
| # salespath_env/server/salespath_environment.py | |
| import uuid | |
| from openenv.core.env_server import Environment | |
| from ..models import ( | |
| SalesPathAction, | |
| SalesPathObservation, | |
| SalesPathState, | |
| ) | |
| from .task_bank import sample_profile | |
| from .rules import check_rules | |
| from .reward import compute_reward | |
| from .prospect_simulator import ProspectSimulator | |
| DIFFICULTY_WORKFLOW = { | |
| 1: [ | |
| "QUALIFY", | |
| "PRESENT", | |
| "CLOSE", | |
| ], | |
| 2: [ | |
| "QUALIFY", | |
| "PRESENT", | |
| "HANDLE_OBJECTION", | |
| "OFFER_DEMO", | |
| "CLOSE", | |
| ], | |
| 3: [ | |
| "QUALIFY", | |
| "PRESENT", | |
| "HANDLE_OBJECTION", | |
| "OFFER_DEMO", | |
| "HANDLE_OBJECTION", | |
| "NEGOTIATE", | |
| "CLOSE", | |
| ], | |
| 4: [], # Agent must determine; DISQUALIFY may be correct | |
| } | |
| MAX_VIOLATIONS_BEFORE_TERMINATE = 3 | |
| MAX_TURNS = 20 | |
| class SalesPathEnvironment(Environment): | |
| """ | |
| Core OpenEnv environment. | |
| All business logic routes through: | |
| - rules.py | |
| - reward.py | |
| - prospect_simulator.py | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| self._state = SalesPathState() | |
| self._simulator = ProspectSimulator() | |
| def reset(self, difficulty: int = 1) -> SalesPathObservation: | |
| """ | |
| Start a new episode. | |
| """ | |
| profile = sample_profile(difficulty) | |
| hidden_state = { | |
| "true_budget": profile.true_budget, | |
| "close_threshold": profile.close_threshold, | |
| "stall_probability": profile.stall_probability, | |
| "num_objections": { | |
| 1: 0, | |
| 2: 1, | |
| 3: 2, | |
| 4: 2, | |
| }[difficulty], | |
| "revealed_budget": ( | |
| "high" | |
| if profile.true_budget >= 0.7 | |
| else "medium" | |
| if profile.true_budget >= 0.4 | |
| else "low" | |
| ), | |
| } | |
| public_profile = { | |
| "company_name": profile.company_name, | |
| "company_size": profile.company_size, | |
| "industry": profile.industry, | |
| "budget_signal": profile.budget_signal, | |
| "pain_points": profile.pain_points, | |
| "decision_maker": profile.decision_maker, | |
| } | |
| self._state = SalesPathState( | |
| episode_id=str(uuid.uuid4()), | |
| prospect_profile=public_profile, | |
| conversation_history=[], | |
| workflow_stage="START", | |
| required_workflow=DIFFICULTY_WORKFLOW[difficulty], | |
| steps_completed=[], | |
| constraints_violated=[], | |
| objections_handled=0, | |
| turn_number=0, | |
| difficulty=difficulty, | |
| done=False, | |
| hidden_state=hidden_state, | |
| ) | |
| intro_message = ( | |
| f"You are engaging {profile.company_name}, " | |
| f"a {profile.company_size} {profile.industry} company. " | |
| f"Pain points: {', '.join(profile.pain_points)}. " | |
| f"Begin the sales conversation." | |
| ) | |
| return SalesPathObservation( | |
| prospect_response=intro_message, | |
| workflow_stage="START", | |
| constraints_violated=[], | |
| steps_completed=[], | |
| turn_number=0, | |
| reward=0.0, | |
| reward_components={}, | |
| done=False, | |
| info={ | |
| "difficulty": difficulty, | |
| "episode_id": self._state.episode_id, | |
| }, | |
| ) | |
| def step( | |
| self, | |
| action: SalesPathAction, | |
| ) -> SalesPathObservation: | |
| """ | |
| One environment transition. | |
| """ | |
| state = self._state | |
| # ----------------------------------- | |
| # Advance turn | |
| # ----------------------------------- | |
| state.turn_number += 1 | |
| # ----------------------------------- | |
| # Strict action validation | |
| # Must return observation, never crash | |
| # ----------------------------------- | |
| if not action.is_valid(): | |
| return SalesPathObservation( | |
| prospect_response="Invalid action type.", | |
| workflow_stage=state.workflow_stage, | |
| constraints_violated=list(state.constraints_violated), | |
| steps_completed=list(state.steps_completed), | |
| turn_number=state.turn_number, | |
| reward=-0.2, | |
| reward_components={ | |
| "r_format": -0.1, | |
| }, | |
| done=False, | |
| info={ | |
| "error": ( | |
| f"Invalid action_type: " | |
| f"{action.action_type}" | |
| ) | |
| }, | |
| ) | |
| # ----------------------------------- | |
| # Rule checks | |
| # ----------------------------------- | |
| new_violations = check_rules( | |
| state, | |
| action, | |
| ) | |
| state.constraints_violated.extend( | |
| new_violations | |
| ) | |
| # ----------------------------------- | |
| # Record agent action | |
| # ----------------------------------- | |
| state.conversation_history.append( | |
| { | |
| "turn": state.turn_number, | |
| "speaker": "agent", | |
| "action_type": action.action_type, | |
| "content": action.content, | |
| } | |
| ) | |
| # ----------------------------------- | |
| # Update workflow state | |
| # ----------------------------------- | |
| if action.action_type not in state.steps_completed: | |
| state.steps_completed.append( | |
| action.action_type | |
| ) | |
| state.workflow_stage = action.action_type | |
| # ----------------------------------- | |
| # Prospect response | |
| # ----------------------------------- | |
| response_token, response_text = ( | |
| self._simulator.respond( | |
| action, | |
| state, | |
| ) | |
| ) | |
| state.conversation_history.append( | |
| { | |
| "turn": state.turn_number, | |
| "speaker": "prospect", | |
| "response_token": response_token, | |
| "text": response_text, | |
| } | |
| ) | |
| # ----------------------------------- | |
| # Episode termination | |
| # ----------------------------------- | |
| terminal_actions = { | |
| "CLOSE", | |
| "DISQUALIFY", | |
| } | |
| too_many_violations = ( | |
| len(state.constraints_violated) | |
| >= MAX_VIOLATIONS_BEFORE_TERMINATE | |
| ) | |
| turn_limit_reached = ( | |
| state.turn_number >= MAX_TURNS | |
| ) | |
| done = ( | |
| action.action_type in terminal_actions | |
| or too_many_violations | |
| or turn_limit_reached | |
| ) | |
| state.done = done | |
| # ----------------------------------- | |
| # Reward | |
| # ----------------------------------- | |
| total_reward, components = ( | |
| compute_reward( | |
| state=state, | |
| action=action, | |
| response_token=response_token, | |
| new_violations=new_violations, | |
| episode_done=done, | |
| ) | |
| ) | |
| return SalesPathObservation( | |
| prospect_response=response_text, | |
| workflow_stage=state.workflow_stage, | |
| constraints_violated=list( | |
| state.constraints_violated | |
| ), | |
| steps_completed=list( | |
| state.steps_completed | |
| ), | |
| turn_number=state.turn_number, | |
| reward=total_reward, | |
| reward_components=components, | |
| done=done, | |
| info={ | |
| "response_token": response_token, | |
| "new_violations": new_violations, | |
| "episode_id": state.episode_id, | |
| }, | |
| ) | |
| def state(self) -> SalesPathState: | |
| return self._state |