File size: 5,338 Bytes
b77d3c5
ae60795
 
b77d3c5
ae60795
 
b77d3c5
ae60795
b77d3c5
 
ae60795
 
 
 
b77d3c5
ae60795
b77d3c5
 
ae60795
b77d3c5
ae60795
b77d3c5
 
 
 
 
 
 
 
 
 
 
 
ae60795
 
 
b77d3c5
 
 
ae60795
b77d3c5
 
 
ae60795
b77d3c5
 
 
 
 
ae60795
 
 
 
 
 
 
 
 
 
 
b77d3c5
 
 
 
ae60795
 
 
 
 
 
 
 
 
b77d3c5
 
 
 
 
 
 
ae60795
b77d3c5
 
 
ae60795
 
b77d3c5
 
 
 
ae60795
 
 
 
 
 
 
b77d3c5
 
 
 
ae60795
b77d3c5
 
 
 
 
 
 
ae60795
b77d3c5
ae60795
b77d3c5
 
 
 
 
 
 
 
 
ae60795
 
 
 
b77d3c5
ae60795
 
b77d3c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# training/rollout.py
# Uses httpx (async HTTP) to call /reset and /step REST endpoints.
# This is robust for both local server and HF Space URL.

import sys
import os
import re
import httpx
import torch

# Ensure the project root is on sys.path regardless of how this is invoked
_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if _ROOT not in sys.path:
    sys.path.insert(0, _ROOT)

from salespath_env.models import SalesPathObservation


SYSTEM_PROMPT = """You are a B2B sales agent. Your goal is to close deals by following a strict workflow.

Required workflow steps (in order): {workflow}

Business rules — NEVER violate these:
- R01: Must QUALIFY before PRESENT
- R02: Must OFFER_DEMO before NEGOTIATE
- R03: Budget must be known before NEGOTIATE
- R04: Discount only after 2 objections handled
- R05: Cannot repeat same action twice in a row
- R06: First action must always be PROSPECT
- R07: FOLLOW_UP only after prospect goes silent
- R08: DISQUALIFY only if prospect is genuinely unqualified
- R09: Must OFFER_DEMO before CLOSE (difficulty 2+)

Respond EXACTLY in this format:
ACTION: <one of: PROSPECT, QUALIFY, PRESENT, HANDLE_OBJECTION, OFFER_DEMO, NEGOTIATE, CLOSE, FOLLOW_UP, DISQUALIFY>
CONTENT: <your message to the prospect>"""


def parse_action(text: str) -> tuple[str, str]:
    """Extract ACTION and CONTENT from model output."""
    action_match = re.search(r"ACTION:\s*(\w+)", text, re.IGNORECASE)
    content_match = re.search(r"CONTENT:\s*(.+?)(?:\n|$)", text, re.IGNORECASE | re.DOTALL)
    action_type = action_match.group(1).upper() if action_match else "QUALIFY"
    content = content_match.group(1).strip() if content_match else "Tell me more about your needs."
    return action_type, content


def build_prompt(obs: SalesPathObservation, workflow: list[str], tokenizer) -> str:
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT.format(
            workflow=" -> ".join(workflow) if workflow else "Dynamic — determine best path"
        )},
        {"role": "user", "content": (
            f"Prospect response: {obs.prospect_response}\n"
            f"Current stage: {obs.workflow_stage}\n"
            f"Steps completed: {obs.steps_completed}\n"
            f"Turn: {obs.turn_number}/20\n"
            f"Violations so far: {obs.constraints_violated}\n\n"
            "What is your next action?"
        )},
    ]
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)


def _parse_obs(data: dict) -> SalesPathObservation:
    """Parse observation — handles nested 'observation' key from server."""
    obs_data = data.get("observation", data)
    # Drop unknown keys that Pydantic would reject
    known = SalesPathObservation.model_fields.keys()
    obs_data = {k: v for k, v in obs_data.items() if k in known}
    return SalesPathObservation(**obs_data)


async def run_episode(
    model,
    tokenizer,
    env_url: str,
    difficulty: int = 1,
    message_timeout_s: float = 300.0,
) -> dict:
    """Run one full episode via REST. Returns trajectory + rewards."""
    DIFFICULTY_WORKFLOW = {
        1: ["QUALIFY", "PRESENT", "CLOSE"],
        2: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "CLOSE"],
        3: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO",
            "HANDLE_OBJECTION", "NEGOTIATE", "CLOSE"],
        4: [],
    }
    workflow = DIFFICULTY_WORKFLOW[difficulty]

    async with httpx.AsyncClient(base_url=env_url, timeout=message_timeout_s) as client:

        # --- Reset ---
        reset_resp = await client.post("/reset", json={"difficulty": difficulty})
        reset_resp.raise_for_status()
        obs = _parse_obs(reset_resp.json())

        trajectory = []
        total_reward = 0.0

        while not obs.done:
            # --- Model inference ---
            prompt = build_prompt(obs, workflow, tokenizer)
            inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=128,
                    temperature=0.8,
                    do_sample=True,
                    pad_token_id=tokenizer.eos_token_id,
                )

            generated = tokenizer.decode(
                outputs[0][inputs["input_ids"].shape[1]:],
                skip_special_tokens=True,
            )

            action_type, content = parse_action(generated)

            # --- Step via REST ---
            step_resp = await client.post(
                "/step",
                json={"action": {"action_type": action_type, "content": content, "target": ""}},
            )
            step_resp.raise_for_status()
            obs = _parse_obs(step_resp.json())

            trajectory.append({
                "prompt": prompt,
                "generated": generated,
                "action_type": action_type,
                "reward": obs.reward,
                "components": obs.reward_components,
                "done": obs.done,
            })

            total_reward += obs.reward

    return {
        "trajectory": trajectory,
        "total_reward": total_reward,
        "steps_completed": obs.steps_completed,
        "violations": obs.constraints_violated,
        "difficulty": difficulty,
    }