Sayed223 commited on
Commit
19faf96
Β·
verified Β·
1 Parent(s): f19b7e5

Update env/environment.py

Browse files
Files changed (1) hide show
  1. env/environment.py +107 -297
env/environment.py CHANGED
@@ -1,153 +1,76 @@
1
- """
2
- CustomerSupportEnv β€” Core environment implementing the OpenEnv spec.
3
 
4
- step(action) β†’ StepResult(observation, reward, done, info)
5
- reset() β†’ Observation
6
- state() β†’ Observation
7
- """
8
  from __future__ import annotations
9
-
10
  import random
11
- from typing import Any, Dict, List, Optional, Tuple
12
 
13
  from env.models import (
14
- Action, ActionType, Category, Message, Observation,
15
- Priority, Reward, Sentiment, StepResult, TaskSpec, TicketStatus
16
  )
17
- from env.tickets import TICKETS, get_ticket
18
-
19
-
20
- # ── Reward constants ──────────────────────────────────────────────────────────
21
-
22
- R_SEARCH_KB = 2.0
23
- R_EMPATHIZE = 1.0
24
- R_ASK_CLARIFY = 1.0
25
- R_OFFER_SOLUTION = 3.0
26
- R_RESOLVE_GOOD = 5.0
27
- R_RESOLVE_BAD = -3.0
28
- R_ESCALATE = -1.0
29
- R_DUPLICATE_ACTION = -1.0
30
- R_SKIP_KB_PENALTY = -1.0
31
- R_TIMEOUT = -2.0
32
-
33
- CSAT_WEIGHTS = {
34
- "empathized": 0.3,
35
- "kb_searched": 0.3,
36
- "solution_offered": 0.4,
37
- }
38
-
39
- # Optimal trajectory (used for efficiency scoring)
40
- OPTIMAL_STEPS = 4 # search_kb, empathize, offer_solution, resolve
41
-
42
 
43
- # ── Task definitions ──────────────────────────────────────────────────────────
44
 
 
 
 
45
  TASKS: Dict[str, TaskSpec] = {
46
  "task_1": TaskSpec(
47
  task_id="task_1",
48
- name="Resolve a Standard Auth Ticket",
49
- description=(
50
- "Handle a frustrated customer locked out of their account. "
51
- "The agent must search the knowledge base, acknowledge the "
52
- "customer's frustration, offer a concrete solution, and resolve the ticket. "
53
- "EASY: single-step fix, KB articles directly address the issue."
54
- ),
55
  difficulty="easy",
56
  ticket_id="TKT-001",
57
- success_criteria=[
58
- "search_kb called before offer_solution",
59
- "empathize called at least once",
60
- "offer_solution payload mentions unlock or reset",
61
- "resolve called to close episode"
62
- ],
63
- max_turns=8,
64
- optimal_actions=["search_kb", "empathize", "offer_solution", "resolve"]
65
  ),
66
  "task_2": TaskSpec(
67
  task_id="task_2",
68
- name="Handle a Multi-Step Billing Dispute",
69
- description=(
70
- "Resolve a billing discrepancy for a customer who was overcharged after "
71
- "a plan downgrade. The agent must clarify details, check the KB, diagnose "
72
- "the root cause, provide a specific dollar credit, and confirm the fix. "
73
- "MEDIUM: requires clarification before diagnosis; generic solutions penalised."
74
- ),
75
  difficulty="medium",
76
  ticket_id="TKT-003",
77
- success_criteria=[
78
- "ask_clarify called at least once",
79
- "search_kb called",
80
- "offer_solution mentions credit or refund amount",
81
- "resolve called"
82
- ],
83
- max_turns=10,
84
- optimal_actions=["search_kb", "ask_clarify", "empathize", "offer_solution", "resolve"]
85
  ),
86
  "task_3": TaskSpec(
87
  task_id="task_3",
88
- name="Triage a Critical Time-Sensitive Bug Report",
89
- description=(
90
- "An enterprise customer has a compliance deadline tomorrow and a data export "
91
- "stuck at 12% for 6 hours. The agent must quickly diagnose the issue, "
92
- "deploy an immediate workaround (priority queue), offer a backup strategy "
93
- "(partial export), and close with a monitoring commitment. "
94
- "HARD: time pressure, two-part solution required, escalation penalised, "
95
- "generic solutions score low."
96
- ),
97
  difficulty="hard",
98
  ticket_id="TKT-006",
99
- success_criteria=[
100
- "search_kb called",
101
- "offer_solution mentions priority queue AND partial export",
102
- "solution demonstrates urgency awareness",
103
- "resolve called without escalation"
104
- ],
105
- max_turns=8,
106
- optimal_actions=["search_kb", "empathize", "ask_clarify", "offer_solution", "resolve"]
107
  )
108
  }
