Spaces:
Runtime error
Runtime error
File size: 5,338 Bytes
b77d3c5 ae60795 b77d3c5 ae60795 b77d3c5 ae60795 b77d3c5 ae60795 b77d3c5 ae60795 b77d3c5 ae60795 b77d3c5 ae60795 b77d3c5 ae60795 b77d3c5 ae60795 b77d3c5 ae60795 b77d3c5 ae60795 b77d3c5 ae60795 b77d3c5 ae60795 b77d3c5 ae60795 b77d3c5 ae60795 b77d3c5 ae60795 b77d3c5 ae60795 b77d3c5 ae60795 b77d3c5 ae60795 b77d3c5 ae60795 b77d3c5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | # training/rollout.py
# Uses httpx (async HTTP) to call /reset and /step REST endpoints.
# This is robust for both local server and HF Space URL.
import sys
import os
import re
import httpx
import torch
# Ensure the project root is on sys.path regardless of how this is invoked
_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if _ROOT not in sys.path:
sys.path.insert(0, _ROOT)
from salespath_env.models import SalesPathObservation
SYSTEM_PROMPT = """You are a B2B sales agent. Your goal is to close deals by following a strict workflow.
Required workflow steps (in order): {workflow}
Business rules — NEVER violate these:
- R01: Must QUALIFY before PRESENT
- R02: Must OFFER_DEMO before NEGOTIATE
- R03: Budget must be known before NEGOTIATE
- R04: Discount only after 2 objections handled
- R05: Cannot repeat same action twice in a row
- R06: First action must always be PROSPECT
- R07: FOLLOW_UP only after prospect goes silent
- R08: DISQUALIFY only if prospect is genuinely unqualified
- R09: Must OFFER_DEMO before CLOSE (difficulty 2+)
Respond EXACTLY in this format:
ACTION: <one of: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY>
CONTENT: <your message to the prospect>"""
def parse_action(text: str) -> tuple[str, str]:
"""Extract ACTION and CONTENT from model output."""
action_match = re.search(r"ACTION:\s*(\w+)", text, re.IGNORECASE)
content_match = re.search(r"CONTENT:\s*(.+?)(?:\n|$)", text, re.IGNORECASE | re.DOTALL)
action_type = action_match.group(1).upper() if action_match else "QUALIFY"
content = content_match.group(1).strip() if content_match else "Tell me more about your needs."
return action_type, content
def build_prompt(obs: SalesPathObservation, workflow: list[str], tokenizer) -> str:
messages = [
{"role": "system", "content": SYSTEM_PROMPT.format(
workflow=" -> ".join(workflow) if workflow else "Dynamic — determine best path"
)},
{"role": "user", "content": (
f"Prospect response: {obs.prospect_response}\n"
f"Current stage: {obs.workflow_stage}\n"
f"Steps completed: {obs.steps_completed}\n"
f"Turn: {obs.turn_number}/20\n"
f"Violations so far: {obs.constraints_violated}\n\n"
"What is your next action?"
)},
]
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
def _parse_obs(data: dict) -> SalesPathObservation:
"""Parse observation — handles nested 'observation' key from server."""
obs_data = data.get("observation", data)
# Drop unknown keys that Pydantic would reject
known = SalesPathObservation.model_fields.keys()
obs_data = {k: v for k, v in obs_data.items() if k in known}
return SalesPathObservation(**obs_data)
async def run_episode(
model,
tokenizer,
env_url: str,
difficulty: int = 1,
message_timeout_s: float = 300.0,
) -> dict:
"""Run one full episode via REST. Returns trajectory + rewards."""
DIFFICULTY_WORKFLOW = {
1: ["QUALIFY", "PRESENT", "CLOSE"],
2: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "CLOSE"],
3: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO",
"HANDLE_OBJECTION", "NEGOTIATE", "CLOSE"],
4: [],
}
workflow = DIFFICULTY_WORKFLOW[difficulty]
async with httpx.AsyncClient(base_url=env_url, timeout=message_timeout_s) as client:
# --- Reset ---
reset_resp = await client.post("/reset", json={"difficulty": difficulty})
reset_resp.raise_for_status()
obs = _parse_obs(reset_resp.json())
trajectory = []
total_reward = 0.0
while not obs.done:
# --- Model inference ---
prompt = build_prompt(obs, workflow, tokenizer)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=128,
temperature=0.8,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
generated = tokenizer.decode(
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True,
)
action_type, content = parse_action(generated)
# --- Step via REST ---
step_resp = await client.post(
"/step",
json={"action": {"action_type": action_type, "content": content, "target": ""}},
)
step_resp.raise_for_status()
obs = _parse_obs(step_resp.json())
trajectory.append({
"prompt": prompt,
"generated": generated,
"action_type": action_type,
"reward": obs.reward,
"components": obs.reward_components,
"done": obs.done,
})
total_reward += obs.reward
return {
"trajectory": trajectory,
"total_reward": total_reward,
"steps_completed": obs.steps_completed,
"violations": obs.constraints_violated,
"difficulty": difficulty,
} |