Spaces:
Runtime error
Runtime error
Upload 8 files
Browse files- env/__init__.py +0 -0
- env/environment.py +366 -0
- env/models.py +130 -0
- env/tickets.py +153 -0
- graders/__init__.py +0 -0
- graders/graders.py +194 -0
- tests/__init__.py +0 -0
- tests/test_env.py +226 -0
env/__init__.py
ADDED
|
File without changes
|
env/environment.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CustomerSupportEnv β Core environment implementing the OpenEnv spec.
|
| 3 |
+
|
| 4 |
+
step(action) β StepResult(observation, reward, done, info)
|
| 5 |
+
reset() β Observation
|
| 6 |
+
state() β Observation
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import random
|
| 11 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 12 |
+
|
| 13 |
+
from env.models import (
|
| 14 |
+
Action, ActionType, Category, Message, Observation,
|
| 15 |
+
Priority, Reward, Sentiment, StepResult, TaskSpec, TicketStatus
|
| 16 |
+
)
|
| 17 |
+
from env.tickets import TICKETS, get_ticket
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# ββ Reward constants ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 21 |
+
|
| 22 |
+
R_SEARCH_KB = 2.0
|
| 23 |
+
R_EMPATHIZE = 1.0
|
| 24 |
+
R_ASK_CLARIFY = 1.0
|
| 25 |
+
R_OFFER_SOLUTION = 3.0
|
| 26 |
+
R_RESOLVE_GOOD = 5.0
|
| 27 |
+
R_RESOLVE_BAD = -3.0
|
| 28 |
+
R_ESCALATE = -1.0
|
| 29 |
+
R_DUPLICATE_ACTION = -1.0
|
| 30 |
+
R_SKIP_KB_PENALTY = -1.0
|
| 31 |
+
R_TIMEOUT = -2.0
|
| 32 |
+
|
| 33 |
+
CSAT_WEIGHTS = {
|
| 34 |
+
"empathized": 0.3,
|
| 35 |
+
"kb_searched": 0.3,
|
| 36 |
+
"solution_offered": 0.4,
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
# Optimal trajectory (used for efficiency scoring)
|
| 40 |
+
OPTIMAL_STEPS = 4 # search_kb, empathize, offer_solution, resolve
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ββ Task definitions ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 44 |
+
|
| 45 |
+
TASKS: Dict[str, TaskSpec] = {
|
| 46 |
+
"task_1": TaskSpec(
|
| 47 |
+
task_id="task_1",
|
| 48 |
+
name="Resolve a Standard Auth Ticket",
|
| 49 |
+
description=(
|
| 50 |
+
"Handle a frustrated customer locked out of their account. "
|
| 51 |
+
"The agent must search the knowledge base, acknowledge the "
|
| 52 |
+
"customer's frustration, offer a concrete solution, and resolve the ticket. "
|
| 53 |
+
"EASY: single-step fix, KB articles directly address the issue."
|
| 54 |
+
),
|
| 55 |
+
difficulty="easy",
|
| 56 |
+
ticket_id="TKT-001",
|
| 57 |
+
success_criteria=[
|
| 58 |
+
"search_kb called before offer_solution",
|
| 59 |
+
"empathize called at least once",
|
| 60 |
+
"offer_solution payload mentions unlock or reset",
|
| 61 |
+
"resolve called to close episode"
|
| 62 |
+
],
|
| 63 |
+
max_turns=8,
|
| 64 |
+
optimal_actions=["search_kb", "empathize", "offer_solution", "resolve"]
|
| 65 |
+
),
|
| 66 |
+
"task_2": TaskSpec(
|
| 67 |
+
task_id="task_2",
|
| 68 |
+
name="Handle a Multi-Step Billing Dispute",
|
| 69 |
+
description=(
|
| 70 |
+
"Resolve a billing discrepancy for a customer who was overcharged after "
|
| 71 |
+
"a plan downgrade. The agent must clarify details, check the KB, diagnose "
|
| 72 |
+
"the root cause, provide a specific dollar credit, and confirm the fix. "
|
| 73 |
+
"MEDIUM: requires clarification before diagnosis; generic solutions penalised."
|
| 74 |
+
),
|
| 75 |
+
difficulty="medium",
|
| 76 |
+
ticket_id="TKT-003",
|
| 77 |
+
success_criteria=[
|
| 78 |
+
"ask_clarify called at least once",
|
| 79 |
+
"search_kb called",
|
| 80 |
+
"offer_solution mentions credit or refund amount",
|
| 81 |
+
"resolve called"
|
| 82 |
+
],
|
| 83 |
+
max_turns=10,
|
| 84 |
+
optimal_actions=["search_kb", "ask_clarify", "empathize", "offer_solution", "resolve"]
|
| 85 |
+
),
|
| 86 |
+
"task_3": TaskSpec(
|
| 87 |
+
task_id="task_3",
|
| 88 |
+
name="Triage a Critical Time-Sensitive Bug Report",
|
| 89 |
+
description=(
|
| 90 |
+
"An enterprise customer has a compliance deadline tomorrow and a data export "
|
| 91 |
+
"stuck at 12% for 6 hours. The agent must quickly diagnose the issue, "
|
| 92 |
+
"deploy an immediate workaround (priority queue), offer a backup strategy "
|
| 93 |
+
"(partial export), and close with a monitoring commitment. "
|
| 94 |
+
"HARD: time pressure, two-part solution required, escalation penalised, "
|
| 95 |
+
"generic solutions score low."
|
| 96 |
+
),
|
| 97 |
+
difficulty="hard",
|
| 98 |
+
ticket_id="TKT-006",
|
| 99 |
+
success_criteria=[
|
| 100 |
+
"search_kb called",
|
| 101 |
+
"offer_solution mentions priority queue AND partial export",
|
| 102 |
+
"solution demonstrates urgency awareness",
|
| 103 |
+
"resolve called without escalation"
|
| 104 |
+
],
|
| 105 |
+
max_turns=8,
|
| 106 |
+
optimal_actions=["search_kb", "empathize", "ask_clarify", "offer_solution", "resolve"]
|
| 107 |
+
)
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# ββ Environment βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 112 |
+
|
| 113 |
+
class CustomerSupportEnv:
|
| 114 |
+
"""
|
| 115 |
+
OpenEnv-compatible customer support RL environment.
|
| 116 |
+
|
| 117 |
+
Usage:
|
| 118 |
+
env = CustomerSupportEnv(task_id="task_1")
|
| 119 |
+
obs = env.reset()
|
| 120 |
+
result = env.step(Action(action_type="search_kb"))
|
| 121 |
+
current = env.state()
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
VERSION = "1.0.0"
|
| 125 |
+
|
| 126 |
+
def __init__(self, task_id: str = "task_1", seed: Optional[int] = None):
|
| 127 |
+
if task_id not in TASKS:
|
| 128 |
+
raise ValueError(f"Unknown task_id '{task_id}'. Valid: {list(TASKS.keys())}")
|
| 129 |
+
self.task_id = task_id
|
| 130 |
+
self.task = TASKS[task_id]
|
| 131 |
+
self._seed = seed
|
| 132 |
+
self._rng = random.Random(seed)
|
| 133 |
+
self._obs: Observation = self._make_idle_obs()
|
| 134 |
+
|
| 135 |
+
# ββ OpenEnv API βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 136 |
+
|
| 137 |
+
def reset(self) -> Observation:
|
| 138 |
+
"""Reset the environment and return the initial observation."""
|
| 139 |
+
ticket_data = get_ticket(self.task.ticket_id)
|
| 140 |
+
history = [
|
| 141 |
+
Message(role=m["role"], text=m["text"], turn=m.get("turn", 0))
|
| 142 |
+
for m in ticket_data["history"]
|
| 143 |
+
]
|
| 144 |
+
self._obs = Observation(
|
| 145 |
+
ticket_id=self.task.ticket_id,
|
| 146 |
+
task_id=self.task_id,
|
| 147 |
+
status=TicketStatus.OPEN,
|
| 148 |
+
sentiment=ticket_data["sentiment"],
|
| 149 |
+
priority=ticket_data["priority"],
|
| 150 |
+
category=ticket_data["category"],
|
| 151 |
+
turn=0,
|
| 152 |
+
max_turns=self.task.max_turns,
|
| 153 |
+
history=history,
|
| 154 |
+
kb_results=[],
|
| 155 |
+
kb_searched=False,
|
| 156 |
+
empathized=False,
|
| 157 |
+
clarified=False,
|
| 158 |
+
solution_offered=False,
|
| 159 |
+
escalated=False,
|
| 160 |
+
cumulative_reward=0.0,
|
| 161 |
+
done=False,
|
| 162 |
+
info={"task_name": self.task.name, "difficulty": self.task.difficulty}
|
| 163 |
+
)
|
| 164 |
+
return self._obs
|
| 165 |
+
|
| 166 |
+
def step(self, action: Action) -> StepResult:
|
| 167 |
+
"""
|
| 168 |
+
Advance the environment by one step.
|
| 169 |
+
Returns StepResult(observation, reward, done, info).
|
| 170 |
+
"""
|
| 171 |
+
if self._obs.status == TicketStatus.IDLE:
|
| 172 |
+
raise RuntimeError("Call reset() before step().")
|
| 173 |
+
if self._obs.done:
|
| 174 |
+
raise RuntimeError("Episode is done. Call reset() to start a new episode.")
|
| 175 |
+
|
| 176 |
+
obs = self._obs
|
| 177 |
+
ticket = get_ticket(obs.ticket_id)
|
| 178 |
+
action_type = ActionType(action.action_type)
|
| 179 |
+
|
| 180 |
+
step_reward, reason, penalty = 0.0, "", 0.0
|
| 181 |
+
done = False
|
| 182 |
+
info: Dict[str, Any] = {}
|
| 183 |
+
|
| 184 |
+
obs.turn += 1
|
| 185 |
+
|
| 186 |
+
# ββ Dispatch action ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 187 |
+
|
| 188 |
+
if action_type == ActionType.SEARCH_KB:
|
| 189 |
+
if obs.kb_searched:
|
| 190 |
+
penalty = R_DUPLICATE_ACTION
|
| 191 |
+
reason = "Duplicate search_kb β no new information."
|
| 192 |
+
else:
|
| 193 |
+
obs.kb_searched = True
|
| 194 |
+
obs.kb_results = ticket["kb_articles"]
|
| 195 |
+
step_reward = R_SEARCH_KB
|
| 196 |
+
reason = f"Retrieved {len(obs.kb_results)} KB articles."
|
| 197 |
+
|
| 198 |
+
elif action_type == ActionType.EMPATHIZE:
|
| 199 |
+
if obs.empathized:
|
| 200 |
+
reason = "Already empathized β no incremental reward."
|
| 201 |
+
else:
|
| 202 |
+
obs.empathized = True
|
| 203 |
+
step_reward = R_EMPATHIZE
|
| 204 |
+
reason = "Empathy acknowledged by customer."
|
| 205 |
+
obs.history.append(Message(
|
| 206 |
+
role="agent",
|
| 207 |
+
text=self._rng.choice([
|
| 208 |
+
"I completely understand how frustrating this situation must be. Let me help you immediately.",
|
| 209 |
+
"I'm sorry you're going through this β that sounds really stressful. Let's fix it right away.",
|
| 210 |
+
"Thank you for reaching out. I can see why this is a concern and I want to resolve it for you."
|
| 211 |
+
]),
|
| 212 |
+
turn=obs.turn
|
| 213 |
+
))
|
| 214 |
+
obs.history.append(Message(
|
| 215 |
+
role="customer",
|
| 216 |
+
text=self._rng.choice(["I appreciate that, thank you.", "Ok, let's get this sorted.", "Thank you."]),
|
| 217 |
+
turn=obs.turn
|
| 218 |
+
))
|
| 219 |
+
|
| 220 |
+
elif action_type == ActionType.ASK_CLARIFY:
|
| 221 |
+
if obs.clarified:
|
| 222 |
+
reason = "Already clarified β no incremental reward."
|
| 223 |
+
else:
|
| 224 |
+
obs.clarified = True
|
| 225 |
+
step_reward = R_ASK_CLARIFY
|
| 226 |
+
reason = "Clarifying question logged."
|
| 227 |
+
clarify_q = action.payload or "Could you share your account email and any relevant reference numbers?"
|
| 228 |
+
obs.history.append(Message(role="agent", text=clarify_q, turn=obs.turn))
|
| 229 |
+
obs.history.append(Message(
|
| 230 |
+
role="customer",
|
| 231 |
+
text=self._rng.choice([
|
| 232 |
+
"My account email is user@example.com. Order reference #482923.",
|
| 233 |
+
"Sure β account email user@example.com, invoice #8821.",
|
| 234 |
+
"My email is user@example.com. It started 3 days ago."
|
| 235 |
+
]),
|
| 236 |
+
turn=obs.turn
|
| 237 |
+
))
|
| 238 |
+
|
| 239 |
+
elif action_type == ActionType.OFFER_SOLUTION:
|
| 240 |
+
if not obs.kb_searched:
|
| 241 |
+
penalty = R_SKIP_KB_PENALTY
|
| 242 |
+
reason = "Penalty: solution offered without consulting the knowledge base."
|
| 243 |
+
solution_text = action.payload or ticket["canonical_solution"]
|
| 244 |
+
quality = self._score_solution(solution_text, ticket)
|
| 245 |
+
obs.solution_offered = True
|
| 246 |
+
step_reward = R_OFFER_SOLUTION * quality
|
| 247 |
+
reason = f"Solution offered. Quality score: {quality:.2f}."
|
| 248 |
+
info["solution_quality"] = quality
|
| 249 |
+
obs.history.append(Message(role="agent", text=solution_text, turn=obs.turn))
|
| 250 |
+
obs.history.append(Message(
|
| 251 |
+
role="customer",
|
| 252 |
+
text=self._rng.choice(ticket["customer_followups"]),
|
| 253 |
+
turn=obs.turn
|
| 254 |
+
))
|
| 255 |
+
|
| 256 |
+
elif action_type == ActionType.ESCALATE:
|
| 257 |
+
if obs.escalated:
|
| 258 |
+
penalty = R_DUPLICATE_ACTION * 2
|
| 259 |
+
reason = "Double escalation penalty."
|
| 260 |
+
else:
|
| 261 |
+
obs.escalated = True
|
| 262 |
+
penalty = R_ESCALATE
|
| 263 |
+
reason = "Escalated to tier-2. In-tier resolution preferred."
|
| 264 |
+
obs.history.append(Message(
|
| 265 |
+
role="system",
|
| 266 |
+
text="Ticket escalated to tier-2 specialist team.",
|
| 267 |
+
turn=obs.turn
|
| 268 |
+
))
|
| 269 |
+
|
| 270 |
+
elif action_type == ActionType.RESOLVE:
|
| 271 |
+
done = True
|
| 272 |
+
obs.status = TicketStatus.RESOLVED if not obs.escalated else TicketStatus.ESCALATED
|
| 273 |
+
if obs.solution_offered or obs.escalated:
|
| 274 |
+
csat = self._compute_csat(obs)
|
| 275 |
+
step_reward = R_RESOLVE_GOOD + csat * 2.0
|
| 276 |
+
reason = f"Resolved. CSAT: {csat:.2f}/1.0"
|
| 277 |
+
info["csat"] = csat
|
| 278 |
+
else:
|
| 279 |
+
step_reward = R_RESOLVE_BAD
|
| 280 |
+
reason = "Penalty: resolved without offering a solution."
|
| 281 |
+
obs.history.append(Message(
|
| 282 |
+
role="agent",
|
| 283 |
+
text="Thank you for your patience. I'm marking this ticket as resolved. Please don't hesitate to reach out if you need further help.",
|
| 284 |
+
turn=obs.turn
|
| 285 |
+
))
|
| 286 |
+
|
| 287 |
+
elif action_type == ActionType.SEND_MESSAGE:
|
| 288 |
+
# Free-form message β small reward for engagement
|
| 289 |
+
msg = action.payload or "I'm looking into this for you."
|
| 290 |
+
obs.history.append(Message(role="agent", text=msg, turn=obs.turn))
|
| 291 |
+
step_reward = 0.5
|
| 292 |
+
reason = "Message sent."
|
| 293 |
+
|
| 294 |
+
# ββ Timeout check βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 295 |
+
|
| 296 |
+
if obs.turn >= obs.max_turns and not done:
|
| 297 |
+
penalty += R_TIMEOUT
|
| 298 |
+
done = True
|
| 299 |
+
obs.status = TicketStatus.TIMEOUT
|
| 300 |
+
reason += " | Episode timed out."
|
| 301 |
+
|
| 302 |
+
# ββ Build reward ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 303 |
+
|
| 304 |
+
net = step_reward + penalty
|
| 305 |
+
efficiency = max(0.0, 1.0 - max(0, obs.turn - OPTIMAL_STEPS) * 0.1)
|
| 306 |
+
process = min(1.0, (
|
| 307 |
+
(0.25 if obs.kb_searched else 0) +
|
| 308 |
+
(0.25 if obs.empathized else 0) +
|
| 309 |
+
(0.25 if obs.solution_offered else 0) +
|
| 310 |
+
(0.25 if done and obs.status == TicketStatus.RESOLVED else 0)
|
| 311 |
+
))
|
| 312 |
+
reward = Reward(
|
| 313 |
+
total=round(net, 3),
|
| 314 |
+
process_score=round(process, 3),
|
| 315 |
+
quality_score=round(info.get("solution_quality", 0.0), 3),
|
| 316 |
+
efficiency_score=round(efficiency, 3),
|
| 317 |
+
csat_score=round(info.get("csat", 0.0), 3),
|
| 318 |
+
penalties=round(penalty, 3),
|
| 319 |
+
reason=reason
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
obs.cumulative_reward = round(obs.cumulative_reward + net, 3)
|
| 323 |
+
obs.done = done
|
| 324 |
+
info["turn"] = obs.turn
|
| 325 |
+
info["cumulative_reward"] = obs.cumulative_reward
|
| 326 |
+
obs.info = info
|
| 327 |
+
self._obs = obs
|
| 328 |
+
|
| 329 |
+
return StepResult(observation=obs, reward=reward, done=done, info=info)
|
| 330 |
+
|
| 331 |
+
def state(self) -> Observation:
|
| 332 |
+
"""Return current observation without advancing the environment."""
|
| 333 |
+
return self._obs
|
| 334 |
+
|
| 335 |
+
# ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 336 |
+
|
| 337 |
+
def _make_idle_obs(self) -> Observation:
|
| 338 |
+
return Observation(task_id=self.task_id)
|
| 339 |
+
|
| 340 |
+
def _score_solution(self, solution_text: str, ticket: dict) -> float:
|
| 341 |
+
"""Score solution quality against expected keywords (0.0β1.0)."""
|
| 342 |
+
text_lower = solution_text.lower()
|
| 343 |
+
keywords = ticket.get("solution_keywords", [])
|
| 344 |
+
if not keywords:
|
| 345 |
+
return 0.5
|
| 346 |
+
hits = sum(1 for kw in keywords if kw.lower() in text_lower)
|
| 347 |
+
return min(1.0, hits / max(1, len(keywords)))
|
| 348 |
+
|
| 349 |
+
def _compute_csat(self, obs: Observation) -> float:
|
| 350 |
+
"""Synthetic CSAT score (0.0β1.0) based on interaction quality."""
|
| 351 |
+
score = 0.0
|
| 352 |
+
if obs.empathized:
|
| 353 |
+
score += CSAT_WEIGHTS["empathized"]
|
| 354 |
+
if obs.kb_searched:
|
| 355 |
+
score += CSAT_WEIGHTS["kb_searched"]
|
| 356 |
+
if obs.solution_offered:
|
| 357 |
+
score += CSAT_WEIGHTS["solution_offered"]
|
| 358 |
+
return round(score, 3)
|
| 359 |
+
|
| 360 |
+
@staticmethod
|
| 361 |
+
def list_tasks() -> List[str]:
|
| 362 |
+
return list(TASKS.keys())
|
| 363 |
+
|
| 364 |
+
@staticmethod
|
| 365 |
+
def get_task_spec(task_id: str) -> TaskSpec:
|
| 366 |
+
return TASKS[task_id]
|
env/models.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Typed Pydantic models for CustomerSupportEnv (OpenEnv spec).
|
| 3 |
+
"""
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
from typing import Any, Dict, List, Optional
|
| 6 |
+
from pydantic import BaseModel, Field
|
| 7 |
+
from enum import Enum
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# ββ Enumerations ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 11 |
+
|
| 12 |
+
class TicketStatus(str, Enum):
|
| 13 |
+
IDLE = "idle"
|
| 14 |
+
OPEN = "open"
|
| 15 |
+
RESOLVED = "resolved"
|
| 16 |
+
ESCALATED = "escalated"
|
| 17 |
+
TIMEOUT = "timeout"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Sentiment(str, Enum):
|
| 21 |
+
POSITIVE = "positive"
|
| 22 |
+
NEUTRAL = "neutral"
|
| 23 |
+
FRUSTRATED = "frustrated"
|
| 24 |
+
ANGRY = "angry"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Priority(str, Enum):
|
| 28 |
+
LOW = "low"
|
| 29 |
+
MEDIUM = "medium"
|
| 30 |
+
HIGH = "high"
|
| 31 |
+
URGENT = "urgent"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Category(str, Enum):
|
| 35 |
+
AUTH = "auth"
|
| 36 |
+
BILLING = "billing"
|
| 37 |
+
FULFILLMENT = "fulfillment"
|
| 38 |
+
BUG = "bug"
|
| 39 |
+
SALES = "sales"
|
| 40 |
+
GENERAL = "general"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ActionType(str, Enum):
|
| 44 |
+
SEARCH_KB = "search_kb"
|
| 45 |
+
EMPATHIZE = "empathize"
|
| 46 |
+
ASK_CLARIFY = "ask_clarify"
|
| 47 |
+
OFFER_SOLUTION = "offer_solution"
|
| 48 |
+
ESCALATE = "escalate"
|
| 49 |
+
RESOLVE = "resolve"
|
| 50 |
+
SEND_MESSAGE = "send_message"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# ββ Core Typed Models βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 54 |
+
|
| 55 |
+
class Message(BaseModel):
|
| 56 |
+
role: str # "customer" | "agent" | "system"
|
| 57 |
+
text: str
|
| 58 |
+
turn: int = 0
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class Observation(BaseModel):
|
| 62 |
+
"""Full typed observation returned by reset() and step()."""
|
| 63 |
+
ticket_id: Optional[str] = None
|
| 64 |
+
task_id: str = "task_1"
|
| 65 |
+
status: TicketStatus = TicketStatus.IDLE
|
| 66 |
+
sentiment: Optional[Sentiment] = None
|
| 67 |
+
priority: Optional[Priority] = None
|
| 68 |
+
category: Optional[Category] = None
|
| 69 |
+
turn: int = 0
|
| 70 |
+
max_turns: int = 10
|
| 71 |
+
history: List[Message] = Field(default_factory=list)
|
| 72 |
+
kb_results: List[str] = Field(default_factory=list)
|
| 73 |
+
kb_searched: bool = False
|
| 74 |
+
empathized: bool = False
|
| 75 |
+
clarified: bool = False
|
| 76 |
+
solution_offered: bool = False
|
| 77 |
+
escalated: bool = False
|
| 78 |
+
cumulative_reward: float = 0.0
|
| 79 |
+
done: bool = False
|
| 80 |
+
info: Dict[str, Any] = Field(default_factory=dict)
|
| 81 |
+
|
| 82 |
+
class Config:
|
| 83 |
+
use_enum_values = True
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class Action(BaseModel):
|
| 87 |
+
"""Typed action submitted by the agent via step()."""
|
| 88 |
+
action_type: ActionType
|
| 89 |
+
payload: Optional[str] = None # free-text for send_message / offer_solution
|
| 90 |
+
|
| 91 |
+
class Config:
|
| 92 |
+
use_enum_values = True
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class Reward(BaseModel):
|
| 96 |
+
"""Typed reward with decomposed components."""
|
| 97 |
+
total: float
|
| 98 |
+
process_score: float = 0.0 # correct action sequencing
|
| 99 |
+
quality_score: float = 0.0 # solution quality / empathy
|
| 100 |
+
efficiency_score: float = 0.0 # steps taken vs optimal
|
| 101 |
+
csat_score: float = 0.0 # synthetic customer satisfaction (0β1)
|
| 102 |
+
penalties: float = 0.0
|
| 103 |
+
reason: str = ""
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class StepResult(BaseModel):
|
| 107 |
+
observation: Observation
|
| 108 |
+
reward: Reward
|
| 109 |
+
done: bool
|
| 110 |
+
info: Dict[str, Any] = Field(default_factory=dict)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class TaskSpec(BaseModel):
|
| 114 |
+
"""Defines one graded task within the environment."""
|
| 115 |
+
task_id: str
|
| 116 |
+
name: str
|
| 117 |
+
description: str
|
| 118 |
+
difficulty: str # easy | medium | hard
|
| 119 |
+
ticket_id: str
|
| 120 |
+
success_criteria: List[str]
|
| 121 |
+
max_turns: int
|
| 122 |
+
optimal_actions: List[str]
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class GraderResult(BaseModel):
|
| 126 |
+
task_id: str
|
| 127 |
+
score: float # 0.0 β 1.0
|
| 128 |
+
breakdown: Dict[str, float]
|
| 129 |
+
passed: bool
|
| 130 |
+
reason: str
|
env/tickets.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Ticket scenario database for CustomerSupportEnv.
|
| 3 |
+
Each ticket includes: metadata, customer history, KB articles,
|
| 4 |
+
canonical solution, and keyword-based solution validator.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
from typing import Dict, List, Any
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
TICKETS: Dict[str, Dict[str, Any]] = {
|
| 11 |
+
"TKT-001": {
|
| 12 |
+
"subject": "Cannot log in to my account",
|
| 13 |
+
"customer": "Aria Shah",
|
| 14 |
+
"priority": "high",
|
| 15 |
+
"category": "auth",
|
| 16 |
+
"sentiment": "frustrated",
|
| 17 |
+
"history": [
|
| 18 |
+
{"role": "customer", "text": "I've been locked out for 2 days! I tried resetting my password three times and nothing works. This is extremely urgent.", "turn": 0}
|
| 19 |
+
],
|
| 20 |
+
"kb_articles": [
|
| 21 |
+
"Password reset: Visit /forgot-password and enter your registered email. Reset links expire in 15 minutes.",
|
| 22 |
+
"Account lockout policy: Accounts lock after 5 failed attempts. Auto-unlock after 30 minutes, or contact support for manual unlock.",
|
| 23 |
+
"2FA issues: If locked out due to 2FA, an admin can bypass the second factor temporarily via the admin console."
|
| 24 |
+
],
|
| 25 |
+
"canonical_solution": "I have manually unlocked your account and sent a fresh password reset link to your registered email. The link will expire in 15 minutes. If 2FA is causing issues I can temporarily bypass it.",
|
| 26 |
+
"solution_keywords": ["unlock", "reset", "link", "email", "password"],
|
| 27 |
+
"customer_followups": [
|
| 28 |
+
"Thank you! I got the email and it worked.",
|
| 29 |
+
"That fixed it, appreciate your help.",
|
| 30 |
+
"Great, I'm back in now."
|
| 31 |
+
]
|
| 32 |
+
},
|
| 33 |
+
"TKT-002": {
|
| 34 |
+
"subject": "Wrong item shipped β order #482923",
|
| 35 |
+
"customer": "Bryce Lee",
|
| 36 |
+
"priority": "urgent",
|
| 37 |
+
"category": "fulfillment",
|
| 38 |
+
"sentiment": "angry",
|
| 39 |
+
"history": [
|
| 40 |
+
{"role": "customer", "text": "This is unacceptable. I ordered a Red T-Shirt size L but you sent me a Blue size M. Order #482923. I need the right item immediately.", "turn": 0}
|
| 41 |
+
],
|
| 42 |
+
"kb_articles": [
|
| 43 |
+
"Return policy: Customers have 30 days to initiate a return. Use the portal at /returns. We cover return shipping for our errors.",
|
| 44 |
+
"Priority re-ship: For fulfilment errors on orders >$25, approve a priority reship within 24h after return label is issued. No need to wait for return arrival.",
|
| 45 |
+
"Compensation policy: For urgent orders or repeat fulfilment errors, issue a 15% discount code on next purchase."
|
| 46 |
+
],
|
| 47 |
+
"canonical_solution": "I sincerely apologise. I've raised a priority reship for the Red T-Shirt size L β it will ship within 24 hours. I've emailed a pre-paid return label for the incorrect item, and added a 15% discount code to your account for the inconvenience.",
|
| 48 |
+
"solution_keywords": ["reship", "return", "label", "correct", "apologise", "apologi", "discount"],
|
| 49 |
+
"customer_followups": [
|
| 50 |
+
"OK, as long as it ships today I'm fine with that.",
|
| 51 |
+
"Got the email with the label. Thank you.",
|
| 52 |
+
"Alright, I appreciate the quick response."
|
| 53 |
+
]
|
| 54 |
+
},
|
| 55 |
+
"TKT-003": {
|
| 56 |
+
"subject": "Invoice #8821 shows wrong amount",
|
| 57 |
+
"customer": "Cleo Park",
|
| 58 |
+
"priority": "medium",
|
| 59 |
+
"category": "billing",
|
| 60 |
+
"sentiment": "neutral",
|
| 61 |
+
"history": [
|
| 62 |
+
{"role": "customer", "text": "Hello, invoice #8821 shows $49 but I'm on the $29/month Basic plan. I downgraded last month. Can you check?", "turn": 0}
|
| 63 |
+
],
|
| 64 |
+
"kb_articles": [
|
| 65 |
+
"Plan changes: Downgrades take effect at the start of the next billing cycle. The current period is charged at the old rate.",
|
| 66 |
+
"Prorate credits: If a downgrade was confirmed before the cycle closed, a manual credit can be issued for the difference.",
|
| 67 |
+
"Billing disputes: Finance team can adjust invoices within 60 days of issue date. Requires the invoice number and account email."
|
| 68 |
+
],
|
| 69 |
+
"canonical_solution": "I've reviewed your account. Your downgrade was confirmed before the billing cycle closed, so I'm issuing a $20 credit to your account which will appear on your next invoice. Going forward you will be billed $29/month.",
|
| 70 |
+
"solution_keywords": ["credit", "$20", "twenty", "correct", "billing", "downgrade", "refund"],
|
| 71 |
+
"customer_followups": [
|
| 72 |
+
"Perfect, that makes sense. Thanks.",
|
| 73 |
+
"Great, I can see the credit on my account.",
|
| 74 |
+
"Thanks for sorting that out quickly."
|
| 75 |
+
]
|
| 76 |
+
},
|
| 77 |
+
"TKT-004": {
|
| 78 |
+
"subject": "App crashes on iOS 17 during PDF export",
|
| 79 |
+
"customer": "Dev Okonkwo",
|
| 80 |
+
"priority": "medium",
|
| 81 |
+
"category": "bug",
|
| 82 |
+
"sentiment": "neutral",
|
| 83 |
+
"history": [
|
| 84 |
+
{"role": "customer", "text": "Every time I tap 'Export PDF' the app force-quits. iPhone 14 Pro, iOS 17.4.1. Started after the last app update.", "turn": 0}
|
| 85 |
+
],
|
| 86 |
+
"kb_articles": [
|
| 87 |
+
"Known iOS 17 crash: The PDF export feature has a memory issue on iOS 17.3 and above introduced in app v4.1.0. Fix is in v4.2.1.",
|
| 88 |
+
"Workaround: Use the web app at app.example.com/export for PDF exports until v4.2.1 is released (ETA: 5 business days).",
|
| 89 |
+
"Bug reporting: Collect crash logs from Settings > Privacy > Analytics & Improvements > Analytics Data and share with devs@example.com."
|
| 90 |
+
],
|
| 91 |
+
"canonical_solution": "This is a known bug in v4.1.0 on iOS 17.3+ β our engineering team has a fix ready in v4.2.1, releasing in 5 days. In the meantime, use our web app at app.example.com/export. I've also flagged your report to the engineering team.",
|
| 92 |
+
"solution_keywords": ["known", "bug", "v4.2", "workaround", "web", "fix", "engineering"],
|
| 93 |
+
"customer_followups": [
|
| 94 |
+
"Good to know it's being fixed. I'll use the web app for now.",
|
| 95 |
+
"Thanks for the workaround, that works.",
|
| 96 |
+
"OK, I'll wait for the update."
|
| 97 |
+
]
|
| 98 |
+
},
|
| 99 |
+
"TKT-005": {
|
| 100 |
+
"subject": "Bulk licence pricing for 50 seats",
|
| 101 |
+
"customer": "Emma Ng",
|
| 102 |
+
"priority": "low",
|
| 103 |
+
"category": "sales",
|
| 104 |
+
"sentiment": "positive",
|
| 105 |
+
"history": [
|
| 106 |
+
{"role": "customer", "text": "Hi! We're a team of about 50 and are considering your Pro plan. Do you offer bulk discounts? Also, is there an enterprise contract option?", "turn": 0}
|
| 107 |
+
],
|
| 108 |
+
"kb_articles": [
|
| 109 |
+
"Volume discounts: 10β24 seats: 10% off. 25β49 seats: 15% off. 50+ seats: 25% off annual plan.",
|
| 110 |
+
"Enterprise contracts: Custom SLA, SSO, dedicated support, and invoice billing. Contact sales@example.com. Average deal closes in 2 weeks.",
|
| 111 |
+
"Trial: Teams of 5+ can get a 30-day free trial of the Pro plan. No credit card required."
|
| 112 |
+
],
|
| 113 |
+
"canonical_solution": "Great news β 50 seats qualifies for our 25% volume discount on the annual Pro plan. We also offer enterprise contracts with SSO, dedicated support, and custom SLA. I'd love to connect you with our enterprise team at sales@example.com, or I can have an account executive reach out directly.",
|
| 114 |
+
"solution_keywords": ["25%", "twenty-five", "enterprise", "volume", "discount", "sales@", "executive"],
|
| 115 |
+
"customer_followups": [
|
| 116 |
+
"That sounds great, please have someone reach out.",
|
| 117 |
+
"25% is better than I expected! I'll email sales.",
|
| 118 |
+
"Perfect, we'll set up a call with the enterprise team."
|
| 119 |
+
]
|
| 120 |
+
},
|
| 121 |
+
"TKT-006": {
|
| 122 |
+
"subject": "Data export taking over 6 hours",
|
| 123 |
+
"customer": "Felix MartΓn",
|
| 124 |
+
"priority": "high",
|
| 125 |
+
"category": "bug",
|
| 126 |
+
"sentiment": "frustrated",
|
| 127 |
+
"history": [
|
| 128 |
+
{"role": "customer", "text": "I started a full data export 6 hours ago and it's still at 12%. I have a compliance deadline tomorrow. This is critical.", "turn": 0}
|
| 129 |
+
],
|
| 130 |
+
"kb_articles": [
|
| 131 |
+
"Export timeouts: Large exports (>10GB) can time out. The system retries automatically but may take 8-12 hours total.",
|
| 132 |
+
"Priority export queue: Support can manually move a job to the priority queue, cutting estimated time to 1-2 hours.",
|
| 133 |
+
"Partial exports: Users can export data by date range to reduce file size. Recommended for compliance: export by quarter."
|
| 134 |
+
],
|
| 135 |
+
"canonical_solution": "I've moved your export job to the priority queue β it should complete within 1-2 hours. As a backup, I recommend also starting a partial export by date range which will be much faster. I'll monitor and send you a confirmation email when the full export completes.",
|
| 136 |
+
"solution_keywords": ["priority", "queue", "1-2 hour", "partial", "monitor", "email"],
|
| 137 |
+
"customer_followups": [
|
| 138 |
+
"Thank you! I'll start the partial export as backup.",
|
| 139 |
+
"OK, I can see the progress picked up. Thanks.",
|
| 140 |
+
"The priority queue worked, it's done now."
|
| 141 |
+
]
|
| 142 |
+
}
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def get_ticket(ticket_id: str) -> Dict[str, Any]:
|
| 147 |
+
if ticket_id not in TICKETS:
|
| 148 |
+
raise ValueError(f"Unknown ticket_id: {ticket_id}")
|
| 149 |
+
return TICKETS[ticket_id]
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def all_ticket_ids() -> List[str]:
|
| 153 |
+
return list(TICKETS.keys())
|
graders/__init__.py
ADDED
|
File without changes
|
graders/graders.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Programmatic graders for CustomerSupportEnv tasks.
|
| 3 |
+
|
| 4 |
+
Each grader accepts a completed Observation and returns a GraderResult
|
| 5 |
+
with a score in [0.0, 1.0] and a detailed breakdown.
|
| 6 |
+
|
| 7 |
+
Graders are deterministic β same inputs always produce same outputs.
|
| 8 |
+
"""
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
from typing import Dict
|
| 11 |
+
|
| 12 |
+
from env.models import GraderResult, Observation, TicketStatus
|
| 13 |
+
from env.tickets import get_ticket
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# ββ Grader registry βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 17 |
+
|
| 18 |
+
def grade_task_1(obs: Observation) -> GraderResult:
|
| 19 |
+
"""
|
| 20 |
+
Task 1 (EASY): Resolve a standard auth ticket.
|
| 21 |
+
Scoring:
|
| 22 |
+
- 0.30 kb_searched before offer_solution
|
| 23 |
+
- 0.25 empathize called at least once
|
| 24 |
+
- 0.25 offer_solution payload mentions unlock/reset keywords
|
| 25 |
+
- 0.20 resolve called (status == RESOLVED)
|
| 26 |
+
"""
|
| 27 |
+
ticket = get_ticket("TKT-001")
|
| 28 |
+
breakdown: Dict[str, float] = {}
|
| 29 |
+
|
| 30 |
+
# Check conversation history for evidence of each required action
|
| 31 |
+
agent_turns = [m.text.lower() for m in obs.history if m.role == "agent"]
|
| 32 |
+
all_agent_text = " ".join(agent_turns)
|
| 33 |
+
|
| 34 |
+
# 1. KB searched
|
| 35 |
+
kb_score = 0.30 if obs.kb_searched else 0.0
|
| 36 |
+
breakdown["kb_searched"] = kb_score
|
| 37 |
+
|
| 38 |
+
# 2. Empathy expressed
|
| 39 |
+
empathy_score = 0.25 if obs.empathized else 0.0
|
| 40 |
+
breakdown["empathized"] = empathy_score
|
| 41 |
+
|
| 42 |
+
# 3. Solution quality β unlock/reset keywords
|
| 43 |
+
solution_keywords = ticket["solution_keywords"]
|
| 44 |
+
kw_hits = sum(1 for kw in solution_keywords if kw in all_agent_text)
|
| 45 |
+
sol_score = 0.25 * min(1.0, kw_hits / max(1, len(solution_keywords)))
|
| 46 |
+
breakdown["solution_quality"] = round(sol_score, 3)
|
| 47 |
+
|
| 48 |
+
# 4. Resolved cleanly (not timeout, not just escalated)
|
| 49 |
+
resolved = obs.status == TicketStatus.RESOLVED.value or obs.status == TicketStatus.RESOLVED
|
| 50 |
+
resolve_score = 0.20 if resolved else 0.0
|
| 51 |
+
breakdown["resolved"] = resolve_score
|
| 52 |
+
|
| 53 |
+
total = sum(breakdown.values())
|
| 54 |
+
passed = total >= 0.70
|
| 55 |
+
|
| 56 |
+
return GraderResult(
|
| 57 |
+
task_id="task_1",
|
| 58 |
+
score=round(total, 3),
|
| 59 |
+
breakdown=breakdown,
|
| 60 |
+
passed=passed,
|
| 61 |
+
reason=_build_reason(breakdown, passed)
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def grade_task_2(obs: Observation) -> GraderResult:
|
| 66 |
+
"""
|
| 67 |
+
Task 2 (MEDIUM): Multi-step billing dispute.
|
| 68 |
+
Scoring:
|
| 69 |
+
- 0.20 ask_clarify called
|
| 70 |
+
- 0.20 kb_searched
|
| 71 |
+
- 0.30 offer_solution mentions a specific credit/refund (amount or keyword)
|
| 72 |
+
- 0.15 empathize called
|
| 73 |
+
- 0.15 resolve called
|
| 74 |
+
"""
|
| 75 |
+
ticket = get_ticket("TKT-003")
|
| 76 |
+
breakdown: Dict[str, float] = {}
|
| 77 |
+
all_agent_text = " ".join(m.text.lower() for m in obs.history if m.role == "agent")
|
| 78 |
+
|
| 79 |
+
# 1. Clarification step
|
| 80 |
+
breakdown["ask_clarify"] = 0.20 if obs.clarified else 0.0
|
| 81 |
+
|
| 82 |
+
# 2. KB searched
|
| 83 |
+
breakdown["kb_searched"] = 0.20 if obs.kb_searched else 0.0
|
| 84 |
+
|
| 85 |
+
# 3. Specific solution with $ amount or keywords
|
| 86 |
+
solution_keywords = ticket["solution_keywords"]
|
| 87 |
+
kw_hits = sum(1 for kw in solution_keywords if kw in all_agent_text)
|
| 88 |
+
# Extra check: requires a numeric/specific value, not just generic words
|
| 89 |
+
has_amount = any(x in all_agent_text for x in ["$20", "twenty", "20 credit", "credit of"])
|
| 90 |
+
quality = min(1.0, kw_hits / max(1, len(solution_keywords)))
|
| 91 |
+
if has_amount:
|
| 92 |
+
quality = min(1.0, quality + 0.3)
|
| 93 |
+
breakdown["solution_quality"] = round(0.30 * quality, 3)
|
| 94 |
+
|
| 95 |
+
# 4. Empathy
|
| 96 |
+
breakdown["empathized"] = 0.15 if obs.empathized else 0.0
|
| 97 |
+
|
| 98 |
+
# 5. Resolved
|
| 99 |
+
resolved = obs.status in (TicketStatus.RESOLVED.value, TicketStatus.RESOLVED)
|
| 100 |
+
breakdown["resolved"] = 0.15 if resolved else 0.0
|
| 101 |
+
|
| 102 |
+
total = sum(breakdown.values())
|
| 103 |
+
passed = total >= 0.70
|
| 104 |
+
|
| 105 |
+
return GraderResult(
|
| 106 |
+
task_id="task_2",
|
| 107 |
+
score=round(total, 3),
|
| 108 |
+
breakdown=breakdown,
|
| 109 |
+
passed=passed,
|
| 110 |
+
reason=_build_reason(breakdown, passed)
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def grade_task_3(obs: Observation) -> GraderResult:
|
| 115 |
+
"""
|
| 116 |
+
Task 3 (HARD): Critical time-sensitive bug β data export stuck.
|
| 117 |
+
Scoring:
|
| 118 |
+
- 0.20 kb_searched
|
| 119 |
+
- 0.15 empathize called
|
| 120 |
+
- 0.35 solution mentions BOTH priority queue AND partial export (two-part solution)
|
| 121 |
+
- 0.15 NOT escalated (in-tier resolution required for full score)
|
| 122 |
+
- 0.15 resolve called
|
| 123 |
+
Bonus deduction: -0.10 if escalated (overrides the 0.15 no-escalation credit)
|
| 124 |
+
"""
|
| 125 |
+
ticket = get_ticket("TKT-006")
|
| 126 |
+
breakdown: Dict[str, float] = {}
|
| 127 |
+
all_agent_text = " ".join(m.text.lower() for m in obs.history if m.role == "agent")
|
| 128 |
+
|
| 129 |
+
# 1. KB searched
|
| 130 |
+
breakdown["kb_searched"] = 0.20 if obs.kb_searched else 0.0
|
| 131 |
+
|
| 132 |
+
# 2. Empathy
|
| 133 |
+
breakdown["empathized"] = 0.15 if obs.empathized else 0.0
|
| 134 |
+
|
| 135 |
+
# 3. Two-part solution: priority queue + partial export
|
| 136 |
+
has_priority_queue = any(x in all_agent_text for x in ["priority queue", "priority export", "move your", "moved your"])
|
| 137 |
+
has_partial = any(x in all_agent_text for x in ["partial", "date range", "by quarter", "partial export"])
|
| 138 |
+
has_urgency = any(x in all_agent_text for x in ["deadline", "1-2 hour", "urgent", "compliance", "monitor", "email you"])
|
| 139 |
+
|
| 140 |
+
sol_quality = 0.0
|
| 141 |
+
if has_priority_queue and has_partial:
|
| 142 |
+
sol_quality = 1.0
|
| 143 |
+
elif has_priority_queue or has_partial:
|
| 144 |
+
sol_quality = 0.5
|
| 145 |
+
if has_urgency:
|
| 146 |
+
sol_quality = min(1.0, sol_quality + 0.2)
|
| 147 |
+
|
| 148 |
+
breakdown["solution_quality"] = round(0.35 * sol_quality, 3)
|
| 149 |
+
|
| 150 |
+
# 4. No escalation
|
| 151 |
+
breakdown["no_escalation"] = 0.0 if obs.escalated else 0.15
|
| 152 |
+
|
| 153 |
+
# 5. Resolved
|
| 154 |
+
resolved = obs.status in (TicketStatus.RESOLVED.value, TicketStatus.RESOLVED)
|
| 155 |
+
breakdown["resolved"] = 0.15 if resolved else 0.0
|
| 156 |
+
|
| 157 |
+
total = sum(breakdown.values())
|
| 158 |
+
# Hard cap at 0.85 if escalated (escalation shows poor judgment on this task)
|
| 159 |
+
if obs.escalated:
|
| 160 |
+
total = min(total, 0.55)
|
| 161 |
+
|
| 162 |
+
passed = total >= 0.70
|
| 163 |
+
|
| 164 |
+
return GraderResult(
|
| 165 |
+
task_id="task_3",
|
| 166 |
+
score=round(total, 3),
|
| 167 |
+
breakdown=breakdown,
|
| 168 |
+
passed=passed,
|
| 169 |
+
reason=_build_reason(breakdown, passed)
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
GRADERS = {
|
| 174 |
+
"task_1": grade_task_1,
|
| 175 |
+
"task_2": grade_task_2,
|
| 176 |
+
"task_3": grade_task_3,
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def grade(task_id: str, obs: Observation) -> GraderResult:
|
| 181 |
+
"""Grade a completed observation for the given task."""
|
| 182 |
+
if task_id not in GRADERS:
|
| 183 |
+
raise ValueError(f"No grader for task_id '{task_id}'. Valid: {list(GRADERS.keys())}")
|
| 184 |
+
return GRADERS[task_id](obs)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def _build_reason(breakdown: Dict[str, float], passed: bool) -> str:
|
| 188 |
+
hits = [k for k, v in breakdown.items() if v > 0]
|
| 189 |
+
misses = [k for k, v in breakdown.items() if v == 0]
|
| 190 |
+
status = "PASS" if passed else "FAIL"
|
| 191 |
+
msg = f"[{status}] Score components present: {hits}."
|
| 192 |
+
if misses:
|
| 193 |
+
msg += f" Missing: {misses}."
|
| 194 |
+
return msg
|
tests/__init__.py
ADDED
|
File without changes
|
tests/test_env.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for CustomerSupportEnv.
|
| 3 |
+
Run: python -m pytest tests/ -v
|
| 4 |
+
"""
|
| 5 |
+
import sys, os
|
| 6 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
from env.environment import CustomerSupportEnv, TASKS
|
| 10 |
+
from env.models import Action, ActionType, TicketStatus
|
| 11 |
+
from graders.graders import grade
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# ββ Fixtures ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 15 |
+
|
| 16 |
+
@pytest.fixture
|
| 17 |
+
def env1():
|
| 18 |
+
e = CustomerSupportEnv(task_id="task_1", seed=0)
|
| 19 |
+
e.reset()
|
| 20 |
+
return e
|
| 21 |
+
|
| 22 |
+
@pytest.fixture
|
| 23 |
+
def env2():
|
| 24 |
+
e = CustomerSupportEnv(task_id="task_2", seed=0)
|
| 25 |
+
e.reset()
|
| 26 |
+
return e
|
| 27 |
+
|
| 28 |
+
@pytest.fixture
|
| 29 |
+
def env3():
|
| 30 |
+
e = CustomerSupportEnv(task_id="task_3", seed=0)
|
| 31 |
+
e.reset()
|
| 32 |
+
return e
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ββ reset() βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 36 |
+
|
| 37 |
+
def test_reset_returns_observation():
|
| 38 |
+
env = CustomerSupportEnv(task_id="task_1", seed=0)
|
| 39 |
+
obs = env.reset()
|
| 40 |
+
assert obs.ticket_id == "TKT-001"
|
| 41 |
+
assert obs.done is False
|
| 42 |
+
assert obs.turn == 0
|
| 43 |
+
assert obs.status == TicketStatus.OPEN.value or obs.status == TicketStatus.OPEN
|
| 44 |
+
|
| 45 |
+
def test_reset_clears_state(env1):
|
| 46 |
+
env1.step(Action(action_type=ActionType.SEARCH_KB))
|
| 47 |
+
obs = env1.reset()
|
| 48 |
+
assert obs.kb_searched is False
|
| 49 |
+
assert obs.turn == 0
|
| 50 |
+
assert obs.cumulative_reward == 0.0
|
| 51 |
+
|
| 52 |
+
def test_reset_loads_history(env1):
|
| 53 |
+
obs = env1.state()
|
| 54 |
+
assert len(obs.history) >= 1
|
| 55 |
+
assert obs.history[0].role == "customer"
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ββ state() βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 59 |
+
|
| 60 |
+
def test_state_does_not_advance(env1):
|
| 61 |
+
obs_before = env1.state()
|
| 62 |
+
env1.state()
|
| 63 |
+
obs_after = env1.state()
|
| 64 |
+
assert obs_before.turn == obs_after.turn
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ββ step() ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 68 |
+
|
| 69 |
+
def test_step_search_kb(env1):
|
| 70 |
+
result = env1.step(Action(action_type=ActionType.SEARCH_KB))
|
| 71 |
+
assert result.reward.total == 2.0
|
| 72 |
+
assert result.observation.kb_searched is True
|
| 73 |
+
assert len(result.observation.kb_results) > 0
|
| 74 |
+
|
| 75 |
+
def test_step_search_kb_duplicate_penalised(env1):
|
| 76 |
+
env1.step(Action(action_type=ActionType.SEARCH_KB))
|
| 77 |
+
result = env1.step(Action(action_type=ActionType.SEARCH_KB))
|
| 78 |
+
assert result.reward.total < 0
|
| 79 |
+
|
| 80 |
+
def test_step_empathize(env1):
|
| 81 |
+
result = env1.step(Action(action_type=ActionType.EMPATHIZE))
|
| 82 |
+
assert result.reward.total == 1.0
|
| 83 |
+
assert result.observation.empathized is True
|
| 84 |
+
|
| 85 |
+
def test_step_empathize_no_double_reward(env1):
|
| 86 |
+
env1.step(Action(action_type=ActionType.EMPATHIZE))
|
| 87 |
+
result = env1.step(Action(action_type=ActionType.EMPATHIZE))
|
| 88 |
+
assert result.reward.total == 0.0
|
| 89 |
+
|
| 90 |
+
def test_step_offer_solution_without_kb_penalised(env1):
|
| 91 |
+
result = env1.step(Action(
|
| 92 |
+
action_type=ActionType.OFFER_SOLUTION,
|
| 93 |
+
payload="I have unlocked your account and sent a reset link."
|
| 94 |
+
))
|
| 95 |
+
assert result.reward.penalties == -1.0
|
| 96 |
+
|
| 97 |
+
def test_step_offer_solution_with_kb_rewarded(env1):
|
| 98 |
+
env1.step(Action(action_type=ActionType.SEARCH_KB))
|
| 99 |
+
result = env1.step(Action(
|
| 100 |
+
action_type=ActionType.OFFER_SOLUTION,
|
| 101 |
+
payload="I have unlocked your account and sent a password reset link."
|
| 102 |
+
))
|
| 103 |
+
assert result.reward.total > 0
|
| 104 |
+
|
| 105 |
+
def test_step_resolve_without_solution_penalised(env1):
|
| 106 |
+
result = env1.step(Action(action_type=ActionType.RESOLVE))
|
| 107 |
+
assert result.reward.total == -3.0
|
| 108 |
+
assert result.done is True
|
| 109 |
+
|
| 110 |
+
def test_step_resolve_good(env1):
|
| 111 |
+
env1.step(Action(action_type=ActionType.SEARCH_KB))
|
| 112 |
+
env1.step(Action(
|
| 113 |
+
action_type=ActionType.OFFER_SOLUTION,
|
| 114 |
+
payload="Account unlocked and reset email sent."
|
| 115 |
+
))
|
| 116 |
+
result = env1.step(Action(action_type=ActionType.RESOLVE))
|
| 117 |
+
assert result.reward.total >= 5.0
|
| 118 |
+
assert result.done is True
|
| 119 |
+
|
| 120 |
+
def test_step_raises_before_reset():
|
| 121 |
+
env = CustomerSupportEnv(task_id="task_1")
|
| 122 |
+
with pytest.raises(RuntimeError):
|
| 123 |
+
env.step(Action(action_type=ActionType.SEARCH_KB))
|
| 124 |
+
|
| 125 |
+
def test_step_raises_after_done(env1):
|
| 126 |
+
env1.step(Action(action_type=ActionType.RESOLVE))
|
| 127 |
+
with pytest.raises(RuntimeError):
|
| 128 |
+
env1.step(Action(action_type=ActionType.SEARCH_KB))
|
| 129 |
+
|
| 130 |
+
def test_timeout_penalty(env1):
|
| 131 |
+
"""Exceeding max_turns gives timeout penalty."""
|
| 132 |
+
for _ in range(env1._obs.max_turns - 1):
|
| 133 |
+
env1.step(Action(action_type=ActionType.EMPATHIZE))
|
| 134 |
+
obs = env1.state()
|
| 135 |
+
assert obs.turn >= obs.max_turns - 1
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# ββ Graders βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 139 |
+
|
| 140 |
+
def test_grader_task1_optimal(env1):
|
| 141 |
+
env1.step(Action(action_type=ActionType.SEARCH_KB))
|
| 142 |
+
env1.step(Action(action_type=ActionType.EMPATHIZE))
|
| 143 |
+
env1.step(Action(
|
| 144 |
+
action_type=ActionType.OFFER_SOLUTION,
|
| 145 |
+
payload="I have unlocked your account and sent a password reset link to your email."
|
| 146 |
+
))
|
| 147 |
+
env1.step(Action(action_type=ActionType.RESOLVE))
|
| 148 |
+
result = grade("task_1", env1.state())
|
| 149 |
+
assert result.score >= 0.90
|
| 150 |
+
assert result.passed is True
|
| 151 |
+
|
| 152 |
+
def test_grader_task1_minimal(env1):
|
| 153 |
+
"""Just resolve with no steps β should fail."""
|
| 154 |
+
env1.step(Action(action_type=ActionType.RESOLVE))
|
| 155 |
+
result = grade("task_1", env1.state())
|
| 156 |
+
assert result.score < 0.40
|
| 157 |
+
assert result.passed is False
|
| 158 |
+
|
| 159 |
+
def test_grader_task1_score_in_range(env1):
|
| 160 |
+
result = grade("task_1", env1.state())
|
| 161 |
+
assert 0.0 <= result.score <= 1.0
|
| 162 |
+
|
| 163 |
+
def test_grader_task2_requires_clarify(env2):
|
| 164 |
+
"""Medium task: no clarify β lower score."""
|
| 165 |
+
env2.step(Action(action_type=ActionType.SEARCH_KB))
|
| 166 |
+
env2.step(Action(
|
| 167 |
+
action_type=ActionType.OFFER_SOLUTION,
|
| 168 |
+
payload="I have applied a $20 credit to your account."
|
| 169 |
+
))
|
| 170 |
+
env2.step(Action(action_type=ActionType.RESOLVE))
|
| 171 |
+
result = grade("task_2", env2.state())
|
| 172 |
+
assert result.breakdown.get("ask_clarify", 0) == 0.0
|
| 173 |
+
|
| 174 |
+
def test_grader_task2_full_score(env2):
|
| 175 |
+
env2.step(Action(action_type=ActionType.SEARCH_KB))
|
| 176 |
+
env2.step(Action(action_type=ActionType.ASK_CLARIFY, payload="Can you confirm your account email and invoice number?"))
|
| 177 |
+
env2.step(Action(action_type=ActionType.EMPATHIZE))
|
| 178 |
+
env2.step(Action(
|
| 179 |
+
action_type=ActionType.OFFER_SOLUTION,
|
| 180 |
+
payload="I have issued a $20 credit to your account. Your plan is now corrected to $29/month."
|
| 181 |
+
))
|
| 182 |
+
env2.step(Action(action_type=ActionType.RESOLVE))
|
| 183 |
+
result = grade("task_2", env2.state())
|
| 184 |
+
assert result.score >= 0.70
|
| 185 |
+
|
| 186 |
+
def test_grader_task3_two_part_solution(env3):
|
| 187 |
+
env3.step(Action(action_type=ActionType.SEARCH_KB))
|
| 188 |
+
env3.step(Action(action_type=ActionType.EMPATHIZE))
|
| 189 |
+
env3.step(Action(
|
| 190 |
+
action_type=ActionType.OFFER_SOLUTION,
|
| 191 |
+
payload="I have moved your export job to the priority queue β it will complete in 1-2 hours. "
|
| 192 |
+
"As a backup, please start a partial export by date range which will be much faster. "
|
| 193 |
+
"I will email you when the full export completes."
|
| 194 |
+
))
|
| 195 |
+
env3.step(Action(action_type=ActionType.RESOLVE))
|
| 196 |
+
result = grade("task_3", env3.state())
|
| 197 |
+
assert result.score >= 0.70
|
| 198 |
+
assert result.passed is True
|
| 199 |
+
|
| 200 |
+
def test_grader_task3_escalation_capped(env3):
|
| 201 |
+
env3.step(Action(action_type=ActionType.SEARCH_KB))
|
| 202 |
+
env3.step(Action(action_type=ActionType.ESCALATE))
|
| 203 |
+
env3.step(Action(action_type=ActionType.RESOLVE))
|
| 204 |
+
result = grade("task_3", env3.state())
|
| 205 |
+
assert result.score <= 0.55
|
| 206 |
+
|
| 207 |
+
def test_grader_deterministic(env1):
|
| 208 |
+
"""Same inputs β same grader output every time."""
|
| 209 |
+
env1.step(Action(action_type=ActionType.SEARCH_KB))
|
| 210 |
+
env1.step(Action(action_type=ActionType.RESOLVE))
|
| 211 |
+
r1 = grade("task_1", env1.state())
|
| 212 |
+
env1.reset()
|
| 213 |
+
env1.step(Action(action_type=ActionType.SEARCH_KB))
|
| 214 |
+
env1.step(Action(action_type=ActionType.RESOLVE))
|
| 215 |
+
r2 = grade("task_1", env1.state())
|
| 216 |
+
assert r1.score == r2.score
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
# ββ Task specs ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 220 |
+
|
| 221 |
+
def test_task_list():
|
| 222 |
+
assert set(CustomerSupportEnv.list_tasks()) == {"task_1", "task_2", "task_3"}
|
| 223 |
+
|
| 224 |
+
def test_task_difficulty_progression():
|
| 225 |
+
diffs = [TASKS[tid].difficulty for tid in ["task_1", "task_2", "task_3"]]
|
| 226 |
+
assert diffs == ["easy", "medium", "hard"]
|