109
 
110
 
111
- # ── Environment ──────────────────────────────────��────────────────────────────
112
-
 
113
  class CustomerSupportEnv:
114
- """
115
- OpenEnv-compatible customer support RL environment.
116
-
117
- Usage:
118
- env = CustomerSupportEnv(task_id="task_1")
119
- obs = env.reset()
120
- result = env.step(Action(action_type="search_kb"))
121
- current = env.state()
122
- """
123
 
124
- VERSION = "1.0.0"
125
-
126
- def __init__(self, task_id: str = "task_1", seed: Optional[int] = None):
127
- if task_id not in TASKS:
128
- raise ValueError(f"Unknown task_id '{task_id}'. Valid: {list(TASKS.keys())}")
129
- self.task_id = task_id
130
  self.task = TASKS[task_id]
131
- self._seed = seed
132
  self._rng = random.Random(seed)
133
- self._obs: Observation = self._make_idle_obs()
134
-
135
- # ── OpenEnv API ───────────────────────────────────────────────────────────
136
 
137
  def reset(self) -> Observation:
138
- """Reset the environment and return the initial observation."""
139
- ticket_data = get_ticket(self.task.ticket_id)
 
 
 
 
 
 
 
140
  history = [
141
- Message(role=m["role"], text=m["text"], turn=m.get("turn", 0))
142
- for m in ticket_data["history"]
143
  ]
 
144
  self._obs = Observation(
145
  ticket_id=self.task.ticket_id,
146
- task_id=self.task_id,
147
  status=TicketStatus.OPEN,
148
- sentiment=ticket_data["sentiment"],
149
- priority=ticket_data["priority"],
150
- category=ticket_data["category"],
151
  turn=0,
152
  max_turns=self.task.max_turns,
153
  history=history,
@@ -159,208 +82,95 @@ class CustomerSupportEnv:
159
  escalated=False,
160
  cumulative_reward=0.0,
161
  done=False,
162
- info={"task_name": self.task.name, "difficulty": self.task.difficulty}
 
 
 
 
 
163
  )
164
  return self._obs
165
 
166
  def step(self, action: Action) -> StepResult:
167
- """
168
- Advance the environment by one step.
169
- Returns StepResult(observation, reward, done, info).
170
- """
171
- if self._obs.status == TicketStatus.IDLE:
172
- raise RuntimeError("Call reset() before step().")
173
- if self._obs.done:
174
- raise RuntimeError("Episode is done. Call reset() to start a new episode.")
175
-
176
  obs = self._obs
177
- ticket = get_ticket(obs.ticket_id)
178
- action_type = ActionType(action.action_type)
179
 
180
- step_reward, reason, penalty = 0.0, "", 0.0
181
  done = False
182
- info: Dict[str, Any] = {}
183
 
184
  obs.turn += 1
185
 
186
- # ── Dispatch action ────────────────────────────────────────────────
 
 
 
 
 
187
 
188
- if action_type == ActionType.SEARCH_KB:
189
- if obs.kb_searched:
190
- penalty = R_DUPLICATE_ACTION
191
- reason = "Duplicate search_kb β€” no new information."
192
- else:
193
- obs.kb_searched = True
194
- obs.kb_results = ticket["kb_articles"]
195
- step_reward = R_SEARCH_KB
196
- reason = f"Retrieved {len(obs.kb_results)} KB articles."
197
 
198
- elif action_type == ActionType.EMPATHIZE:
199
- if obs.empathized:
200
- reason = "Already empathized β€” no incremental reward."
201
- else:
202
- obs.empathized = True
203
- step_reward = R_EMPATHIZE
204
- reason = "Empathy acknowledged by customer."
205
- obs.history.append(Message(
206
- role="agent",
207
- text=self._rng.choice([
208
- "I completely understand how frustrating this situation must be. Let me help you immediately.",
209
- "I'm sorry you're going through this β€” that sounds really stressful. Let's fix it right away.",
210
- "Thank you for reaching out. I can see why this is a concern and I want to resolve it for you."
211
- ]),
212
- turn=obs.turn
213
- ))
214
- obs.history.append(Message(
215
- role="customer",
216
- text=self._rng.choice(["I appreciate that, thank you.", "Ok, let's get this sorted.", "Thank you."]),
217
- turn=obs.turn
218
- ))
219
 
