salespath-env / salespath_env /server /salespath_environment.py
Imsachin010's picture
first commit
b77d3c5
# salespath_env/server/salespath_environment.py
import uuid
from openenv.core.env_server import Environment
from ..models import (
SalesPathAction,
SalesPathObservation,
SalesPathState,
)
from .task_bank import sample_profile
from .rules import check_rules
from .reward import compute_reward
from .prospect_simulator import ProspectSimulator
DIFFICULTY_WORKFLOW = {
1: [
"QUALIFY",
"PRESENT",
"CLOSE",
],
2: [
"QUALIFY",
"PRESENT",
"HANDLE_OBJECTION",
"OFFER_DEMO",
"CLOSE",
],
3: [
"QUALIFY",
"PRESENT",
"HANDLE_OBJECTION",
"OFFER_DEMO",
"HANDLE_OBJECTION",
"NEGOTIATE",
"CLOSE",
],
4: [], # Agent must determine; DISQUALIFY may be correct
}
MAX_VIOLATIONS_BEFORE_TERMINATE = 3
MAX_TURNS = 20
class SalesPathEnvironment(Environment):
"""
Core OpenEnv environment.
All business logic routes through:
- rules.py
- reward.py
- prospect_simulator.py
"""
def __init__(self):
super().__init__()
self._state = SalesPathState()
self._simulator = ProspectSimulator()
def reset(self, difficulty: int = 1) -> SalesPathObservation:
"""
Start a new episode.
"""
profile = sample_profile(difficulty)
hidden_state = {
"true_budget": profile.true_budget,
"close_threshold": profile.close_threshold,
"stall_probability": profile.stall_probability,
"num_objections": {
1: 0,
2: 1,
3: 2,
4: 2,
}[difficulty],
"revealed_budget": (
"high"
if profile.true_budget >= 0.7
else "medium"
if profile.true_budget >= 0.4
else "low"
),
}
public_profile = {
"company_name": profile.company_name,
"company_size": profile.company_size,
"industry": profile.industry,
"budget_signal": profile.budget_signal,
"pain_points": profile.pain_points,
"decision_maker": profile.decision_maker,
}
self._state = SalesPathState(
episode_id=str(uuid.uuid4()),
prospect_profile=public_profile,
conversation_history=[],
workflow_stage="START",
required_workflow=DIFFICULTY_WORKFLOW[difficulty],
steps_completed=[],
constraints_violated=[],
objections_handled=0,
turn_number=0,
difficulty=difficulty,
done=False,
hidden_state=hidden_state,
)
intro_message = (
f"You are engaging {profile.company_name}, "
f"a {profile.company_size} {profile.industry} company. "
f"Pain points: {', '.join(profile.pain_points)}. "
f"Begin the sales conversation."
)
return SalesPathObservation(
prospect_response=intro_message,
workflow_stage="START",
constraints_violated=[],
steps_completed=[],
turn_number=0,
reward=0.0,
reward_components={},
done=False,
info={
"difficulty": difficulty,
"episode_id": self._state.episode_id,
},
)
def step(
self,
action: SalesPathAction,
) -> SalesPathObservation:
"""
One environment transition.
"""
state = self._state
# -----------------------------------
# Advance turn
# -----------------------------------
state.turn_number += 1
# -----------------------------------
# Strict action validation
# Must return observation, never crash
# -----------------------------------
if not action.is_valid():
return SalesPathObservation(
prospect_response="Invalid action type.",
workflow_stage=state.workflow_stage,
constraints_violated=list(state.constraints_violated),
steps_completed=list(state.steps_completed),
turn_number=state.turn_number,
reward=-0.2,
reward_components={
"r_format": -0.1,
},
done=False,
info={
"error": (
f"Invalid action_type: "
f"{action.action_type}"
)
},
)
# -----------------------------------
# Rule checks
# -----------------------------------
new_violations = check_rules(
state,
action,
)
state.constraints_violated.extend(
new_violations
)
# -----------------------------------
# Record agent action
# -----------------------------------
state.conversation_history.append(
{
"turn": state.turn_number,
"speaker": "agent",
"action_type": action.action_type,
"content": action.content,
}
)
# -----------------------------------
# Update workflow state
# -----------------------------------
if action.action_type not in state.steps_completed:
state.steps_completed.append(
action.action_type
)
state.workflow_stage = action.action_type
# -----------------------------------
# Prospect response
# -----------------------------------
response_token, response_text = (
self._simulator.respond(
action,
state,
)
)
state.conversation_history.append(
{
"turn": state.turn_number,
"speaker": "prospect",
"response_token": response_token,
"text": response_text,
}
)
# -----------------------------------
# Episode termination
# -----------------------------------
terminal_actions = {
"CLOSE",
"DISQUALIFY",
}
too_many_violations = (
len(state.constraints_violated)
>= MAX_VIOLATIONS_BEFORE_TERMINATE
)
turn_limit_reached = (
state.turn_number >= MAX_TURNS
)
done = (
action.action_type in terminal_actions
or too_many_violations
or turn_limit_reached
)
state.done = done
# -----------------------------------
# Reward
# -----------------------------------
total_reward, components = (
compute_reward(
state=state,
action=action,
response_token=response_token,
new_violations=new_violations,
episode_done=done,
)
)
return SalesPathObservation(
prospect_response=response_text,
workflow_stage=state.workflow_stage,
constraints_violated=list(
state.constraints_violated
),
steps_completed=list(
state.steps_completed
),
turn_number=state.turn_number,
reward=total_reward,
reward_components=components,
done=done,
info={
"response_token": response_token,
"new_violations": new_violations,
"episode_id": state.episode_id,
},
)
@property
def state(self) -> SalesPathState:
return self._state