Spaces:
Runtime error
Runtime error
Update graders/graders.py
Browse files- graders/graders.py +22 -190
graders/graders.py
CHANGED
|
@@ -1,197 +1,29 @@
|
|
| 1 |
-
|
| 2 |
-
Programmatic graders for CustomerSupportEnv tasks.
|
| 3 |
|
| 4 |
-
|
| 5 |
-
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
-
from env.tickets import get_ticket
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
# ββ Grader registry βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 17 |
-
|
| 18 |
-
def grade_task_1(obs: Observation) -> GraderResult:
|
| 19 |
-
"""
|
| 20 |
-
Task 1 (EASY): Resolve a standard auth ticket.
|
| 21 |
-
Scoring:
|
| 22 |
-
- 0.30 kb_searched before offer_solution
|
| 23 |
-
- 0.25 empathize called at least once
|
| 24 |
-
- 0.25 offer_solution payload mentions unlock/reset keywords
|
| 25 |
-
- 0.20 resolve called (status == RESOLVED)
|
| 26 |
-
"""
|
| 27 |
-
ticket = get_ticket("TKT-001")
|
| 28 |
-
breakdown: Dict[str, float] = {}
|
| 29 |
-
|
| 30 |
-
# Check conversation history for evidence of each required action
|
| 31 |
-
agent_turns = [m.text.lower() for m in obs.history if m.role == "agent"]
|
| 32 |
-
all_agent_text = " ".join(agent_turns)
|
| 33 |
-
|
| 34 |
-
# 1. KB searched
|
| 35 |
-
kb_score = 0.30 if obs.kb_searched else 0.0
|
| 36 |
-
breakdown["kb_searched"] = kb_score
|
| 37 |
-
|
| 38 |
-
# 2. Empathy expressed
|
| 39 |
-
empathy_score = 0.25 if obs.empathized else 0.0
|
| 40 |
-
breakdown["empathized"] = empathy_score
|
| 41 |
-
|
| 42 |
-
# 3. Solution quality β unlock/reset keywords
|
| 43 |
-
solution_keywords = ticket["solution_keywords"]
|
| 44 |
-
kw_hits = sum(1 for kw in solution_keywords if kw in all_agent_text)
|
| 45 |
-
sol_score = 0.25 * min(1.0, kw_hits / max(1, len(solution_keywords)))
|
| 46 |
-
breakdown["solution_quality"] = round(sol_score, 3)
|
| 47 |
-
|
| 48 |
-
# 4. Resolved cleanly (not timeout, not just escalated)
|
| 49 |
-
resolved = obs.status == TicketStatus.RESOLVED.value or obs.status == TicketStatus.RESOLVED
|
| 50 |
-
resolve_score = 0.20 if resolved else 0.0
|
| 51 |
-
breakdown["resolved"] = resolve_score
|
| 52 |
-
|
| 53 |
-
total = sum(breakdown.values())
|
| 54 |
-
total = min(total, 0.999)
|
| 55 |
-
passed = total >= 0.70
|
| 56 |
-
|
| 57 |
-
return GraderResult(
|
| 58 |
-
task_id="task_1",
|
| 59 |
-
score=round(total, 3),
|
| 60 |
-
breakdown=breakdown,
|
| 61 |
-
passed=passed,
|
| 62 |
-
reason=_build_reason(breakdown, passed)
|
| 63 |
-
)
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def grade_task_2(obs: Observation) -> GraderResult:
|
| 67 |
-
"""
|
| 68 |
-
Task 2 (MEDIUM): Multi-step billing dispute.
|
| 69 |
-
Scoring:
|
| 70 |
-
- 0.20 ask_clarify called
|
| 71 |
-
- 0.20 kb_searched
|
| 72 |
-
- 0.30 offer_solution mentions a specific credit/refund (amount or keyword)
|
| 73 |
-
- 0.15 empathize called
|
| 74 |
-
- 0.15 resolve called
|
| 75 |
-
"""
|
| 76 |
-
ticket = get_ticket("TKT-003")
|
| 77 |
-
breakdown: Dict[str, float] = {}
|
| 78 |
-
all_agent_text = " ".join(m.text.lower() for m in obs.history if m.role == "agent")
|
| 79 |
-
|
| 80 |
-
# 1. Clarification step
|
| 81 |
-
breakdown["ask_clarify"] = 0.20 if obs.clarified else 0.0
|
| 82 |
-
|
| 83 |
-
# 2. KB searched
|
| 84 |
-
breakdown["kb_searched"] = 0.20 if obs.kb_searched else 0.0
|
| 85 |
-
|
| 86 |
-
# 3. Specific solution with $ amount or keywords
|
| 87 |
-
solution_keywords = ticket["solution_keywords"]
|
| 88 |
-
kw_hits = sum(1 for kw in solution_keywords if kw in all_agent_text)
|
| 89 |
-
# Extra check: requires a numeric/specific value, not just generic words
|
| 90 |
-
has_amount = any(x in all_agent_text for x in ["$20", "twenty", "20 credit", "credit of"])
|
| 91 |
-
quality = min(1.0, kw_hits / max(1, len(solution_keywords)))
|
| 92 |
-
if has_amount:
|
| 93 |
-
quality = min(1.0, quality + 0.3)
|
| 94 |
-
breakdown["solution_quality"] = round(0.30 * quality, 3)
|
| 95 |
-
|
| 96 |
-
# 4. Empathy
|
| 97 |
-
breakdown["empathized"] = 0.15 if obs.empathized else 0.0
|
| 98 |
-
|
| 99 |
-
# 5. Resolved
|
| 100 |
-
resolved = obs.status in (TicketStatus.RESOLVED.value, TicketStatus.RESOLVED)
|
| 101 |
-
breakdown["resolved"] = 0.15 if resolved else 0.0
|
| 102 |
-
|
| 103 |
-
total = sum(breakdown.values())
|
| 104 |
-
total = min(total, 0.999)
|
| 105 |
-
passed = total >= 0.70
|
| 106 |
-
|
| 107 |
-
return GraderResult(
|
| 108 |
-
task_id="task_2",
|
| 109 |
-
score=round(total, 3),
|
| 110 |
-
breakdown=breakdown,
|
| 111 |
-
passed=passed,
|
| 112 |
-
reason=_build_reason(breakdown, passed)
|
| 113 |
-
)
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
def grade_task_3(obs: Observation) -> GraderResult:
|
| 117 |
-
"""
|
| 118 |
-
Task 3 (HARD): Critical time-sensitive bug β data export stuck.
|
| 119 |
-
Scoring:
|
| 120 |
-
- 0.20 kb_searched
|
| 121 |
-
- 0.15 empathize called
|
| 122 |
-
- 0.35 solution mentions BOTH priority queue AND partial export (two-part solution)
|
| 123 |
-
- 0.15 NOT escalated (in-tier resolution required for full score)
|
| 124 |
-
- 0.15 resolve called
|
| 125 |
-
Bonus deduction: -0.10 if escalated (overrides the 0.15 no-escalation credit)
|
| 126 |
-
"""
|
| 127 |
-
ticket = get_ticket("TKT-006")
|
| 128 |
-
breakdown: Dict[str, float] = {}
|
| 129 |
-
all_agent_text = " ".join(m.text.lower() for m in obs.history if m.role == "agent")
|
| 130 |
-
|
| 131 |
-
# 1. KB searched
|
| 132 |
-
breakdown["kb_searched"] = 0.20 if obs.kb_searched else 0.0
|
| 133 |
-
|
| 134 |
-
# 2. Empathy
|
| 135 |
-
breakdown["empathized"] = 0.15 if obs.empathized else 0.0
|
| 136 |
-
|
| 137 |
-
# 3. Two-part solution: priority queue + partial export
|
| 138 |
-
has_priority_queue = any(x in all_agent_text for x in ["priority queue", "priority export", "move your", "moved your"])
|
| 139 |
-
has_partial = any(x in all_agent_text for x in ["partial", "date range", "by quarter", "partial export"])
|
| 140 |
-
has_urgency = any(x in all_agent_text for x in ["deadline", "1-2 hour", "urgent", "compliance", "monitor", "email you"])
|
| 141 |
-
|
| 142 |
-
sol_quality = 0.0
|
| 143 |
-
if has_priority_queue and has_partial:
|
| 144 |
-
sol_quality = 1.0
|
| 145 |
-
elif has_priority_queue or has_partial:
|
| 146 |
-
sol_quality = 0.5
|
| 147 |
-
if has_urgency:
|
| 148 |
-
sol_quality = min(1.0, sol_quality + 0.2)
|
| 149 |
-
|
| 150 |
-
breakdown["solution_quality"] = round(0.35 * sol_quality, 3)
|
| 151 |
-
|
| 152 |
-
# 4. No escalation
|
| 153 |
-
breakdown["no_escalation"] = 0.0 if obs.escalated else 0.15
|
| 154 |
-
|
| 155 |
-
# 5. Resolved
|
| 156 |
-
resolved = obs.status in (TicketStatus.RESOLVED.value, TicketStatus.RESOLVED)
|
| 157 |
-
breakdown["resolved"] = 0.15 if resolved else 0.0
|
| 158 |
-
|
| 159 |
-
total = sum(breakdown.values())
|
| 160 |
-
total = min(total, 0.999)
|
| 161 |
-
# Hard cap at 0.85 if escalated (escalation shows poor judgment on this task)
|
| 162 |
if obs.escalated:
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
passed = total >= 0.70
|
| 166 |
-
|
| 167 |
-
return GraderResult(
|
| 168 |
-
task_id="task_3",
|
| 169 |
-
score=round(total, 3),
|
| 170 |
-
breakdown=breakdown,
|
| 171 |
-
passed=passed,
|
| 172 |
-
reason=_build_reason(breakdown, passed)
|
| 173 |
-
)
|
| 174 |
|
|
|
|
|
|
|
| 175 |
|
| 176 |
-
|
| 177 |
-
"task_1": grade_task_1,
|
| 178 |
-
"task_2": grade_task_2,
|
| 179 |
-
"task_3": grade_task_3,
|
| 180 |
-
}
|
| 181 |
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
def _build_reason(breakdown: Dict[str, float], passed: bool) -> str:
|
| 191 |
-
hits = [k for k, v in breakdown.items() if v > 0]
|
| 192 |
-
misses = [k for k, v in breakdown.items() if v == 0]
|
| 193 |
-
status = "PASS" if passed else "FAIL"
|
| 194 |
-
msg = f"[{status}] Score components present: {hits}."
|
| 195 |
-
if misses:
|
| 196 |
-
msg += f" Missing: {misses}."
|
| 197 |
-
return msg
|
|
|
|
| 1 |
+
from env.models import GraderResult
|
|
|
|
| 2 |
|
| 3 |
+
def grade(task_id, obs):
|
| 4 |
+
score = 0.0
|
| 5 |
|
| 6 |
+
if obs.kb_searched:
|
| 7 |
+
score += 0.2
|
| 8 |
+
if obs.empathized:
|
| 9 |
+
score += 0.2
|
| 10 |
+
if obs.solution_offered:
|
| 11 |
+
score += 0.3
|
| 12 |
+
if obs.done:
|
| 13 |
+
score += 0.3
|
| 14 |
|
| 15 |
+
# penalties
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
if obs.escalated:
|
| 17 |
+
score *= 0.7
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
+
if obs.turn > 6:
|
| 20 |
+
score *= 0.9
|
| 21 |
|
| 22 |
+
score = min(max(score, 0.001), 0.999)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
+
return GraderResult(
|
| 25 |
+
task_id=task_id,
|
| 26 |
+
score=round(score, 3),
|
| 27 |
+
passed=score > 0.5,
|
| 28 |
+
breakdown={}
|
| 29 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|