220
- elif action_type == ActionType.ASK_CLARIFY:
221
- if obs.clarified:
222
- reason = "Already clarified β€” no incremental reward."
223
- else:
224
- obs.clarified = True
225
- step_reward = R_ASK_CLARIFY
226
- reason = "Clarifying question logged."
227
- clarify_q = action.payload or "Could you share your account email and any relevant reference numbers?"
228
- obs.history.append(Message(role="agent", text=clarify_q, turn=obs.turn))
229
- obs.history.append(Message(
230
- role="customer",
231
- text=self._rng.choice([
232
- "My account email is user@example.com. Order reference #482923.",
233
- "Sure β€” account email user@example.com, invoice #8821.",
234
- "My email is user@example.com. It started 3 days ago."
235
- ]),
236
- turn=obs.turn
237
- ))
238
-
239
- elif action_type == ActionType.OFFER_SOLUTION:
240
- if not obs.kb_searched:
241
- penalty = R_SKIP_KB_PENALTY
242
- reason = "Penalty: solution offered without consulting the knowledge base."
243
- solution_text = action.payload or ticket["canonical_solution"]
244
- quality = self._score_solution(solution_text, ticket)
245
  obs.solution_offered = True
246
- step_reward = R_OFFER_SOLUTION * quality
247
- reason = f"Solution offered. Quality score: {quality:.2f}."
248
- info["solution_quality"] = quality
249
- obs.history.append(Message(role="agent", text=solution_text, turn=obs.turn))
250
- obs.history.append(Message(
251
- role="customer",
252
- text=self._rng.choice(ticket["customer_followups"]),
253
- turn=obs.turn
254
- ))
255
-
256
- elif action_type == ActionType.ESCALATE:
257
- if obs.escalated:
258
- penalty = R_DUPLICATE_ACTION * 2
259
- reason = "Double escalation penalty."
260
- else:
261
- obs.escalated = True
262
- penalty = R_ESCALATE
263
- reason = "Escalated to tier-2. In-tier resolution preferred."
264
- obs.history.append(Message(
265
- role="system",
266
- text="Ticket escalated to tier-2 specialist team.",
267
- turn=obs.turn
268
- ))
269
 
270
- elif action_type == ActionType.RESOLVE:
271
  done = True
272
- obs.status = TicketStatus.RESOLVED if not obs.escalated else TicketStatus.ESCALATED
273
- if obs.solution_offered or obs.escalated:
274
- csat = self._compute_csat(obs)
275
- step_reward = R_RESOLVE_GOOD + csat * 2.0
276
- reason = f"Resolved. CSAT: {csat:.2f}/1.0"
277
- info["csat"] = csat
278
- else:
279
- step_reward = R_RESOLVE_BAD
280
- reason = "Penalty: resolved without offering a solution."
281
- obs.history.append(Message(
282
- role="agent",
283
- text="Thank you for your patience. I'm marking this ticket as resolved. Please don't hesitate to reach out if you need further help.",
284
- turn=obs.turn
285
- ))
286
-
287
- elif action_type == ActionType.SEND_MESSAGE:
288
- # Free-form message β€” small reward for engagement
289
- msg = action.payload or "I'm looking into this for you."
290
- obs.history.append(Message(role="agent", text=msg, turn=obs.turn))
291
- step_reward = 0.5
292
- reason = "Message sent."
293
-
294
- # ── Timeout check ─────────────────────────────────────────────────
295
-
296
- if obs.turn >= obs.max_turns and not done:
297
- penalty += R_TIMEOUT
 
 
 
 
 
 
 
 
 
 
298
  done = True
299
- obs.status = TicketStatus.TIMEOUT
300
- reason += " | Episode timed out."
301
-
302
- # ── Build reward ──────────────────────────────────────────────────
303
 
