# training/rollout.py # Uses httpx (async HTTP) to call /reset and /step REST endpoints. # This is robust for both local server and HF Space URL. import sys import os import re import httpx import torch # Ensure the project root is on sys.path regardless of how this is invoked _ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) if _ROOT not in sys.path: sys.path.insert(0, _ROOT) from salespath_env.models import SalesPathObservation SYSTEM_PROMPT = """You are a B2B sales agent. Your goal is to close deals by following a strict workflow. Required workflow steps (in order): {workflow} Business rules — NEVER violate these: - R01: Must QUALIFY before PRESENT - R02: Must OFFER_DEMO before NEGOTIATE - R03: Budget must be known before NEGOTIATE - R04: Discount only after 2 objections handled - R05: Cannot repeat same action twice in a row - R06: First action must always be PROSPECT - R07: FOLLOW_UP only after prospect goes silent - R08: DISQUALIFY only if prospect is genuinely unqualified - R09: Must OFFER_DEMO before CLOSE (difficulty 2+) Respond EXACTLY in this format: ACTION: CONTENT: """ def parse_action(text: str) -> tuple[str, str]: """Extract ACTION and CONTENT from model output.""" action_match = re.search(r"ACTION:\s*(\w+)", text, re.IGNORECASE) content_match = re.search(r"CONTENT:\s*(.+?)(?:\n|$)", text, re.IGNORECASE | re.DOTALL) action_type = action_match.group(1).upper() if action_match else "QUALIFY" content = content_match.group(1).strip() if content_match else "Tell me more about your needs." return action_type, content def build_prompt(obs: SalesPathObservation, workflow: list[str], tokenizer) -> str: messages = [ {"role": "system", "content": SYSTEM_PROMPT.format( workflow=" -> ".join(workflow) if workflow else "Dynamic — determine best path" )}, {"role": "user", "content": ( f"Prospect response: {obs.prospect_response}\n" f"Current stage: {obs.workflow_stage}\n" f"Steps completed: {obs.steps_completed}\n" f"Turn: {obs.turn_number}/20\n" f"Violations so far: {obs.constraints_violated}\n\n" "What is your next action?" )}, ] return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) def _parse_obs(data: dict) -> SalesPathObservation: """Parse observation — handles nested 'observation' key from server.""" obs_data = data.get("observation", data) # Drop unknown keys that Pydantic would reject known = SalesPathObservation.model_fields.keys() obs_data = {k: v for k, v in obs_data.items() if k in known} return SalesPathObservation(**obs_data) async def run_episode( model, tokenizer, env_url: str, difficulty: int = 1, message_timeout_s: float = 300.0, ) -> dict: """Run one full episode via REST. Returns trajectory + rewards.""" DIFFICULTY_WORKFLOW = { 1: ["QUALIFY", "PRESENT", "CLOSE"], 2: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "CLOSE"], 3: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "HANDLE_OBJECTION", "NEGOTIATE", "CLOSE"], 4: [], } workflow = DIFFICULTY_WORKFLOW[difficulty] async with httpx.AsyncClient(base_url=env_url, timeout=message_timeout_s) as client: # --- Reset --- reset_resp = await client.post("/reset", json={"difficulty": difficulty}) reset_resp.raise_for_status() obs = _parse_obs(reset_resp.json()) trajectory = [] total_reward = 0.0 while not obs.done: # --- Model inference --- prompt = build_prompt(obs, workflow, tokenizer) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=128, temperature=0.8, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) generated = tokenizer.decode( outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True, ) action_type, content = parse_action(generated) # --- Step via REST --- step_resp = await client.post( "/step", json={"action": {"action_type": action_type, "content": content, "target": ""}}, ) step_resp.raise_for_status() obs = _parse_obs(step_resp.json()) trajectory.append({ "prompt": prompt, "generated": generated, "action_type": action_type, "reward": obs.reward, "components": obs.reward_components, "done": obs.done, }) total_reward += obs.reward return { "trajectory": trajectory, "total_reward": total_reward, "steps_completed": obs.steps_completed, "violations": obs.constraints_violated, "difficulty": difficulty, }