Sayed223 commited on
Commit
f19b7e5
Β·
verified Β·
1 Parent(s): be9873a

Update graders/graders.py

Browse files
Files changed (1) hide show
  1. graders/graders.py +22 -190
graders/graders.py CHANGED
@@ -1,197 +1,29 @@
1
- """
2
- Programmatic graders for CustomerSupportEnv tasks.
3
 
4
- Each grader accepts a completed Observation and returns a GraderResult
5
- with a score in [0.0, 1.0] and a detailed breakdown.
6
 
7
- Graders are deterministic β€” same inputs always produce same outputs.
8
- """
9
- from __future__ import annotations
10
- from typing import Dict
 
 
 
 
11
 
12
- from env.models import GraderResult, Observation, TicketStatus
13
- from env.tickets import get_ticket
14
-
15
-
16
- # ── Grader registry ───────────────────────────────────────────────────────────
17
-
18
- def grade_task_1(obs: Observation) -> GraderResult:
19
- """
20
- Task 1 (EASY): Resolve a standard auth ticket.
21
- Scoring:
22
- - 0.30 kb_searched before offer_solution
23
- - 0.25 empathize called at least once
24
- - 0.25 offer_solution payload mentions unlock/reset keywords
25
- - 0.20 resolve called (status == RESOLVED)
26
- """
27
- ticket = get_ticket("TKT-001")
28
- breakdown: Dict[str, float] = {}
29
-
30
- # Check conversation history for evidence of each required action
31
- agent_turns = [m.text.lower() for m in obs.history if m.role == "agent"]
32
- all_agent_text = " ".join(agent_turns)
33
-
34
- # 1. KB searched
35
- kb_score = 0.30 if obs.kb_searched else 0.0
36
- breakdown["kb_searched"] = kb_score
37
-
38
- # 2. Empathy expressed
39
- empathy_score = 0.25 if obs.empathized else 0.0
40
- breakdown["empathized"] = empathy_score
41
-
42
- # 3. Solution quality β€” unlock/reset keywords
43
- solution_keywords = ticket["solution_keywords"]
44
- kw_hits = sum(1 for kw in solution_keywords if kw in all_agent_text)
45
- sol_score = 0.25 * min(1.0, kw_hits / max(1, len(solution_keywords)))
46
- breakdown["solution_quality"] = round(sol_score, 3)
47
-
48
- # 4. Resolved cleanly (not timeout, not just escalated)
49
- resolved = obs.status == TicketStatus.RESOLVED.value or obs.status == TicketStatus.RESOLVED
50
- resolve_score = 0.20 if resolved else 0.0
51
- breakdown["resolved"] = resolve_score
52
-
53
- total = sum(breakdown.values())
54
- total = min(total, 0.999)
55
- passed = total >= 0.70
56
-
57
- return GraderResult(
58
- task_id="task_1",
59
- score=round(total, 3),
60
- breakdown=breakdown,
61
- passed=passed,
62
- reason=_build_reason(breakdown, passed)
63
- )
64
-
65
-
66
- def grade_task_2(obs: Observation) -> GraderResult:
67
- """
68
- Task 2 (MEDIUM): Multi-step billing dispute.
69
- Scoring:
70
- - 0.20 ask_clarify called
71
- - 0.20 kb_searched
72
- - 0.30 offer_solution mentions a specific credit/refund (amount or keyword)
73
- - 0.15 empathize called
74
- - 0.15 resolve called
75
- """
76
- ticket = get_ticket("TKT-003")
77
- breakdown: Dict[str, float] = {}
78
- all_agent_text = " ".join(m.text.lower() for m in obs.history if m.role == "agent")
79
-
80
- # 1. Clarification step
81
- breakdown["ask_clarify"] = 0.20 if obs.clarified else 0.0
82
-
83
- # 2. KB searched
84
- breakdown["kb_searched"] = 0.20 if obs.kb_searched else 0.0
85
-
86
- # 3. Specific solution with $ amount or keywords
87
- solution_keywords = ticket["solution_keywords"]
88
- kw_hits = sum(1 for kw in solution_keywords if kw in all_agent_text)
89
- # Extra check: requires a numeric/specific value, not just generic words
90
- has_amount = any(x in all_agent_text for x in ["$20", "twenty", "20 credit", "credit of"])
91
- quality = min(1.0, kw_hits / max(1, len(solution_keywords)))
92
- if has_amount:
93
- quality = min(1.0, quality + 0.3)
94
- breakdown["solution_quality"] = round(0.30 * quality, 3)
95
-
96
- # 4. Empathy
97
- breakdown["empathized"] = 0.15 if obs.empathized else 0.0
98
-
99
- # 5. Resolved
100
- resolved = obs.status in (TicketStatus.RESOLVED.value, TicketStatus.RESOLVED)
101
- breakdown["resolved"] = 0.15 if resolved else 0.0
102
-
103
- total = sum(breakdown.values())
104
- total = min(total, 0.999)
105
- passed = total >= 0.70
106
-
107
- return GraderResult(
108
- task_id="task_2",
109
- score=round(total, 3),
110
- breakdown=breakdown,
111
- passed=passed,
112
- reason=_build_reason(breakdown, passed)
113
- )
114
-
115
-
116
- def grade_task_3(obs: Observation) -> GraderResult:
117
- """
118
- Task 3 (HARD): Critical time-sensitive bug β€” data export stuck.
119
- Scoring:
120
- - 0.20 kb_searched
121
- - 0.15 empathize called
122
- - 0.35 solution mentions BOTH priority queue AND partial export (two-part solution)
123
- - 0.15 NOT escalated (in-tier resolution required for full score)
124
- - 0.15 resolve called
125
- Bonus deduction: -0.10 if escalated (overrides the 0.15 no-escalation credit)
126
- """
127
- ticket = get_ticket("TKT-006")
128
- breakdown: Dict[str, float] = {}
129
- all_agent_text = " ".join(m.text.lower() for m in obs.history if m.role == "agent")
130
-
131
- # 1. KB searched
132
- breakdown["kb_searched"] = 0.20 if obs.kb_searched else 0.0
133
-
134
- # 2. Empathy
135
- breakdown["empathized"] = 0.15 if obs.empathized else 0.0
136
-
137
- # 3. Two-part solution: priority queue + partial export
138
- has_priority_queue = any(x in all_agent_text for x in ["priority queue", "priority export", "move your", "moved your"])
139
- has_partial = any(x in all_agent_text for x in ["partial", "date range", "by quarter", "partial export"])
140
- has_urgency = any(x in all_agent_text for x in ["deadline", "1-2 hour", "urgent", "compliance", "monitor", "email you"])
141
-
142
- sol_quality = 0.0
143
- if has_priority_queue and has_partial:
144
- sol_quality = 1.0
145
- elif has_priority_queue or has_partial:
146
- sol_quality = 0.5
147
- if has_urgency:
148
- sol_quality = min(1.0, sol_quality + 0.2)
149
-
150
- breakdown["solution_quality"] = round(0.35 * sol_quality, 3)
151
-
152
- # 4. No escalation
153
- breakdown["no_escalation"] = 0.0 if obs.escalated else 0.15
154
-
155
- # 5. Resolved
156
- resolved = obs.status in (TicketStatus.RESOLVED.value, TicketStatus.RESOLVED)
157
- breakdown["resolved"] = 0.15 if resolved else 0.0
158
-
159
- total = sum(breakdown.values())
160
- total = min(total, 0.999)
161
- # Hard cap at 0.85 if escalated (escalation shows poor judgment on this task)
162
  if obs.escalated:
163
- total = min(total, 0.55)
164
-
165
- passed = total >= 0.70
166
-
167
- return GraderResult(
168
- task_id="task_3",
169
- score=round(total, 3),
170
- breakdown=breakdown,
171
- passed=passed,
172
- reason=_build_reason(breakdown, passed)
173
- )
174
 
 
 
175
 
176
- GRADERS = {
177
- "task_1": grade_task_1,
178
- "task_2": grade_task_2,
179
- "task_3": grade_task_3,
180
- }
181
 
182
-
183
- def grade(task_id: str, obs: Observation) -> GraderResult:
184
- """Grade a completed observation for the given task."""
185
- if task_id not in GRADERS:
186
- raise ValueError(f"No grader for task_id '{task_id}'. Valid: {list(GRADERS.keys())}")
187
- return GRADERS[task_id](obs)
188
-
189
-
190
- def _build_reason(breakdown: Dict[str, float], passed: bool) -> str:
191
- hits = [k for k, v in breakdown.items() if v > 0]
192
- misses = [k for k, v in breakdown.items() if v == 0]
193
- status = "PASS" if passed else "FAIL"
194
- msg = f"[{status}] Score components present: {hits}."
195
- if misses:
196
- msg += f" Missing: {misses}."
197
- return msg
 
1
+ from env.models import GraderResult
 
2
 
3
+ def grade(task_id, obs):
4
+ score = 0.0
5
 
6
+ if obs.kb_searched:
7
+ score += 0.2
8
+ if obs.empathized:
9
+ score += 0.2
10
+ if obs.solution_offered:
11
+ score += 0.3
12
+ if obs.done:
13
+ score += 0.3
14
 
15
+ # penalties
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  if obs.escalated:
17
+ score *= 0.7
 
 
 
 
 
 
 
 
 
 
18
 
19
+ if obs.turn > 6:
20
+ score *= 0.9
21
 
22
+ score = min(max(score, 0.001), 0.999)
 
 
 
 
23
 
24
+ return GraderResult(
25
+ task_id=task_id,
26
+ score=round(score, 3),
27
+ passed=score > 0.5,
28
+ breakdown={}
29
+ )