304
- net = step_reward + penalty
305
- efficiency = max(0.0, 1.0 - max(0, obs.turn - OPTIMAL_STEPS) * 0.1)
306
- process = min(1.0, (
307
- (0.25 if obs.kb_searched else 0) +
308
- (0.25 if obs.empathized else 0) +
309
- (0.25 if obs.solution_offered else 0) +
310
- (0.25 if done and obs.status == TicketStatus.RESOLVED else 0)
311
- ))
312
- reward = Reward(
313
- total=round(net, 3),
314
- process_score=round(process, 3),
315
- quality_score=round(info.get("solution_quality", 0.0), 3),
316
- efficiency_score=round(efficiency, 3),
317
- csat_score=round(info.get("csat", 0.0), 3),
318
- penalties=round(penalty, 3),
319
- reason=reason
320
- )
321
 
322
- obs.cumulative_reward = round(obs.cumulative_reward + net, 3)
323
  obs.done = done
324
- info["turn"] = obs.turn
325
- info["cumulative_reward"] = obs.cumulative_reward
326
- obs.info = info
327
- self._obs = obs
328
-
329
- return StepResult(observation=obs, reward=reward, done=done, info=info)
330
-
331
- def state(self) -> Observation:
332
- """Return current observation without advancing the environment."""
333
- return self._obs
334
-
335
- # ── Helpers ───────────────────────────────────────────────────────────────
336
-
337
- def _make_idle_obs(self) -> Observation:
338
- return Observation(task_id=self.task_id)
339
-
340
- def _score_solution(self, solution_text: str, ticket: dict) -> float:
341
- """Score solution quality against expected keywords (0.0–1.0)."""
342
- text_lower = solution_text.lower()
343
- keywords = ticket.get("solution_keywords", [])
344
- if not keywords:
345
- return 0.5
346
- hits = sum(1 for kw in keywords if kw.lower() in text_lower)
347
- return min(1.0, hits / max(1, len(keywords)))
348
-
349
- def _compute_csat(self, obs: Observation) -> float:
350
- """Synthetic CSAT score (0.0–1.0) based on interaction quality."""
351
- score = 0.0
352
- if obs.empathized:
353
- score += CSAT_WEIGHTS["empathized"]
354
- if obs.kb_searched:
355
- score += CSAT_WEIGHTS["kb_searched"]
356
- if obs.solution_offered:
357
- score += CSAT_WEIGHTS["solution_offered"]
358
- return round(score, 3)
359
-
360
- @staticmethod
361
- def list_tasks() -> List[str]:
362
- return list(TASKS.keys())
363
-
364
- @staticmethod
365
- def get_task_spec(task_id: str) -> TaskSpec:
366
- return TASKS[task_id]
 
1
+ # ADVANCED ENTERPRISE INCIDENT ENVIRONMENT
 
2
 
 
 
 
 
3
  from __future__ import annotations
 
4
  import random
5
+ from typing import Dict, Any, Optional, List
6
 
7
  from env.models import (
8
+ Action, ActionType, Message, Observation,
9
+ Reward, StepResult, TaskSpec, TicketStatus
10
  )
11
+ from env.tickets import get_ticket
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
 
13
 
14
+ # =========================
15
+ # TASK DEFINITIONS
16
+ # =========================
17
  TASKS: Dict[str, TaskSpec] = {
18
  "task_1": TaskSpec(
19
  task_id="task_1",
20
+ name="Basic Auth Issue",
21
+ description="Resolve login issue with empathy and KB usage",
 
 
 
 
 
22
  difficulty="easy",
23
  ticket_id="TKT-001",
24
+ max_turns=8
 
 
 
 
 
 
 
25
  ),
26
  "task_2": TaskSpec(
27
  task_id="task_2",
28
+ name="Billing SLA Case",
29
+ description="Resolve billing with SLA and refund precision",
 
 
 
 
 
30
  difficulty="medium",
31
  ticket_id="TKT-003",
32
+ max_turns=10
 
 
 
 
 
 
 
33
  ),
34
  "task_3": TaskSpec(
35
  task_id="task_3",
36
+ name="Critical Enterprise Outage",
37
+ description="Handle high severity incident with urgency",
 
 
 
 
 
 
 
38
  difficulty="hard",
39
  ticket_id="TKT-006",
40
+ max_turns=8
 
 
 
 
 
 
 
41
  )
42
  }
43
 
44
 
45
+ # =========================
46
+ # ENVIRONMENT
47
+ # =========================
48
  class CustomerSupportEnv:
 
 
 
 
 
 
 
 
 
49
 
50
+ def __init__(self, task_id="task_1", seed=None):
 
 
 
 
 
51
  self.task = TASKS[task_id]
 
52
  self._rng = random.Random(seed)
53
+ self._obs: Observation = None
 
 
54
 
55
  def reset(self) -> Observation:
56
+ ticket = get_ticket(self.task.ticket_id)
57
+
58
+ # πŸ”₯ ADVANCED CONTEXT
59
+ self.sla_deadline = self._rng.choice([3, 5, 7])
60
+ self.customer_tier = self._rng.choice(["free", "premium", "enterprise"])
61
+ self.issue_severity = self._rng.choice(["low", "medium", "critical"])
62
+ self.escalation_risk = self._rng.choice([0.2, 0.5, 0.8])
63
+ self.hidden_failure_mode = self._rng.choice([True, False])
64
+
65
  history = [
66
+ Message(role=m["role"], text=m["text"], turn=0)
67
+ for m in ticket["history"]
68
  ]
69
+
70
  self._obs = Observation(
71
  ticket_id=self.task.ticket_id,
72
+ task_id=self.task.task_id,
73
  status=TicketStatus.OPEN,
 
 
 
74
  turn=0,
75
  max_turns=self.task.max_turns,
76
  history=history,
 
82
  escalated=False,
83
  cumulative_reward=0.0,
84
  done=False,
85
+ info={
86
+ "sla": self.sla_deadline,
87
+ "tier": self.customer_tier,
88
+ "severity": self.issue_severity,
89
+ "risk": self.escalation_risk
90
+ }
91
  )
92
  return self._obs
93
 
94
  def step(self, action: Action) -> StepResult:
 
 
 
 
 
 
 
 
 
95
  obs = self._obs
96
+ action_type = action.action_type
 
97
 
98
+ reward = 0.0
99
  done = False
 
100
 
101
  obs.turn += 1
102
 
103
+ # =========================
104
+ # ACTIONS
105
+ # =========================
106
+ if action_type == "search_kb":
107
+ reward += 2 if not obs.kb_searched else -1
108
+ obs.kb_searched = True
109
 
110
+ elif action_type == "empathize":
111
+ reward += 1 if not obs.empathized else 0
112
+ obs.empathized = True
 
 
 
 
 
 
113
 
114
+ elif action_type == "ask_clarify":
115
+ reward += 1 if not obs.clarified else 0
116
+ obs.clarified = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ elif action_type == "offer_solution":
119
+ reward += 3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  obs.solution_offered = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
+ elif action_type == "resolve":
123
  done = True
124
+ reward += 5 if obs.solution_offered else -3
125
+
126
+ elif action_type == "escalate":
127
+ obs.escalated = True
128
+ reward -= 1
129
+
130
+ # =========================
131
+ # ADVANCED LOGIC
132
+ # =========================
133
+
134
+ # SLA pressure
135
+ if obs.turn > self.sla_deadline:
136
+ reward -= 1
137
+
138
+ # severity boost
139
+ if self.issue_severity == "critical":
140
+ if action_type == "search_kb":
141
+ reward += 0.5
142
+
143
+ # enterprise expectations
144
+ if self.customer_tier == "enterprise":
145
+ if done and obs.turn <= self.sla_deadline:
146
+ reward += 2
147
+ elif done:
148
+ reward -= 1
149
+
150
+ # escalation risk
151
+ if self.escalation_risk > 0.7 and not obs.empathized:
152
+ reward -= 1
153
+
154
+ # hidden failure mode (novel)
155
+ if self.hidden_failure_mode and action_type == "offer_solution" and not obs.kb_searched:
156
+ reward -= 2
157
+
158
+ # auto escalation
159
+ if self.escalation_risk > 0.7 and obs.turn > 3 and not obs.empathized:
160
  done = True
161
+ obs.escalated = True
162
+ reward -= 3
 
 
163
 
164
+ # efficiency
165
+ if done and obs.turn <= 4:
166
+ reward += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
 
168
  obs.done = done
169
+ obs.cumulative_reward += reward
170
+
171
+ return StepResult(
172
+ observation=obs,
173
+ reward=Reward(total=round(reward, 3)),
174
+ done=done,
175
+ info={}
176
+ )